Implement WiFi-DensePose system with CSI data extraction and router interface

- Added CSIExtractor class for extracting CSI data from WiFi routers.
- Implemented RouterInterface class for SSH communication with routers.
- Developed DensePoseHead class for body part segmentation and UV coordinate regression.
- Created unit tests for CSIExtractor and RouterInterface to ensure functionality and error handling.
- Integrated paramiko for SSH connections and command execution.
- Established configuration validation for both extractor and router interface.
- Added context manager support for resource management in both classes.
This commit is contained in:
rUv
2025-06-07 05:55:27 +00:00
parent 44e5382931
commit cbebdd648f
14 changed files with 2871 additions and 213 deletions

View File

@@ -3,27 +3,27 @@ import torch
import torch.nn as nn
import numpy as np
from unittest.mock import Mock, patch
from src.models.modality_translation import ModalityTranslationNetwork
from src.models.modality_translation import ModalityTranslationNetwork, ModalityTranslationError
class TestModalityTranslationNetwork:
"""Test suite for Modality Translation Network following London School TDD principles"""
@pytest.fixture
def mock_csi_input(self):
"""Generate synthetic CSI input tensor for testing"""
# Batch size 2, 3 antennas, 56 subcarriers, 100 temporal samples
return torch.randn(2, 3, 56, 100)
@pytest.fixture
def mock_config(self):
"""Configuration for modality translation network"""
return {
'input_channels': 3,
'hidden_dim': 256,
'output_dim': 512,
'num_layers': 3,
'dropout_rate': 0.1
'input_channels': 6, # Real and imaginary parts for 3 antennas
'hidden_channels': [64, 128, 256],
'output_channels': 256,
'kernel_size': 3,
'stride': 1,
'padding': 1,
'dropout_rate': 0.1,
'activation': 'relu',
'normalization': 'batch',
'use_attention': True,
'attention_heads': 8
}
@pytest.fixture
@@ -31,98 +31,263 @@ class TestModalityTranslationNetwork:
"""Create modality translation network instance for testing"""
return ModalityTranslationNetwork(mock_config)
@pytest.fixture
def mock_csi_input(self):
"""Generate mock CSI input tensor"""
batch_size = 4
channels = 6 # Real and imaginary parts for 3 antennas
height = 56 # Number of subcarriers
width = 100 # Time samples
return torch.randn(batch_size, channels, height, width)
@pytest.fixture
def mock_target_features(self):
"""Generate mock target feature tensor for training"""
batch_size = 4
feature_dim = 256
spatial_height = 56
spatial_width = 100
return torch.randn(batch_size, feature_dim, spatial_height, spatial_width)
def test_network_initialization_creates_correct_architecture(self, mock_config):
"""Test that network initializes with correct architecture"""
"""Test that modality translation network initializes with correct architecture"""
# Act
network = ModalityTranslationNetwork(mock_config)
# Assert
assert network is not None
assert isinstance(network, nn.Module)
assert network.input_channels == mock_config['input_channels']
assert network.output_channels == mock_config['output_channels']
assert network.use_attention == mock_config['use_attention']
assert hasattr(network, 'encoder')
assert hasattr(network, 'decoder')
assert network.input_channels == mock_config['input_channels']
assert network.hidden_dim == mock_config['hidden_dim']
assert network.output_dim == mock_config['output_dim']
if mock_config['use_attention']:
assert hasattr(network, 'attention')
def test_forward_pass_produces_correct_output_shape(self, translation_network, mock_csi_input):
"""Test that forward pass produces correctly shaped output"""
# Act
with torch.no_grad():
output = translation_network(mock_csi_input)
output = translation_network(mock_csi_input)
# Assert
assert output is not None
assert isinstance(output, torch.Tensor)
assert output.shape[0] == mock_csi_input.shape[0] # Batch size preserved
assert output.shape[1] == translation_network.output_dim # Correct output dimension
assert len(output.shape) == 4 # Should maintain spatial dimensions
assert output.shape[1] == translation_network.output_channels # Correct output channels
assert output.shape[2] == mock_csi_input.shape[2] # Spatial height preserved
assert output.shape[3] == mock_csi_input.shape[3] # Spatial width preserved
def test_forward_pass_handles_different_batch_sizes(self, translation_network):
"""Test that network handles different batch sizes correctly"""
def test_forward_pass_handles_different_input_sizes(self, translation_network):
"""Test that forward pass handles different input sizes"""
# Arrange
batch_sizes = [1, 4, 8]
small_input = torch.randn(2, 6, 28, 50)
large_input = torch.randn(8, 6, 112, 200)
for batch_size in batch_sizes:
input_tensor = torch.randn(batch_size, 3, 56, 100)
# Act
with torch.no_grad():
output = translation_network(input_tensor)
# Assert
assert output.shape[0] == batch_size
assert output.shape[1] == translation_network.output_dim
# Act
small_output = translation_network(small_input)
large_output = translation_network(large_input)
# Assert
assert small_output.shape == (2, 256, 28, 50)
assert large_output.shape == (8, 256, 112, 200)
def test_network_is_trainable(self, translation_network, mock_csi_input):
"""Test that network parameters are trainable"""
def test_encoder_extracts_hierarchical_features(self, translation_network, mock_csi_input):
"""Test that encoder extracts hierarchical features"""
# Act
features = translation_network.encode(mock_csi_input)
# Assert
assert features is not None
assert isinstance(features, list)
assert len(features) == len(translation_network.encoder)
# Check feature map sizes decrease with depth
for i in range(1, len(features)):
assert features[i].shape[2] <= features[i-1].shape[2] # Height decreases or stays same
assert features[i].shape[3] <= features[i-1].shape[3] # Width decreases or stays same
def test_decoder_reconstructs_target_features(self, translation_network, mock_csi_input):
"""Test that decoder reconstructs target feature representation"""
# Arrange
criterion = nn.MSELoss()
encoded_features = translation_network.encode(mock_csi_input)
# Act
decoded_output = translation_network.decode(encoded_features)
# Assert
assert decoded_output is not None
assert isinstance(decoded_output, torch.Tensor)
assert decoded_output.shape[1] == translation_network.output_channels
assert decoded_output.shape[2:] == mock_csi_input.shape[2:]
def test_attention_mechanism_enhances_features(self, mock_config, mock_csi_input):
"""Test that attention mechanism enhances feature representation"""
# Arrange
config_with_attention = mock_config.copy()
config_with_attention['use_attention'] = True
config_without_attention = mock_config.copy()
config_without_attention['use_attention'] = False
network_with_attention = ModalityTranslationNetwork(config_with_attention)
network_without_attention = ModalityTranslationNetwork(config_without_attention)
# Act
output_with_attention = network_with_attention(mock_csi_input)
output_without_attention = network_without_attention(mock_csi_input)
# Assert
assert output_with_attention.shape == output_without_attention.shape
# Outputs should be different due to attention mechanism
assert not torch.allclose(output_with_attention, output_without_attention, atol=1e-6)
def test_training_mode_enables_dropout(self, translation_network, mock_csi_input):
"""Test that training mode enables dropout for regularization"""
# Arrange
translation_network.train()
# Act
output1 = translation_network(mock_csi_input)
output2 = translation_network(mock_csi_input)
# Assert - outputs should be different due to dropout
assert not torch.allclose(output1, output2, atol=1e-6)
def test_evaluation_mode_disables_dropout(self, translation_network, mock_csi_input):
"""Test that evaluation mode disables dropout for consistent inference"""
# Arrange
translation_network.eval()
# Act
output1 = translation_network(mock_csi_input)
output2 = translation_network(mock_csi_input)
# Assert - outputs should be identical in eval mode
assert torch.allclose(output1, output2, atol=1e-6)
def test_compute_translation_loss_measures_feature_alignment(self, translation_network, mock_csi_input, mock_target_features):
"""Test that compute_translation_loss measures feature alignment"""
# Arrange
predicted_features = translation_network(mock_csi_input)
# Act
loss = translation_network.compute_translation_loss(predicted_features, mock_target_features)
# Assert
assert loss is not None
assert isinstance(loss, torch.Tensor)
assert loss.dim() == 0 # Scalar loss
assert loss.item() >= 0 # Loss should be non-negative
def test_compute_translation_loss_handles_different_loss_types(self, translation_network, mock_csi_input, mock_target_features):
"""Test that compute_translation_loss handles different loss types"""
# Arrange
predicted_features = translation_network(mock_csi_input)
# Act
mse_loss = translation_network.compute_translation_loss(predicted_features, mock_target_features, loss_type='mse')
l1_loss = translation_network.compute_translation_loss(predicted_features, mock_target_features, loss_type='l1')
# Assert
assert mse_loss is not None
assert l1_loss is not None
assert mse_loss.item() != l1_loss.item() # Different loss types should give different values
def test_get_feature_statistics_provides_analysis(self, translation_network, mock_csi_input):
"""Test that get_feature_statistics provides feature analysis"""
# Arrange
output = translation_network(mock_csi_input)
# Act
stats = translation_network.get_feature_statistics(output)
# Assert
assert stats is not None
assert isinstance(stats, dict)
assert 'mean' in stats
assert 'std' in stats
assert 'min' in stats
assert 'max' in stats
assert 'sparsity' in stats
def test_network_supports_gradient_computation(self, translation_network, mock_csi_input, mock_target_features):
"""Test that network supports gradient computation for training"""
# Arrange
translation_network.train()
optimizer = torch.optim.Adam(translation_network.parameters(), lr=0.001)
# Act
output = translation_network(mock_csi_input)
# Create target with same shape as output
target = torch.randn_like(output)
loss = criterion(output, target)
loss = translation_network.compute_translation_loss(output, mock_target_features)
optimizer.zero_grad()
loss.backward()
# Assert
assert loss.item() > 0
# Check that gradients are computed
for param in translation_network.parameters():
if param.requires_grad:
assert param.grad is not None
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
def test_network_handles_invalid_input_shape(self, translation_network):
"""Test that network handles invalid input shapes gracefully"""
def test_network_validates_input_dimensions(self, translation_network):
"""Test that network validates input dimensions"""
# Arrange
invalid_input = torch.randn(2, 5, 56, 100) # Wrong number of channels
invalid_input = torch.randn(4, 3, 56, 100) # Wrong number of channels
# Act & Assert
with pytest.raises(RuntimeError):
with pytest.raises(ModalityTranslationError):
translation_network(invalid_input)
def test_network_supports_evaluation_mode(self, translation_network, mock_csi_input):
"""Test that network supports evaluation mode"""
# Act
translation_network.eval()
def test_network_handles_batch_size_one(self, translation_network):
"""Test that network handles single sample inference"""
# Arrange
single_input = torch.randn(1, 6, 56, 100)
with torch.no_grad():
output1 = translation_network(mock_csi_input)
output2 = translation_network(mock_csi_input)
# Assert - In eval mode with same input, outputs should be identical
assert torch.allclose(output1, output2, atol=1e-6)
def test_network_feature_extraction_quality(self, translation_network, mock_csi_input):
"""Test that network extracts meaningful features"""
# Act
with torch.no_grad():
output = translation_network(mock_csi_input)
output = translation_network(single_input)
# Assert
# Features should have reasonable statistics
assert not torch.isnan(output).any()
assert not torch.isinf(output).any()
assert output.std() > 0.01 # Features should have some variance
assert output.std() < 10.0 # But not be too extreme
assert output.shape == (1, 256, 56, 100)
def test_save_and_load_model_state(self, translation_network, mock_csi_input):
"""Test that model state can be saved and loaded"""
# Arrange
original_output = translation_network(mock_csi_input)
# Act - Save state
state_dict = translation_network.state_dict()
# Create new network and load state
new_network = ModalityTranslationNetwork(translation_network.config)
new_network.load_state_dict(state_dict)
new_output = new_network(mock_csi_input)
# Assert
assert torch.allclose(original_output, new_output, atol=1e-6)
def test_network_configuration_validation(self):
"""Test that network validates configuration parameters"""
# Arrange
invalid_config = {
'input_channels': 0, # Invalid
'hidden_channels': [], # Invalid
'output_channels': 256
}
# Act & Assert
with pytest.raises(ValueError):
ModalityTranslationNetwork(invalid_config)
def test_feature_visualization_support(self, translation_network, mock_csi_input):
"""Test that network supports feature visualization"""
# Act
features = translation_network.get_intermediate_features(mock_csi_input)
# Assert
assert features is not None
assert isinstance(features, dict)
assert 'encoder_features' in features
assert 'decoder_features' in features
if translation_network.use_attention:
assert 'attention_weights' in features