Files
wifi-densepose/tests/unit/test_modality_translation.py

128 lines
4.9 KiB
Python

import pytest
import torch
import torch.nn as nn
import numpy as np
from unittest.mock import Mock, patch
from src.models.modality_translation import ModalityTranslationNetwork
class TestModalityTranslationNetwork:
"""Test suite for Modality Translation Network following London School TDD principles"""
@pytest.fixture
def mock_csi_input(self):
"""Generate synthetic CSI input tensor for testing"""
# Batch size 2, 3 antennas, 56 subcarriers, 100 temporal samples
return torch.randn(2, 3, 56, 100)
@pytest.fixture
def mock_config(self):
"""Configuration for modality translation network"""
return {
'input_channels': 3,
'hidden_dim': 256,
'output_dim': 512,
'num_layers': 3,
'dropout_rate': 0.1
}
@pytest.fixture
def translation_network(self, mock_config):
"""Create modality translation network instance for testing"""
return ModalityTranslationNetwork(mock_config)
def test_network_initialization_creates_correct_architecture(self, mock_config):
"""Test that network initializes with correct architecture"""
# Act
network = ModalityTranslationNetwork(mock_config)
# Assert
assert network is not None
assert isinstance(network, nn.Module)
assert hasattr(network, 'encoder')
assert hasattr(network, 'decoder')
assert network.input_channels == mock_config['input_channels']
assert network.hidden_dim == mock_config['hidden_dim']
assert network.output_dim == mock_config['output_dim']
def test_forward_pass_produces_correct_output_shape(self, translation_network, mock_csi_input):
"""Test that forward pass produces correctly shaped output"""
# Act
with torch.no_grad():
output = translation_network(mock_csi_input)
# Assert
assert output is not None
assert isinstance(output, torch.Tensor)
assert output.shape[0] == mock_csi_input.shape[0] # Batch size preserved
assert output.shape[1] == translation_network.output_dim # Correct output dimension
assert len(output.shape) == 4 # Should maintain spatial dimensions
def test_forward_pass_handles_different_batch_sizes(self, translation_network):
"""Test that network handles different batch sizes correctly"""
# Arrange
batch_sizes = [1, 4, 8]
for batch_size in batch_sizes:
input_tensor = torch.randn(batch_size, 3, 56, 100)
# Act
with torch.no_grad():
output = translation_network(input_tensor)
# Assert
assert output.shape[0] == batch_size
assert output.shape[1] == translation_network.output_dim
def test_network_is_trainable(self, translation_network, mock_csi_input):
"""Test that network parameters are trainable"""
# Arrange
criterion = nn.MSELoss()
# Act
output = translation_network(mock_csi_input)
# Create target with same shape as output
target = torch.randn_like(output)
loss = criterion(output, target)
loss.backward()
# Assert
assert loss.item() > 0
# Check that gradients are computed
for param in translation_network.parameters():
if param.requires_grad:
assert param.grad is not None
def test_network_handles_invalid_input_shape(self, translation_network):
"""Test that network handles invalid input shapes gracefully"""
# Arrange
invalid_input = torch.randn(2, 5, 56, 100) # Wrong number of channels
# Act & Assert
with pytest.raises(RuntimeError):
translation_network(invalid_input)
def test_network_supports_evaluation_mode(self, translation_network, mock_csi_input):
"""Test that network supports evaluation mode"""
# Act
translation_network.eval()
with torch.no_grad():
output1 = translation_network(mock_csi_input)
output2 = translation_network(mock_csi_input)
# Assert - In eval mode with same input, outputs should be identical
assert torch.allclose(output1, output2, atol=1e-6)
def test_network_feature_extraction_quality(self, translation_network, mock_csi_input):
"""Test that network extracts meaningful features"""
# Act
with torch.no_grad():
output = translation_network(mock_csi_input)
# Assert
# Features should have reasonable statistics
assert not torch.isnan(output).any()
assert not torch.isinf(output).any()
assert output.std() > 0.01 # Features should have some variance
assert output.std() < 10.0 # But not be too extreme