128 lines
4.9 KiB
Python
128 lines
4.9 KiB
Python
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
from unittest.mock import Mock, patch
|
|
from src.models.modality_translation import ModalityTranslationNetwork
|
|
|
|
|
|
class TestModalityTranslationNetwork:
|
|
"""Test suite for Modality Translation Network following London School TDD principles"""
|
|
|
|
@pytest.fixture
|
|
def mock_csi_input(self):
|
|
"""Generate synthetic CSI input tensor for testing"""
|
|
# Batch size 2, 3 antennas, 56 subcarriers, 100 temporal samples
|
|
return torch.randn(2, 3, 56, 100)
|
|
|
|
@pytest.fixture
|
|
def mock_config(self):
|
|
"""Configuration for modality translation network"""
|
|
return {
|
|
'input_channels': 3,
|
|
'hidden_dim': 256,
|
|
'output_dim': 512,
|
|
'num_layers': 3,
|
|
'dropout_rate': 0.1
|
|
}
|
|
|
|
@pytest.fixture
|
|
def translation_network(self, mock_config):
|
|
"""Create modality translation network instance for testing"""
|
|
return ModalityTranslationNetwork(mock_config)
|
|
|
|
def test_network_initialization_creates_correct_architecture(self, mock_config):
|
|
"""Test that network initializes with correct architecture"""
|
|
# Act
|
|
network = ModalityTranslationNetwork(mock_config)
|
|
|
|
# Assert
|
|
assert network is not None
|
|
assert isinstance(network, nn.Module)
|
|
assert hasattr(network, 'encoder')
|
|
assert hasattr(network, 'decoder')
|
|
assert network.input_channels == mock_config['input_channels']
|
|
assert network.hidden_dim == mock_config['hidden_dim']
|
|
assert network.output_dim == mock_config['output_dim']
|
|
|
|
def test_forward_pass_produces_correct_output_shape(self, translation_network, mock_csi_input):
|
|
"""Test that forward pass produces correctly shaped output"""
|
|
# Act
|
|
with torch.no_grad():
|
|
output = translation_network(mock_csi_input)
|
|
|
|
# Assert
|
|
assert output is not None
|
|
assert isinstance(output, torch.Tensor)
|
|
assert output.shape[0] == mock_csi_input.shape[0] # Batch size preserved
|
|
assert output.shape[1] == translation_network.output_dim # Correct output dimension
|
|
assert len(output.shape) == 4 # Should maintain spatial dimensions
|
|
|
|
def test_forward_pass_handles_different_batch_sizes(self, translation_network):
|
|
"""Test that network handles different batch sizes correctly"""
|
|
# Arrange
|
|
batch_sizes = [1, 4, 8]
|
|
|
|
for batch_size in batch_sizes:
|
|
input_tensor = torch.randn(batch_size, 3, 56, 100)
|
|
|
|
# Act
|
|
with torch.no_grad():
|
|
output = translation_network(input_tensor)
|
|
|
|
# Assert
|
|
assert output.shape[0] == batch_size
|
|
assert output.shape[1] == translation_network.output_dim
|
|
|
|
def test_network_is_trainable(self, translation_network, mock_csi_input):
|
|
"""Test that network parameters are trainable"""
|
|
# Arrange
|
|
criterion = nn.MSELoss()
|
|
|
|
# Act
|
|
output = translation_network(mock_csi_input)
|
|
# Create target with same shape as output
|
|
target = torch.randn_like(output)
|
|
loss = criterion(output, target)
|
|
loss.backward()
|
|
|
|
# Assert
|
|
assert loss.item() > 0
|
|
# Check that gradients are computed
|
|
for param in translation_network.parameters():
|
|
if param.requires_grad:
|
|
assert param.grad is not None
|
|
|
|
def test_network_handles_invalid_input_shape(self, translation_network):
|
|
"""Test that network handles invalid input shapes gracefully"""
|
|
# Arrange
|
|
invalid_input = torch.randn(2, 5, 56, 100) # Wrong number of channels
|
|
|
|
# Act & Assert
|
|
with pytest.raises(RuntimeError):
|
|
translation_network(invalid_input)
|
|
|
|
def test_network_supports_evaluation_mode(self, translation_network, mock_csi_input):
|
|
"""Test that network supports evaluation mode"""
|
|
# Act
|
|
translation_network.eval()
|
|
|
|
with torch.no_grad():
|
|
output1 = translation_network(mock_csi_input)
|
|
output2 = translation_network(mock_csi_input)
|
|
|
|
# Assert - In eval mode with same input, outputs should be identical
|
|
assert torch.allclose(output1, output2, atol=1e-6)
|
|
|
|
def test_network_feature_extraction_quality(self, translation_network, mock_csi_input):
|
|
"""Test that network extracts meaningful features"""
|
|
# Act
|
|
with torch.no_grad():
|
|
output = translation_network(mock_csi_input)
|
|
|
|
# Assert
|
|
# Features should have reasonable statistics
|
|
assert not torch.isnan(output).any()
|
|
assert not torch.isinf(output).any()
|
|
assert output.std() > 0.01 # Features should have some variance
|
|
assert output.std() < 10.0 # But not be too extreme |