Implement WiFi-DensePose system with CSI data extraction and router interface
- Added CSIExtractor class for extracting CSI data from WiFi routers. - Implemented RouterInterface class for SSH communication with routers. - Developed DensePoseHead class for body part segmentation and UV coordinate regression. - Created unit tests for CSIExtractor and RouterInterface to ensure functionality and error handling. - Integrated paramiko for SSH connections and command execution. - Established configuration validation for both extractor and router interface. - Added context manager support for resource management in both classes.
This commit is contained in:
@@ -3,27 +3,27 @@ import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch
|
||||
from src.models.modality_translation import ModalityTranslationNetwork
|
||||
from src.models.modality_translation import ModalityTranslationNetwork, ModalityTranslationError
|
||||
|
||||
|
||||
class TestModalityTranslationNetwork:
|
||||
"""Test suite for Modality Translation Network following London School TDD principles"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_csi_input(self):
|
||||
"""Generate synthetic CSI input tensor for testing"""
|
||||
# Batch size 2, 3 antennas, 56 subcarriers, 100 temporal samples
|
||||
return torch.randn(2, 3, 56, 100)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Configuration for modality translation network"""
|
||||
return {
|
||||
'input_channels': 3,
|
||||
'hidden_dim': 256,
|
||||
'output_dim': 512,
|
||||
'num_layers': 3,
|
||||
'dropout_rate': 0.1
|
||||
'input_channels': 6, # Real and imaginary parts for 3 antennas
|
||||
'hidden_channels': [64, 128, 256],
|
||||
'output_channels': 256,
|
||||
'kernel_size': 3,
|
||||
'stride': 1,
|
||||
'padding': 1,
|
||||
'dropout_rate': 0.1,
|
||||
'activation': 'relu',
|
||||
'normalization': 'batch',
|
||||
'use_attention': True,
|
||||
'attention_heads': 8
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
@@ -31,98 +31,263 @@ class TestModalityTranslationNetwork:
|
||||
"""Create modality translation network instance for testing"""
|
||||
return ModalityTranslationNetwork(mock_config)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_csi_input(self):
|
||||
"""Generate mock CSI input tensor"""
|
||||
batch_size = 4
|
||||
channels = 6 # Real and imaginary parts for 3 antennas
|
||||
height = 56 # Number of subcarriers
|
||||
width = 100 # Time samples
|
||||
return torch.randn(batch_size, channels, height, width)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_target_features(self):
|
||||
"""Generate mock target feature tensor for training"""
|
||||
batch_size = 4
|
||||
feature_dim = 256
|
||||
spatial_height = 56
|
||||
spatial_width = 100
|
||||
return torch.randn(batch_size, feature_dim, spatial_height, spatial_width)
|
||||
|
||||
def test_network_initialization_creates_correct_architecture(self, mock_config):
|
||||
"""Test that network initializes with correct architecture"""
|
||||
"""Test that modality translation network initializes with correct architecture"""
|
||||
# Act
|
||||
network = ModalityTranslationNetwork(mock_config)
|
||||
|
||||
# Assert
|
||||
assert network is not None
|
||||
assert isinstance(network, nn.Module)
|
||||
assert network.input_channels == mock_config['input_channels']
|
||||
assert network.output_channels == mock_config['output_channels']
|
||||
assert network.use_attention == mock_config['use_attention']
|
||||
assert hasattr(network, 'encoder')
|
||||
assert hasattr(network, 'decoder')
|
||||
assert network.input_channels == mock_config['input_channels']
|
||||
assert network.hidden_dim == mock_config['hidden_dim']
|
||||
assert network.output_dim == mock_config['output_dim']
|
||||
if mock_config['use_attention']:
|
||||
assert hasattr(network, 'attention')
|
||||
|
||||
def test_forward_pass_produces_correct_output_shape(self, translation_network, mock_csi_input):
|
||||
"""Test that forward pass produces correctly shaped output"""
|
||||
# Act
|
||||
with torch.no_grad():
|
||||
output = translation_network(mock_csi_input)
|
||||
output = translation_network(mock_csi_input)
|
||||
|
||||
# Assert
|
||||
assert output is not None
|
||||
assert isinstance(output, torch.Tensor)
|
||||
assert output.shape[0] == mock_csi_input.shape[0] # Batch size preserved
|
||||
assert output.shape[1] == translation_network.output_dim # Correct output dimension
|
||||
assert len(output.shape) == 4 # Should maintain spatial dimensions
|
||||
assert output.shape[1] == translation_network.output_channels # Correct output channels
|
||||
assert output.shape[2] == mock_csi_input.shape[2] # Spatial height preserved
|
||||
assert output.shape[3] == mock_csi_input.shape[3] # Spatial width preserved
|
||||
|
||||
def test_forward_pass_handles_different_batch_sizes(self, translation_network):
|
||||
"""Test that network handles different batch sizes correctly"""
|
||||
def test_forward_pass_handles_different_input_sizes(self, translation_network):
|
||||
"""Test that forward pass handles different input sizes"""
|
||||
# Arrange
|
||||
batch_sizes = [1, 4, 8]
|
||||
small_input = torch.randn(2, 6, 28, 50)
|
||||
large_input = torch.randn(8, 6, 112, 200)
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
input_tensor = torch.randn(batch_size, 3, 56, 100)
|
||||
|
||||
# Act
|
||||
with torch.no_grad():
|
||||
output = translation_network(input_tensor)
|
||||
|
||||
# Assert
|
||||
assert output.shape[0] == batch_size
|
||||
assert output.shape[1] == translation_network.output_dim
|
||||
# Act
|
||||
small_output = translation_network(small_input)
|
||||
large_output = translation_network(large_input)
|
||||
|
||||
# Assert
|
||||
assert small_output.shape == (2, 256, 28, 50)
|
||||
assert large_output.shape == (8, 256, 112, 200)
|
||||
|
||||
def test_network_is_trainable(self, translation_network, mock_csi_input):
|
||||
"""Test that network parameters are trainable"""
|
||||
def test_encoder_extracts_hierarchical_features(self, translation_network, mock_csi_input):
|
||||
"""Test that encoder extracts hierarchical features"""
|
||||
# Act
|
||||
features = translation_network.encode(mock_csi_input)
|
||||
|
||||
# Assert
|
||||
assert features is not None
|
||||
assert isinstance(features, list)
|
||||
assert len(features) == len(translation_network.encoder)
|
||||
|
||||
# Check feature map sizes decrease with depth
|
||||
for i in range(1, len(features)):
|
||||
assert features[i].shape[2] <= features[i-1].shape[2] # Height decreases or stays same
|
||||
assert features[i].shape[3] <= features[i-1].shape[3] # Width decreases or stays same
|
||||
|
||||
def test_decoder_reconstructs_target_features(self, translation_network, mock_csi_input):
|
||||
"""Test that decoder reconstructs target feature representation"""
|
||||
# Arrange
|
||||
criterion = nn.MSELoss()
|
||||
encoded_features = translation_network.encode(mock_csi_input)
|
||||
|
||||
# Act
|
||||
decoded_output = translation_network.decode(encoded_features)
|
||||
|
||||
# Assert
|
||||
assert decoded_output is not None
|
||||
assert isinstance(decoded_output, torch.Tensor)
|
||||
assert decoded_output.shape[1] == translation_network.output_channels
|
||||
assert decoded_output.shape[2:] == mock_csi_input.shape[2:]
|
||||
|
||||
def test_attention_mechanism_enhances_features(self, mock_config, mock_csi_input):
|
||||
"""Test that attention mechanism enhances feature representation"""
|
||||
# Arrange
|
||||
config_with_attention = mock_config.copy()
|
||||
config_with_attention['use_attention'] = True
|
||||
|
||||
config_without_attention = mock_config.copy()
|
||||
config_without_attention['use_attention'] = False
|
||||
|
||||
network_with_attention = ModalityTranslationNetwork(config_with_attention)
|
||||
network_without_attention = ModalityTranslationNetwork(config_without_attention)
|
||||
|
||||
# Act
|
||||
output_with_attention = network_with_attention(mock_csi_input)
|
||||
output_without_attention = network_without_attention(mock_csi_input)
|
||||
|
||||
# Assert
|
||||
assert output_with_attention.shape == output_without_attention.shape
|
||||
# Outputs should be different due to attention mechanism
|
||||
assert not torch.allclose(output_with_attention, output_without_attention, atol=1e-6)
|
||||
|
||||
def test_training_mode_enables_dropout(self, translation_network, mock_csi_input):
|
||||
"""Test that training mode enables dropout for regularization"""
|
||||
# Arrange
|
||||
translation_network.train()
|
||||
|
||||
# Act
|
||||
output1 = translation_network(mock_csi_input)
|
||||
output2 = translation_network(mock_csi_input)
|
||||
|
||||
# Assert - outputs should be different due to dropout
|
||||
assert not torch.allclose(output1, output2, atol=1e-6)
|
||||
|
||||
def test_evaluation_mode_disables_dropout(self, translation_network, mock_csi_input):
|
||||
"""Test that evaluation mode disables dropout for consistent inference"""
|
||||
# Arrange
|
||||
translation_network.eval()
|
||||
|
||||
# Act
|
||||
output1 = translation_network(mock_csi_input)
|
||||
output2 = translation_network(mock_csi_input)
|
||||
|
||||
# Assert - outputs should be identical in eval mode
|
||||
assert torch.allclose(output1, output2, atol=1e-6)
|
||||
|
||||
def test_compute_translation_loss_measures_feature_alignment(self, translation_network, mock_csi_input, mock_target_features):
|
||||
"""Test that compute_translation_loss measures feature alignment"""
|
||||
# Arrange
|
||||
predicted_features = translation_network(mock_csi_input)
|
||||
|
||||
# Act
|
||||
loss = translation_network.compute_translation_loss(predicted_features, mock_target_features)
|
||||
|
||||
# Assert
|
||||
assert loss is not None
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.dim() == 0 # Scalar loss
|
||||
assert loss.item() >= 0 # Loss should be non-negative
|
||||
|
||||
def test_compute_translation_loss_handles_different_loss_types(self, translation_network, mock_csi_input, mock_target_features):
|
||||
"""Test that compute_translation_loss handles different loss types"""
|
||||
# Arrange
|
||||
predicted_features = translation_network(mock_csi_input)
|
||||
|
||||
# Act
|
||||
mse_loss = translation_network.compute_translation_loss(predicted_features, mock_target_features, loss_type='mse')
|
||||
l1_loss = translation_network.compute_translation_loss(predicted_features, mock_target_features, loss_type='l1')
|
||||
|
||||
# Assert
|
||||
assert mse_loss is not None
|
||||
assert l1_loss is not None
|
||||
assert mse_loss.item() != l1_loss.item() # Different loss types should give different values
|
||||
|
||||
def test_get_feature_statistics_provides_analysis(self, translation_network, mock_csi_input):
|
||||
"""Test that get_feature_statistics provides feature analysis"""
|
||||
# Arrange
|
||||
output = translation_network(mock_csi_input)
|
||||
|
||||
# Act
|
||||
stats = translation_network.get_feature_statistics(output)
|
||||
|
||||
# Assert
|
||||
assert stats is not None
|
||||
assert isinstance(stats, dict)
|
||||
assert 'mean' in stats
|
||||
assert 'std' in stats
|
||||
assert 'min' in stats
|
||||
assert 'max' in stats
|
||||
assert 'sparsity' in stats
|
||||
|
||||
def test_network_supports_gradient_computation(self, translation_network, mock_csi_input, mock_target_features):
|
||||
"""Test that network supports gradient computation for training"""
|
||||
# Arrange
|
||||
translation_network.train()
|
||||
optimizer = torch.optim.Adam(translation_network.parameters(), lr=0.001)
|
||||
|
||||
# Act
|
||||
output = translation_network(mock_csi_input)
|
||||
# Create target with same shape as output
|
||||
target = torch.randn_like(output)
|
||||
loss = criterion(output, target)
|
||||
loss = translation_network.compute_translation_loss(output, mock_target_features)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
# Assert
|
||||
assert loss.item() > 0
|
||||
# Check that gradients are computed
|
||||
for param in translation_network.parameters():
|
||||
if param.requires_grad:
|
||||
assert param.grad is not None
|
||||
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
|
||||
|
||||
def test_network_handles_invalid_input_shape(self, translation_network):
|
||||
"""Test that network handles invalid input shapes gracefully"""
|
||||
def test_network_validates_input_dimensions(self, translation_network):
|
||||
"""Test that network validates input dimensions"""
|
||||
# Arrange
|
||||
invalid_input = torch.randn(2, 5, 56, 100) # Wrong number of channels
|
||||
invalid_input = torch.randn(4, 3, 56, 100) # Wrong number of channels
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError):
|
||||
with pytest.raises(ModalityTranslationError):
|
||||
translation_network(invalid_input)
|
||||
|
||||
def test_network_supports_evaluation_mode(self, translation_network, mock_csi_input):
|
||||
"""Test that network supports evaluation mode"""
|
||||
# Act
|
||||
translation_network.eval()
|
||||
def test_network_handles_batch_size_one(self, translation_network):
|
||||
"""Test that network handles single sample inference"""
|
||||
# Arrange
|
||||
single_input = torch.randn(1, 6, 56, 100)
|
||||
|
||||
with torch.no_grad():
|
||||
output1 = translation_network(mock_csi_input)
|
||||
output2 = translation_network(mock_csi_input)
|
||||
|
||||
# Assert - In eval mode with same input, outputs should be identical
|
||||
assert torch.allclose(output1, output2, atol=1e-6)
|
||||
|
||||
def test_network_feature_extraction_quality(self, translation_network, mock_csi_input):
|
||||
"""Test that network extracts meaningful features"""
|
||||
# Act
|
||||
with torch.no_grad():
|
||||
output = translation_network(mock_csi_input)
|
||||
output = translation_network(single_input)
|
||||
|
||||
# Assert
|
||||
# Features should have reasonable statistics
|
||||
assert not torch.isnan(output).any()
|
||||
assert not torch.isinf(output).any()
|
||||
assert output.std() > 0.01 # Features should have some variance
|
||||
assert output.std() < 10.0 # But not be too extreme
|
||||
assert output.shape == (1, 256, 56, 100)
|
||||
|
||||
def test_save_and_load_model_state(self, translation_network, mock_csi_input):
|
||||
"""Test that model state can be saved and loaded"""
|
||||
# Arrange
|
||||
original_output = translation_network(mock_csi_input)
|
||||
|
||||
# Act - Save state
|
||||
state_dict = translation_network.state_dict()
|
||||
|
||||
# Create new network and load state
|
||||
new_network = ModalityTranslationNetwork(translation_network.config)
|
||||
new_network.load_state_dict(state_dict)
|
||||
new_output = new_network(mock_csi_input)
|
||||
|
||||
# Assert
|
||||
assert torch.allclose(original_output, new_output, atol=1e-6)
|
||||
|
||||
def test_network_configuration_validation(self):
|
||||
"""Test that network validates configuration parameters"""
|
||||
# Arrange
|
||||
invalid_config = {
|
||||
'input_channels': 0, # Invalid
|
||||
'hidden_channels': [], # Invalid
|
||||
'output_channels': 256
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
ModalityTranslationNetwork(invalid_config)
|
||||
|
||||
def test_feature_visualization_support(self, translation_network, mock_csi_input):
|
||||
"""Test that network supports feature visualization"""
|
||||
# Act
|
||||
features = translation_network.get_intermediate_features(mock_csi_input)
|
||||
|
||||
# Assert
|
||||
assert features is not None
|
||||
assert isinstance(features, dict)
|
||||
assert 'encoder_features' in features
|
||||
assert 'decoder_features' in features
|
||||
if translation_network.use_attention:
|
||||
assert 'attention_weights' in features
|
||||
Reference in New Issue
Block a user