Implement WiFi-DensePose system with CSI data extraction and router interface
- Added CSIExtractor class for extracting CSI data from WiFi routers. - Implemented RouterInterface class for SSH communication with routers. - Developed DensePoseHead class for body part segmentation and UV coordinate regression. - Created unit tests for CSIExtractor and RouterInterface to ensure functionality and error handling. - Integrated paramiko for SSH connections and command execution. - Established configuration validation for both extractor and router interface. - Added context manager support for resource management in both classes.
This commit is contained in:
353
tests/integration/test_csi_pipeline.py
Normal file
353
tests/integration/test_csi_pipeline.py
Normal file
@@ -0,0 +1,353 @@
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from src.core.csi_processor import CSIProcessor
|
||||
from src.core.phase_sanitizer import PhaseSanitizer
|
||||
from src.hardware.router_interface import RouterInterface
|
||||
from src.hardware.csi_extractor import CSIExtractor
|
||||
|
||||
|
||||
class TestCSIPipeline:
|
||||
"""Integration tests for CSI processing pipeline following London School TDD principles"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_router_config(self):
|
||||
"""Configuration for router interface"""
|
||||
return {
|
||||
'router_ip': '192.168.1.1',
|
||||
'username': 'admin',
|
||||
'password': 'password',
|
||||
'ssh_port': 22,
|
||||
'timeout': 30,
|
||||
'max_retries': 3
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_extractor_config(self):
|
||||
"""Configuration for CSI extractor"""
|
||||
return {
|
||||
'interface': 'wlan0',
|
||||
'channel': 6,
|
||||
'bandwidth': 20,
|
||||
'antenna_count': 3,
|
||||
'subcarrier_count': 56,
|
||||
'sample_rate': 1000,
|
||||
'buffer_size': 1024
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_processor_config(self):
|
||||
"""Configuration for CSI processor"""
|
||||
return {
|
||||
'window_size': 100,
|
||||
'overlap': 0.5,
|
||||
'filter_type': 'butterworth',
|
||||
'filter_order': 4,
|
||||
'cutoff_frequency': 50,
|
||||
'normalization': 'minmax',
|
||||
'outlier_threshold': 3.0
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sanitizer_config(self):
|
||||
"""Configuration for phase sanitizer"""
|
||||
return {
|
||||
'unwrap_method': 'numpy',
|
||||
'smoothing_window': 5,
|
||||
'outlier_threshold': 2.0,
|
||||
'interpolation_method': 'linear',
|
||||
'phase_correction': True
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def csi_pipeline_components(self, mock_router_config, mock_extractor_config,
|
||||
mock_processor_config, mock_sanitizer_config):
|
||||
"""Create CSI pipeline components for testing"""
|
||||
router = RouterInterface(mock_router_config)
|
||||
extractor = CSIExtractor(mock_extractor_config)
|
||||
processor = CSIProcessor(mock_processor_config)
|
||||
sanitizer = PhaseSanitizer(mock_sanitizer_config)
|
||||
|
||||
return {
|
||||
'router': router,
|
||||
'extractor': extractor,
|
||||
'processor': processor,
|
||||
'sanitizer': sanitizer
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_raw_csi_data(self):
|
||||
"""Generate mock raw CSI data"""
|
||||
batch_size = 10
|
||||
antennas = 3
|
||||
subcarriers = 56
|
||||
time_samples = 100
|
||||
|
||||
# Generate complex CSI data
|
||||
real_part = np.random.randn(batch_size, antennas, subcarriers, time_samples)
|
||||
imag_part = np.random.randn(batch_size, antennas, subcarriers, time_samples)
|
||||
|
||||
return {
|
||||
'csi_data': real_part + 1j * imag_part,
|
||||
'timestamps': np.linspace(0, 1, time_samples),
|
||||
'metadata': {
|
||||
'channel': 6,
|
||||
'bandwidth': 20,
|
||||
'rssi': -45,
|
||||
'noise_floor': -90
|
||||
}
|
||||
}
|
||||
|
||||
def test_end_to_end_csi_pipeline_processes_data_correctly(self, csi_pipeline_components, mock_raw_csi_data):
|
||||
"""Test that end-to-end CSI pipeline processes data correctly"""
|
||||
# Arrange
|
||||
router = csi_pipeline_components['router']
|
||||
extractor = csi_pipeline_components['extractor']
|
||||
processor = csi_pipeline_components['processor']
|
||||
sanitizer = csi_pipeline_components['sanitizer']
|
||||
|
||||
# Mock the hardware extraction
|
||||
with patch.object(extractor, 'extract_csi_data', return_value=mock_raw_csi_data):
|
||||
with patch.object(router, 'connect', return_value=True):
|
||||
with patch.object(router, 'configure_monitor_mode', return_value=True):
|
||||
|
||||
# Act - Run the pipeline
|
||||
# 1. Connect to router and configure
|
||||
router.connect()
|
||||
router.configure_monitor_mode('wlan0', 6)
|
||||
|
||||
# 2. Extract CSI data
|
||||
raw_data = extractor.extract_csi_data()
|
||||
|
||||
# 3. Process CSI data
|
||||
processed_data = processor.process_csi_batch(raw_data['csi_data'])
|
||||
|
||||
# 4. Sanitize phase information
|
||||
sanitized_data = sanitizer.sanitize_phase_batch(processed_data)
|
||||
|
||||
# Assert
|
||||
assert raw_data is not None
|
||||
assert processed_data is not None
|
||||
assert sanitized_data is not None
|
||||
|
||||
# Check data flow integrity
|
||||
assert isinstance(processed_data, torch.Tensor)
|
||||
assert isinstance(sanitized_data, torch.Tensor)
|
||||
assert processed_data.shape == sanitized_data.shape
|
||||
|
||||
def test_pipeline_handles_hardware_connection_failure(self, csi_pipeline_components):
|
||||
"""Test that pipeline handles hardware connection failures gracefully"""
|
||||
# Arrange
|
||||
router = csi_pipeline_components['router']
|
||||
|
||||
# Mock connection failure
|
||||
with patch.object(router, 'connect', return_value=False):
|
||||
|
||||
# Act & Assert
|
||||
connection_result = router.connect()
|
||||
assert connection_result is False
|
||||
|
||||
# Pipeline should handle this gracefully
|
||||
with pytest.raises(Exception): # Should raise appropriate exception
|
||||
router.configure_monitor_mode('wlan0', 6)
|
||||
|
||||
def test_pipeline_handles_csi_extraction_timeout(self, csi_pipeline_components):
|
||||
"""Test that pipeline handles CSI extraction timeouts"""
|
||||
# Arrange
|
||||
extractor = csi_pipeline_components['extractor']
|
||||
|
||||
# Mock extraction timeout
|
||||
with patch.object(extractor, 'extract_csi_data', side_effect=TimeoutError("CSI extraction timeout")):
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TimeoutError):
|
||||
extractor.extract_csi_data()
|
||||
|
||||
def test_pipeline_handles_invalid_csi_data_format(self, csi_pipeline_components):
|
||||
"""Test that pipeline handles invalid CSI data formats"""
|
||||
# Arrange
|
||||
processor = csi_pipeline_components['processor']
|
||||
|
||||
# Invalid data format
|
||||
invalid_data = np.random.randn(10, 2, 56) # Missing time dimension
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
processor.process_csi_batch(invalid_data)
|
||||
|
||||
def test_pipeline_maintains_data_consistency_across_stages(self, csi_pipeline_components, mock_raw_csi_data):
|
||||
"""Test that pipeline maintains data consistency across processing stages"""
|
||||
# Arrange
|
||||
processor = csi_pipeline_components['processor']
|
||||
sanitizer = csi_pipeline_components['sanitizer']
|
||||
|
||||
csi_data = mock_raw_csi_data['csi_data']
|
||||
|
||||
# Act
|
||||
processed_data = processor.process_csi_batch(csi_data)
|
||||
sanitized_data = sanitizer.sanitize_phase_batch(processed_data)
|
||||
|
||||
# Assert - Check data consistency
|
||||
assert processed_data.shape[0] == sanitized_data.shape[0] # Batch size preserved
|
||||
assert processed_data.shape[1] == sanitized_data.shape[1] # Antenna count preserved
|
||||
assert processed_data.shape[2] == sanitized_data.shape[2] # Subcarrier count preserved
|
||||
|
||||
# Check that data is not corrupted (no NaN or infinite values)
|
||||
assert not torch.isnan(processed_data).any()
|
||||
assert not torch.isinf(processed_data).any()
|
||||
assert not torch.isnan(sanitized_data).any()
|
||||
assert not torch.isinf(sanitized_data).any()
|
||||
|
||||
def test_pipeline_performance_meets_real_time_requirements(self, csi_pipeline_components, mock_raw_csi_data):
|
||||
"""Test that pipeline performance meets real-time processing requirements"""
|
||||
import time
|
||||
|
||||
# Arrange
|
||||
processor = csi_pipeline_components['processor']
|
||||
sanitizer = csi_pipeline_components['sanitizer']
|
||||
|
||||
csi_data = mock_raw_csi_data['csi_data']
|
||||
|
||||
# Act - Measure processing time
|
||||
start_time = time.time()
|
||||
|
||||
processed_data = processor.process_csi_batch(csi_data)
|
||||
sanitized_data = sanitizer.sanitize_phase_batch(processed_data)
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
# Assert - Should process within reasonable time (< 100ms for this data size)
|
||||
assert processing_time < 0.1, f"Processing took {processing_time:.3f}s, expected < 0.1s"
|
||||
|
||||
def test_pipeline_handles_different_data_sizes(self, csi_pipeline_components):
|
||||
"""Test that pipeline handles different CSI data sizes"""
|
||||
# Arrange
|
||||
processor = csi_pipeline_components['processor']
|
||||
sanitizer = csi_pipeline_components['sanitizer']
|
||||
|
||||
# Different data sizes
|
||||
small_data = np.random.randn(1, 3, 56, 50) + 1j * np.random.randn(1, 3, 56, 50)
|
||||
large_data = np.random.randn(20, 3, 56, 200) + 1j * np.random.randn(20, 3, 56, 200)
|
||||
|
||||
# Act
|
||||
small_processed = processor.process_csi_batch(small_data)
|
||||
small_sanitized = sanitizer.sanitize_phase_batch(small_processed)
|
||||
|
||||
large_processed = processor.process_csi_batch(large_data)
|
||||
large_sanitized = sanitizer.sanitize_phase_batch(large_processed)
|
||||
|
||||
# Assert
|
||||
assert small_processed.shape == small_sanitized.shape
|
||||
assert large_processed.shape == large_sanitized.shape
|
||||
assert small_processed.shape != large_processed.shape # Different sizes
|
||||
|
||||
def test_pipeline_configuration_validation(self, mock_router_config, mock_extractor_config,
|
||||
mock_processor_config, mock_sanitizer_config):
|
||||
"""Test that pipeline components validate configurations properly"""
|
||||
# Arrange - Invalid configurations
|
||||
invalid_router_config = mock_router_config.copy()
|
||||
invalid_router_config['router_ip'] = 'invalid_ip'
|
||||
|
||||
invalid_extractor_config = mock_extractor_config.copy()
|
||||
invalid_extractor_config['antenna_count'] = 0
|
||||
|
||||
invalid_processor_config = mock_processor_config.copy()
|
||||
invalid_processor_config['window_size'] = -1
|
||||
|
||||
invalid_sanitizer_config = mock_sanitizer_config.copy()
|
||||
invalid_sanitizer_config['smoothing_window'] = 0
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
RouterInterface(invalid_router_config)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
CSIExtractor(invalid_extractor_config)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
CSIProcessor(invalid_processor_config)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
PhaseSanitizer(invalid_sanitizer_config)
|
||||
|
||||
def test_pipeline_error_recovery_and_logging(self, csi_pipeline_components, mock_raw_csi_data):
|
||||
"""Test that pipeline handles errors gracefully and logs appropriately"""
|
||||
# Arrange
|
||||
processor = csi_pipeline_components['processor']
|
||||
|
||||
# Corrupt some data to trigger error handling
|
||||
corrupted_data = mock_raw_csi_data['csi_data'].copy()
|
||||
corrupted_data[0, 0, 0, :] = np.inf # Introduce infinite values
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError): # Should detect and handle corrupted data
|
||||
processor.process_csi_batch(corrupted_data)
|
||||
|
||||
def test_pipeline_memory_usage_optimization(self, csi_pipeline_components):
|
||||
"""Test that pipeline optimizes memory usage for large datasets"""
|
||||
# Arrange
|
||||
processor = csi_pipeline_components['processor']
|
||||
sanitizer = csi_pipeline_components['sanitizer']
|
||||
|
||||
# Large dataset
|
||||
large_data = np.random.randn(100, 3, 56, 1000) + 1j * np.random.randn(100, 3, 56, 1000)
|
||||
|
||||
# Act - Process in chunks to test memory optimization
|
||||
chunk_size = 10
|
||||
results = []
|
||||
|
||||
for i in range(0, large_data.shape[0], chunk_size):
|
||||
chunk = large_data[i:i+chunk_size]
|
||||
processed_chunk = processor.process_csi_batch(chunk)
|
||||
sanitized_chunk = sanitizer.sanitize_phase_batch(processed_chunk)
|
||||
results.append(sanitized_chunk)
|
||||
|
||||
# Assert
|
||||
assert len(results) == 10 # 100 samples / 10 chunk_size
|
||||
for result in results:
|
||||
assert result.shape[0] <= chunk_size
|
||||
|
||||
def test_pipeline_supports_concurrent_processing(self, csi_pipeline_components, mock_raw_csi_data):
|
||||
"""Test that pipeline supports concurrent processing of multiple streams"""
|
||||
import threading
|
||||
import queue
|
||||
|
||||
# Arrange
|
||||
processor = csi_pipeline_components['processor']
|
||||
sanitizer = csi_pipeline_components['sanitizer']
|
||||
|
||||
results_queue = queue.Queue()
|
||||
|
||||
def process_stream(stream_id, data):
|
||||
try:
|
||||
processed = processor.process_csi_batch(data)
|
||||
sanitized = sanitizer.sanitize_phase_batch(processed)
|
||||
results_queue.put((stream_id, sanitized))
|
||||
except Exception as e:
|
||||
results_queue.put((stream_id, e))
|
||||
|
||||
# Act - Process multiple streams concurrently
|
||||
threads = []
|
||||
for i in range(3):
|
||||
thread = threading.Thread(
|
||||
target=process_stream,
|
||||
args=(i, mock_raw_csi_data['csi_data'])
|
||||
)
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for all threads to complete
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Assert
|
||||
results = []
|
||||
while not results_queue.empty():
|
||||
results.append(results_queue.get())
|
||||
|
||||
assert len(results) == 3
|
||||
for stream_id, result in results:
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert not isinstance(result, Exception)
|
||||
459
tests/integration/test_inference_pipeline.py
Normal file
459
tests/integration/test_inference_pipeline.py
Normal file
@@ -0,0 +1,459 @@
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from src.core.csi_processor import CSIProcessor
|
||||
from src.core.phase_sanitizer import PhaseSanitizer
|
||||
from src.models.modality_translation import ModalityTranslationNetwork
|
||||
from src.models.densepose_head import DensePoseHead
|
||||
|
||||
|
||||
class TestInferencePipeline:
|
||||
"""Integration tests for inference pipeline following London School TDD principles"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_csi_processor_config(self):
|
||||
"""Configuration for CSI processor"""
|
||||
return {
|
||||
'window_size': 100,
|
||||
'overlap': 0.5,
|
||||
'filter_type': 'butterworth',
|
||||
'filter_order': 4,
|
||||
'cutoff_frequency': 50,
|
||||
'normalization': 'minmax',
|
||||
'outlier_threshold': 3.0
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sanitizer_config(self):
|
||||
"""Configuration for phase sanitizer"""
|
||||
return {
|
||||
'unwrap_method': 'numpy',
|
||||
'smoothing_window': 5,
|
||||
'outlier_threshold': 2.0,
|
||||
'interpolation_method': 'linear',
|
||||
'phase_correction': True
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_translation_config(self):
|
||||
"""Configuration for modality translation network"""
|
||||
return {
|
||||
'input_channels': 6,
|
||||
'output_channels': 256,
|
||||
'hidden_channels': [64, 128, 256],
|
||||
'kernel_sizes': [7, 5, 3],
|
||||
'strides': [2, 2, 1],
|
||||
'dropout_rate': 0.1,
|
||||
'use_attention': True,
|
||||
'attention_heads': 8,
|
||||
'use_residual': True,
|
||||
'activation': 'relu',
|
||||
'normalization': 'batch'
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_densepose_config(self):
|
||||
"""Configuration for DensePose head"""
|
||||
return {
|
||||
'input_channels': 256,
|
||||
'num_body_parts': 24,
|
||||
'num_uv_coordinates': 2,
|
||||
'hidden_channels': [128, 64],
|
||||
'kernel_size': 3,
|
||||
'padding': 1,
|
||||
'dropout_rate': 0.1,
|
||||
'use_deformable_conv': False,
|
||||
'use_fpn': True,
|
||||
'fpn_levels': [2, 3, 4, 5],
|
||||
'output_stride': 4
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def inference_pipeline_components(self, mock_csi_processor_config, mock_sanitizer_config,
|
||||
mock_translation_config, mock_densepose_config):
|
||||
"""Create inference pipeline components for testing"""
|
||||
csi_processor = CSIProcessor(mock_csi_processor_config)
|
||||
phase_sanitizer = PhaseSanitizer(mock_sanitizer_config)
|
||||
translation_network = ModalityTranslationNetwork(mock_translation_config)
|
||||
densepose_head = DensePoseHead(mock_densepose_config)
|
||||
|
||||
return {
|
||||
'csi_processor': csi_processor,
|
||||
'phase_sanitizer': phase_sanitizer,
|
||||
'translation_network': translation_network,
|
||||
'densepose_head': densepose_head
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_raw_csi_input(self):
|
||||
"""Generate mock raw CSI input data"""
|
||||
batch_size = 4
|
||||
antennas = 3
|
||||
subcarriers = 56
|
||||
time_samples = 100
|
||||
|
||||
# Generate complex CSI data
|
||||
real_part = np.random.randn(batch_size, antennas, subcarriers, time_samples)
|
||||
imag_part = np.random.randn(batch_size, antennas, subcarriers, time_samples)
|
||||
|
||||
return real_part + 1j * imag_part
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ground_truth_densepose(self):
|
||||
"""Generate mock ground truth DensePose annotations"""
|
||||
batch_size = 4
|
||||
height = 224
|
||||
width = 224
|
||||
num_parts = 24
|
||||
|
||||
# Segmentation masks
|
||||
seg_masks = torch.randint(0, num_parts + 1, (batch_size, height, width))
|
||||
|
||||
# UV coordinates
|
||||
uv_coords = torch.randn(batch_size, 2, height, width)
|
||||
|
||||
return {
|
||||
'segmentation': seg_masks,
|
||||
'uv_coordinates': uv_coords
|
||||
}
|
||||
|
||||
def test_end_to_end_inference_pipeline_produces_valid_output(self, inference_pipeline_components, mock_raw_csi_input):
|
||||
"""Test that end-to-end inference pipeline produces valid DensePose output"""
|
||||
# Arrange
|
||||
csi_processor = inference_pipeline_components['csi_processor']
|
||||
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||
translation_network = inference_pipeline_components['translation_network']
|
||||
densepose_head = inference_pipeline_components['densepose_head']
|
||||
|
||||
# Set models to evaluation mode
|
||||
translation_network.eval()
|
||||
densepose_head.eval()
|
||||
|
||||
# Act - Run the complete inference pipeline
|
||||
with torch.no_grad():
|
||||
# 1. Process CSI data
|
||||
processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input)
|
||||
|
||||
# 2. Sanitize phase information
|
||||
sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi)
|
||||
|
||||
# 3. Translate CSI to visual features
|
||||
visual_features = translation_network(sanitized_csi)
|
||||
|
||||
# 4. Generate DensePose predictions
|
||||
densepose_output = densepose_head(visual_features)
|
||||
|
||||
# Assert
|
||||
assert densepose_output is not None
|
||||
assert isinstance(densepose_output, dict)
|
||||
assert 'segmentation' in densepose_output
|
||||
assert 'uv_coordinates' in densepose_output
|
||||
|
||||
seg_output = densepose_output['segmentation']
|
||||
uv_output = densepose_output['uv_coordinates']
|
||||
|
||||
# Check output shapes
|
||||
assert seg_output.shape[0] == mock_raw_csi_input.shape[0] # Batch size preserved
|
||||
assert seg_output.shape[1] == 25 # 24 body parts + 1 background
|
||||
assert uv_output.shape[0] == mock_raw_csi_input.shape[0] # Batch size preserved
|
||||
assert uv_output.shape[1] == 2 # U and V coordinates
|
||||
|
||||
# Check output ranges
|
||||
assert torch.all(uv_output >= 0) and torch.all(uv_output <= 1) # UV in [0, 1]
|
||||
|
||||
def test_inference_pipeline_handles_different_batch_sizes(self, inference_pipeline_components):
|
||||
"""Test that inference pipeline handles different batch sizes"""
|
||||
# Arrange
|
||||
csi_processor = inference_pipeline_components['csi_processor']
|
||||
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||
translation_network = inference_pipeline_components['translation_network']
|
||||
densepose_head = inference_pipeline_components['densepose_head']
|
||||
|
||||
# Different batch sizes
|
||||
small_batch = np.random.randn(1, 3, 56, 100) + 1j * np.random.randn(1, 3, 56, 100)
|
||||
large_batch = np.random.randn(8, 3, 56, 100) + 1j * np.random.randn(8, 3, 56, 100)
|
||||
|
||||
# Set models to evaluation mode
|
||||
translation_network.eval()
|
||||
densepose_head.eval()
|
||||
|
||||
# Act
|
||||
with torch.no_grad():
|
||||
# Small batch
|
||||
small_processed = csi_processor.process_csi_batch(small_batch)
|
||||
small_sanitized = phase_sanitizer.sanitize_phase_batch(small_processed)
|
||||
small_features = translation_network(small_sanitized)
|
||||
small_output = densepose_head(small_features)
|
||||
|
||||
# Large batch
|
||||
large_processed = csi_processor.process_csi_batch(large_batch)
|
||||
large_sanitized = phase_sanitizer.sanitize_phase_batch(large_processed)
|
||||
large_features = translation_network(large_sanitized)
|
||||
large_output = densepose_head(large_features)
|
||||
|
||||
# Assert
|
||||
assert small_output['segmentation'].shape[0] == 1
|
||||
assert large_output['segmentation'].shape[0] == 8
|
||||
assert small_output['uv_coordinates'].shape[0] == 1
|
||||
assert large_output['uv_coordinates'].shape[0] == 8
|
||||
|
||||
def test_inference_pipeline_maintains_gradient_flow_during_training(self, inference_pipeline_components,
|
||||
mock_raw_csi_input, mock_ground_truth_densepose):
|
||||
"""Test that inference pipeline maintains gradient flow during training"""
|
||||
# Arrange
|
||||
csi_processor = inference_pipeline_components['csi_processor']
|
||||
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||
translation_network = inference_pipeline_components['translation_network']
|
||||
densepose_head = inference_pipeline_components['densepose_head']
|
||||
|
||||
# Set models to training mode
|
||||
translation_network.train()
|
||||
densepose_head.train()
|
||||
|
||||
# Create optimizer
|
||||
optimizer = torch.optim.Adam(
|
||||
list(translation_network.parameters()) + list(densepose_head.parameters()),
|
||||
lr=0.001
|
||||
)
|
||||
|
||||
# Act
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input)
|
||||
sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi)
|
||||
visual_features = translation_network(sanitized_csi)
|
||||
densepose_output = densepose_head(visual_features)
|
||||
|
||||
# Resize ground truth to match output
|
||||
seg_target = torch.nn.functional.interpolate(
|
||||
mock_ground_truth_densepose['segmentation'].float().unsqueeze(1),
|
||||
size=densepose_output['segmentation'].shape[2:],
|
||||
mode='nearest'
|
||||
).squeeze(1).long()
|
||||
|
||||
uv_target = torch.nn.functional.interpolate(
|
||||
mock_ground_truth_densepose['uv_coordinates'],
|
||||
size=densepose_output['uv_coordinates'].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Compute loss
|
||||
loss = densepose_head.compute_total_loss(densepose_output, seg_target, uv_target)
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
||||
# Assert - Check that gradients are computed
|
||||
for param in translation_network.parameters():
|
||||
if param.requires_grad:
|
||||
assert param.grad is not None
|
||||
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
|
||||
|
||||
for param in densepose_head.parameters():
|
||||
if param.requires_grad:
|
||||
assert param.grad is not None
|
||||
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
|
||||
|
||||
def test_inference_pipeline_performance_benchmarking(self, inference_pipeline_components, mock_raw_csi_input):
|
||||
"""Test inference pipeline performance for real-time requirements"""
|
||||
import time
|
||||
|
||||
# Arrange
|
||||
csi_processor = inference_pipeline_components['csi_processor']
|
||||
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||
translation_network = inference_pipeline_components['translation_network']
|
||||
densepose_head = inference_pipeline_components['densepose_head']
|
||||
|
||||
# Set models to evaluation mode for inference
|
||||
translation_network.eval()
|
||||
densepose_head.eval()
|
||||
|
||||
# Warm up (first inference is often slower)
|
||||
with torch.no_grad():
|
||||
processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input)
|
||||
sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi)
|
||||
visual_features = translation_network(sanitized_csi)
|
||||
_ = densepose_head(visual_features)
|
||||
|
||||
# Act - Measure inference time
|
||||
start_time = time.time()
|
||||
|
||||
with torch.no_grad():
|
||||
processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input)
|
||||
sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi)
|
||||
visual_features = translation_network(sanitized_csi)
|
||||
densepose_output = densepose_head(visual_features)
|
||||
|
||||
end_time = time.time()
|
||||
inference_time = end_time - start_time
|
||||
|
||||
# Assert - Should meet real-time requirements (< 50ms for batch of 4)
|
||||
assert inference_time < 0.05, f"Inference took {inference_time:.3f}s, expected < 0.05s"
|
||||
|
||||
def test_inference_pipeline_handles_edge_cases(self, inference_pipeline_components):
|
||||
"""Test that inference pipeline handles edge cases gracefully"""
|
||||
# Arrange
|
||||
csi_processor = inference_pipeline_components['csi_processor']
|
||||
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||
translation_network = inference_pipeline_components['translation_network']
|
||||
densepose_head = inference_pipeline_components['densepose_head']
|
||||
|
||||
# Edge cases
|
||||
zero_input = np.zeros((1, 3, 56, 100), dtype=complex)
|
||||
noisy_input = np.random.randn(1, 3, 56, 100) * 100 + 1j * np.random.randn(1, 3, 56, 100) * 100
|
||||
|
||||
translation_network.eval()
|
||||
densepose_head.eval()
|
||||
|
||||
# Act & Assert
|
||||
with torch.no_grad():
|
||||
# Zero input
|
||||
zero_processed = csi_processor.process_csi_batch(zero_input)
|
||||
zero_sanitized = phase_sanitizer.sanitize_phase_batch(zero_processed)
|
||||
zero_features = translation_network(zero_sanitized)
|
||||
zero_output = densepose_head(zero_features)
|
||||
|
||||
assert not torch.isnan(zero_output['segmentation']).any()
|
||||
assert not torch.isnan(zero_output['uv_coordinates']).any()
|
||||
|
||||
# Noisy input
|
||||
noisy_processed = csi_processor.process_csi_batch(noisy_input)
|
||||
noisy_sanitized = phase_sanitizer.sanitize_phase_batch(noisy_processed)
|
||||
noisy_features = translation_network(noisy_sanitized)
|
||||
noisy_output = densepose_head(noisy_features)
|
||||
|
||||
assert not torch.isnan(noisy_output['segmentation']).any()
|
||||
assert not torch.isnan(noisy_output['uv_coordinates']).any()
|
||||
|
||||
def test_inference_pipeline_memory_efficiency(self, inference_pipeline_components):
|
||||
"""Test that inference pipeline is memory efficient"""
|
||||
# Arrange
|
||||
csi_processor = inference_pipeline_components['csi_processor']
|
||||
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||
translation_network = inference_pipeline_components['translation_network']
|
||||
densepose_head = inference_pipeline_components['densepose_head']
|
||||
|
||||
# Large batch to test memory usage
|
||||
large_input = np.random.randn(16, 3, 56, 100) + 1j * np.random.randn(16, 3, 56, 100)
|
||||
|
||||
translation_network.eval()
|
||||
densepose_head.eval()
|
||||
|
||||
# Act - Process in chunks to manage memory
|
||||
chunk_size = 4
|
||||
outputs = []
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(0, large_input.shape[0], chunk_size):
|
||||
chunk = large_input[i:i+chunk_size]
|
||||
|
||||
processed_chunk = csi_processor.process_csi_batch(chunk)
|
||||
sanitized_chunk = phase_sanitizer.sanitize_phase_batch(processed_chunk)
|
||||
feature_chunk = translation_network(sanitized_chunk)
|
||||
output_chunk = densepose_head(feature_chunk)
|
||||
|
||||
outputs.append(output_chunk)
|
||||
|
||||
# Clear intermediate tensors to free memory
|
||||
del processed_chunk, sanitized_chunk, feature_chunk
|
||||
|
||||
# Assert
|
||||
assert len(outputs) == 4 # 16 samples / 4 chunk_size
|
||||
for output in outputs:
|
||||
assert output['segmentation'].shape[0] <= chunk_size
|
||||
|
||||
def test_inference_pipeline_deterministic_output(self, inference_pipeline_components, mock_raw_csi_input):
|
||||
"""Test that inference pipeline produces deterministic output in eval mode"""
|
||||
# Arrange
|
||||
csi_processor = inference_pipeline_components['csi_processor']
|
||||
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||
translation_network = inference_pipeline_components['translation_network']
|
||||
densepose_head = inference_pipeline_components['densepose_head']
|
||||
|
||||
# Set models to evaluation mode
|
||||
translation_network.eval()
|
||||
densepose_head.eval()
|
||||
|
||||
# Act - Run inference twice
|
||||
with torch.no_grad():
|
||||
# First run
|
||||
processed_csi_1 = csi_processor.process_csi_batch(mock_raw_csi_input)
|
||||
sanitized_csi_1 = phase_sanitizer.sanitize_phase_batch(processed_csi_1)
|
||||
visual_features_1 = translation_network(sanitized_csi_1)
|
||||
output_1 = densepose_head(visual_features_1)
|
||||
|
||||
# Second run
|
||||
processed_csi_2 = csi_processor.process_csi_batch(mock_raw_csi_input)
|
||||
sanitized_csi_2 = phase_sanitizer.sanitize_phase_batch(processed_csi_2)
|
||||
visual_features_2 = translation_network(sanitized_csi_2)
|
||||
output_2 = densepose_head(visual_features_2)
|
||||
|
||||
# Assert - Outputs should be identical in eval mode
|
||||
assert torch.allclose(output_1['segmentation'], output_2['segmentation'], atol=1e-6)
|
||||
assert torch.allclose(output_1['uv_coordinates'], output_2['uv_coordinates'], atol=1e-6)
|
||||
|
||||
def test_inference_pipeline_confidence_estimation(self, inference_pipeline_components, mock_raw_csi_input):
|
||||
"""Test that inference pipeline provides confidence estimates"""
|
||||
# Arrange
|
||||
csi_processor = inference_pipeline_components['csi_processor']
|
||||
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||
translation_network = inference_pipeline_components['translation_network']
|
||||
densepose_head = inference_pipeline_components['densepose_head']
|
||||
|
||||
translation_network.eval()
|
||||
densepose_head.eval()
|
||||
|
||||
# Act
|
||||
with torch.no_grad():
|
||||
processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input)
|
||||
sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi)
|
||||
visual_features = translation_network(sanitized_csi)
|
||||
densepose_output = densepose_head(visual_features)
|
||||
|
||||
# Get confidence estimates
|
||||
confidence = densepose_head.get_prediction_confidence(densepose_output)
|
||||
|
||||
# Assert
|
||||
assert 'segmentation_confidence' in confidence
|
||||
assert 'uv_confidence' in confidence
|
||||
|
||||
seg_conf = confidence['segmentation_confidence']
|
||||
uv_conf = confidence['uv_confidence']
|
||||
|
||||
assert seg_conf.shape[0] == mock_raw_csi_input.shape[0]
|
||||
assert uv_conf.shape[0] == mock_raw_csi_input.shape[0]
|
||||
assert torch.all(seg_conf >= 0) and torch.all(seg_conf <= 1)
|
||||
assert torch.all(uv_conf >= 0)
|
||||
|
||||
def test_inference_pipeline_post_processing(self, inference_pipeline_components, mock_raw_csi_input):
|
||||
"""Test that inference pipeline post-processes predictions correctly"""
|
||||
# Arrange
|
||||
csi_processor = inference_pipeline_components['csi_processor']
|
||||
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||
translation_network = inference_pipeline_components['translation_network']
|
||||
densepose_head = inference_pipeline_components['densepose_head']
|
||||
|
||||
translation_network.eval()
|
||||
densepose_head.eval()
|
||||
|
||||
# Act
|
||||
with torch.no_grad():
|
||||
processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input)
|
||||
sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi)
|
||||
visual_features = translation_network(sanitized_csi)
|
||||
raw_output = densepose_head(visual_features)
|
||||
|
||||
# Post-process predictions
|
||||
processed_output = densepose_head.post_process_predictions(raw_output)
|
||||
|
||||
# Assert
|
||||
assert 'body_parts' in processed_output
|
||||
assert 'uv_coordinates' in processed_output
|
||||
assert 'confidence_scores' in processed_output
|
||||
|
||||
body_parts = processed_output['body_parts']
|
||||
assert body_parts.dtype == torch.long # Class indices
|
||||
assert torch.all(body_parts >= 0) and torch.all(body_parts <= 24) # Valid class range
|
||||
264
tests/unit/test_csi_extractor.py
Normal file
264
tests/unit/test_csi_extractor.py
Normal file
@@ -0,0 +1,264 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
import torch
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from src.hardware.csi_extractor import CSIExtractor, CSIExtractionError
|
||||
|
||||
|
||||
class TestCSIExtractor:
|
||||
"""Test suite for CSI Extractor following London School TDD principles"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Configuration for CSI extractor"""
|
||||
return {
|
||||
'interface': 'wlan0',
|
||||
'channel': 6,
|
||||
'bandwidth': 20,
|
||||
'sample_rate': 1000,
|
||||
'buffer_size': 1024,
|
||||
'extraction_timeout': 5.0
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_router_interface(self):
|
||||
"""Mock router interface for testing"""
|
||||
mock_router = Mock()
|
||||
mock_router.is_connected = True
|
||||
mock_router.execute_command = Mock()
|
||||
return mock_router
|
||||
|
||||
@pytest.fixture
|
||||
def csi_extractor(self, mock_config, mock_router_interface):
|
||||
"""Create CSI extractor instance for testing"""
|
||||
return CSIExtractor(mock_config, mock_router_interface)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_csi_data(self):
|
||||
"""Generate synthetic CSI data for testing"""
|
||||
# Simulate CSI data: complex values for multiple subcarriers
|
||||
num_subcarriers = 56
|
||||
num_antennas = 3
|
||||
amplitude = np.random.uniform(0.1, 2.0, (num_antennas, num_subcarriers))
|
||||
phase = np.random.uniform(-np.pi, np.pi, (num_antennas, num_subcarriers))
|
||||
return amplitude * np.exp(1j * phase)
|
||||
|
||||
def test_extractor_initialization_creates_correct_configuration(self, mock_config, mock_router_interface):
|
||||
"""Test that CSI extractor initializes with correct configuration"""
|
||||
# Act
|
||||
extractor = CSIExtractor(mock_config, mock_router_interface)
|
||||
|
||||
# Assert
|
||||
assert extractor is not None
|
||||
assert extractor.interface == mock_config['interface']
|
||||
assert extractor.channel == mock_config['channel']
|
||||
assert extractor.bandwidth == mock_config['bandwidth']
|
||||
assert extractor.sample_rate == mock_config['sample_rate']
|
||||
assert extractor.buffer_size == mock_config['buffer_size']
|
||||
assert extractor.extraction_timeout == mock_config['extraction_timeout']
|
||||
assert extractor.router_interface == mock_router_interface
|
||||
assert not extractor.is_extracting
|
||||
|
||||
def test_start_extraction_configures_monitor_mode(self, csi_extractor, mock_router_interface):
|
||||
"""Test that start_extraction configures monitor mode"""
|
||||
# Arrange
|
||||
mock_router_interface.enable_monitor_mode.return_value = True
|
||||
mock_router_interface.execute_command.return_value = "CSI extraction started"
|
||||
|
||||
# Act
|
||||
result = csi_extractor.start_extraction()
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert csi_extractor.is_extracting is True
|
||||
mock_router_interface.enable_monitor_mode.assert_called_once_with(csi_extractor.interface)
|
||||
|
||||
def test_start_extraction_handles_monitor_mode_failure(self, csi_extractor, mock_router_interface):
|
||||
"""Test that start_extraction handles monitor mode configuration failure"""
|
||||
# Arrange
|
||||
mock_router_interface.enable_monitor_mode.return_value = False
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(CSIExtractionError):
|
||||
csi_extractor.start_extraction()
|
||||
|
||||
assert csi_extractor.is_extracting is False
|
||||
|
||||
def test_stop_extraction_disables_monitor_mode(self, csi_extractor, mock_router_interface):
|
||||
"""Test that stop_extraction disables monitor mode"""
|
||||
# Arrange
|
||||
mock_router_interface.enable_monitor_mode.return_value = True
|
||||
mock_router_interface.disable_monitor_mode.return_value = True
|
||||
mock_router_interface.execute_command.return_value = "CSI extraction started"
|
||||
|
||||
csi_extractor.start_extraction()
|
||||
|
||||
# Act
|
||||
result = csi_extractor.stop_extraction()
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert csi_extractor.is_extracting is False
|
||||
mock_router_interface.disable_monitor_mode.assert_called_once_with(csi_extractor.interface)
|
||||
|
||||
def test_extract_csi_data_returns_valid_format(self, csi_extractor, mock_router_interface, mock_csi_data):
|
||||
"""Test that extract_csi_data returns data in valid format"""
|
||||
# Arrange
|
||||
mock_router_interface.enable_monitor_mode.return_value = True
|
||||
mock_router_interface.execute_command.return_value = "CSI extraction started"
|
||||
|
||||
# Mock the CSI data extraction
|
||||
with patch.object(csi_extractor, '_parse_csi_output', return_value=mock_csi_data):
|
||||
csi_extractor.start_extraction()
|
||||
|
||||
# Act
|
||||
csi_data = csi_extractor.extract_csi_data()
|
||||
|
||||
# Assert
|
||||
assert csi_data is not None
|
||||
assert isinstance(csi_data, np.ndarray)
|
||||
assert csi_data.dtype == np.complex128
|
||||
assert csi_data.shape == mock_csi_data.shape
|
||||
|
||||
def test_extract_csi_data_requires_active_extraction(self, csi_extractor):
|
||||
"""Test that extract_csi_data requires active extraction"""
|
||||
# Act & Assert
|
||||
with pytest.raises(CSIExtractionError):
|
||||
csi_extractor.extract_csi_data()
|
||||
|
||||
def test_extract_csi_data_handles_timeout(self, csi_extractor, mock_router_interface):
|
||||
"""Test that extract_csi_data handles extraction timeout"""
|
||||
# Arrange
|
||||
mock_router_interface.enable_monitor_mode.return_value = True
|
||||
mock_router_interface.execute_command.side_effect = [
|
||||
"CSI extraction started",
|
||||
Exception("Timeout")
|
||||
]
|
||||
|
||||
csi_extractor.start_extraction()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(CSIExtractionError):
|
||||
csi_extractor.extract_csi_data()
|
||||
|
||||
def test_convert_to_tensor_produces_correct_format(self, csi_extractor, mock_csi_data):
|
||||
"""Test that convert_to_tensor produces correctly formatted tensor"""
|
||||
# Act
|
||||
tensor = csi_extractor.convert_to_tensor(mock_csi_data)
|
||||
|
||||
# Assert
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
assert tensor.dtype == torch.float32
|
||||
assert tensor.shape[0] == mock_csi_data.shape[0] * 2 # Real and imaginary parts
|
||||
assert tensor.shape[1] == mock_csi_data.shape[1]
|
||||
|
||||
def test_convert_to_tensor_handles_invalid_input(self, csi_extractor):
|
||||
"""Test that convert_to_tensor handles invalid input"""
|
||||
# Arrange
|
||||
invalid_data = "not an array"
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
csi_extractor.convert_to_tensor(invalid_data)
|
||||
|
||||
def test_get_extraction_stats_returns_valid_statistics(self, csi_extractor, mock_router_interface):
|
||||
"""Test that get_extraction_stats returns valid statistics"""
|
||||
# Arrange
|
||||
mock_router_interface.enable_monitor_mode.return_value = True
|
||||
mock_router_interface.execute_command.return_value = "CSI extraction started"
|
||||
|
||||
csi_extractor.start_extraction()
|
||||
|
||||
# Act
|
||||
stats = csi_extractor.get_extraction_stats()
|
||||
|
||||
# Assert
|
||||
assert stats is not None
|
||||
assert isinstance(stats, dict)
|
||||
assert 'samples_extracted' in stats
|
||||
assert 'extraction_rate' in stats
|
||||
assert 'buffer_utilization' in stats
|
||||
assert 'last_extraction_time' in stats
|
||||
|
||||
def test_set_channel_configures_wifi_channel(self, csi_extractor, mock_router_interface):
|
||||
"""Test that set_channel configures WiFi channel"""
|
||||
# Arrange
|
||||
new_channel = 11
|
||||
mock_router_interface.execute_command.return_value = f"Channel set to {new_channel}"
|
||||
|
||||
# Act
|
||||
result = csi_extractor.set_channel(new_channel)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert csi_extractor.channel == new_channel
|
||||
mock_router_interface.execute_command.assert_called()
|
||||
|
||||
def test_set_channel_validates_channel_range(self, csi_extractor):
|
||||
"""Test that set_channel validates channel range"""
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
csi_extractor.set_channel(0) # Invalid channel
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
csi_extractor.set_channel(15) # Invalid channel
|
||||
|
||||
def test_extractor_supports_context_manager(self, csi_extractor, mock_router_interface):
|
||||
"""Test that CSI extractor supports context manager protocol"""
|
||||
# Arrange
|
||||
mock_router_interface.enable_monitor_mode.return_value = True
|
||||
mock_router_interface.disable_monitor_mode.return_value = True
|
||||
mock_router_interface.execute_command.return_value = "CSI extraction started"
|
||||
|
||||
# Act
|
||||
with csi_extractor as extractor:
|
||||
# Assert
|
||||
assert extractor.is_extracting is True
|
||||
|
||||
# Assert - extraction should be stopped after context
|
||||
assert csi_extractor.is_extracting is False
|
||||
|
||||
def test_extractor_validates_configuration(self, mock_router_interface):
|
||||
"""Test that CSI extractor validates configuration parameters"""
|
||||
# Arrange
|
||||
invalid_config = {
|
||||
'interface': '', # Invalid interface
|
||||
'channel': 6,
|
||||
'bandwidth': 20
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
CSIExtractor(invalid_config, mock_router_interface)
|
||||
|
||||
def test_parse_csi_output_processes_raw_data(self, csi_extractor):
|
||||
"""Test that _parse_csi_output processes raw CSI data correctly"""
|
||||
# Arrange
|
||||
raw_output = "CSI_DATA: 1.5+0.5j,2.0-1.0j,0.8+1.2j"
|
||||
|
||||
# Act
|
||||
parsed_data = csi_extractor._parse_csi_output(raw_output)
|
||||
|
||||
# Assert
|
||||
assert parsed_data is not None
|
||||
assert isinstance(parsed_data, np.ndarray)
|
||||
assert parsed_data.dtype == np.complex128
|
||||
|
||||
def test_buffer_management_handles_overflow(self, csi_extractor, mock_router_interface, mock_csi_data):
|
||||
"""Test that buffer management handles overflow correctly"""
|
||||
# Arrange
|
||||
mock_router_interface.enable_monitor_mode.return_value = True
|
||||
mock_router_interface.execute_command.return_value = "CSI extraction started"
|
||||
|
||||
with patch.object(csi_extractor, '_parse_csi_output', return_value=mock_csi_data):
|
||||
csi_extractor.start_extraction()
|
||||
|
||||
# Fill buffer beyond capacity
|
||||
for _ in range(csi_extractor.buffer_size + 10):
|
||||
csi_extractor._add_to_buffer(mock_csi_data)
|
||||
|
||||
# Act
|
||||
stats = csi_extractor.get_extraction_stats()
|
||||
|
||||
# Assert
|
||||
assert stats['buffer_utilization'] <= 1.0 # Should not exceed 100%
|
||||
@@ -3,27 +3,27 @@ import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch
|
||||
from src.models.densepose_head import DensePoseHead
|
||||
from src.models.densepose_head import DensePoseHead, DensePoseError
|
||||
|
||||
|
||||
class TestDensePoseHead:
|
||||
"""Test suite for DensePose Head following London School TDD principles"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_feature_input(self):
|
||||
"""Generate synthetic feature input tensor for testing"""
|
||||
# Batch size 2, 512 channels, 56 height, 100 width (from modality translation)
|
||||
return torch.randn(2, 512, 56, 100)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Configuration for DensePose head"""
|
||||
return {
|
||||
'input_channels': 512,
|
||||
'num_body_parts': 24, # Standard DensePose body parts
|
||||
'num_uv_coordinates': 2, # U and V coordinates
|
||||
'hidden_dim': 256,
|
||||
'dropout_rate': 0.1
|
||||
'input_channels': 256,
|
||||
'num_body_parts': 24,
|
||||
'num_uv_coordinates': 2,
|
||||
'hidden_channels': [128, 64],
|
||||
'kernel_size': 3,
|
||||
'padding': 1,
|
||||
'dropout_rate': 0.1,
|
||||
'use_deformable_conv': False,
|
||||
'use_fpn': True,
|
||||
'fpn_levels': [2, 3, 4, 5],
|
||||
'output_stride': 4
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
@@ -31,6 +31,33 @@ class TestDensePoseHead:
|
||||
"""Create DensePose head instance for testing"""
|
||||
return DensePoseHead(mock_config)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_feature_input(self):
|
||||
"""Generate mock feature input tensor"""
|
||||
batch_size = 2
|
||||
channels = 256
|
||||
height = 56
|
||||
width = 56
|
||||
return torch.randn(batch_size, channels, height, width)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_target_masks(self):
|
||||
"""Generate mock target segmentation masks"""
|
||||
batch_size = 2
|
||||
num_parts = 24
|
||||
height = 224
|
||||
width = 224
|
||||
return torch.randint(0, num_parts + 1, (batch_size, height, width))
|
||||
|
||||
@pytest.fixture
|
||||
def mock_target_uv(self):
|
||||
"""Generate mock target UV coordinates"""
|
||||
batch_size = 2
|
||||
num_coords = 2
|
||||
height = 224
|
||||
width = 224
|
||||
return torch.randn(batch_size, num_coords, height, width)
|
||||
|
||||
def test_head_initialization_creates_correct_architecture(self, mock_config):
|
||||
"""Test that DensePose head initializes with correct architecture"""
|
||||
# Act
|
||||
@@ -39,135 +66,302 @@ class TestDensePoseHead:
|
||||
# Assert
|
||||
assert head is not None
|
||||
assert isinstance(head, nn.Module)
|
||||
assert hasattr(head, 'segmentation_head')
|
||||
assert hasattr(head, 'uv_regression_head')
|
||||
assert head.input_channels == mock_config['input_channels']
|
||||
assert head.num_body_parts == mock_config['num_body_parts']
|
||||
assert head.num_uv_coordinates == mock_config['num_uv_coordinates']
|
||||
assert head.use_fpn == mock_config['use_fpn']
|
||||
assert hasattr(head, 'segmentation_head')
|
||||
assert hasattr(head, 'uv_regression_head')
|
||||
if mock_config['use_fpn']:
|
||||
assert hasattr(head, 'fpn')
|
||||
|
||||
def test_forward_pass_produces_correct_output_shapes(self, densepose_head, mock_feature_input):
|
||||
"""Test that forward pass produces correctly shaped outputs"""
|
||||
def test_forward_pass_produces_correct_output_format(self, densepose_head, mock_feature_input):
|
||||
"""Test that forward pass produces correctly formatted output"""
|
||||
# Act
|
||||
with torch.no_grad():
|
||||
segmentation, uv_coords = densepose_head(mock_feature_input)
|
||||
output = densepose_head(mock_feature_input)
|
||||
|
||||
# Assert
|
||||
assert segmentation is not None
|
||||
assert uv_coords is not None
|
||||
assert isinstance(segmentation, torch.Tensor)
|
||||
assert isinstance(uv_coords, torch.Tensor)
|
||||
assert output is not None
|
||||
assert isinstance(output, dict)
|
||||
assert 'segmentation' in output
|
||||
assert 'uv_coordinates' in output
|
||||
|
||||
# Check segmentation output shape
|
||||
assert segmentation.shape[0] == mock_feature_input.shape[0] # Batch size preserved
|
||||
assert segmentation.shape[1] == densepose_head.num_body_parts # Correct number of body parts
|
||||
assert segmentation.shape[2:] == mock_feature_input.shape[2:] # Spatial dimensions preserved
|
||||
seg_output = output['segmentation']
|
||||
uv_output = output['uv_coordinates']
|
||||
|
||||
# Check UV coordinates output shape
|
||||
assert uv_coords.shape[0] == mock_feature_input.shape[0] # Batch size preserved
|
||||
assert uv_coords.shape[1] == densepose_head.num_uv_coordinates # U and V coordinates
|
||||
assert uv_coords.shape[2:] == mock_feature_input.shape[2:] # Spatial dimensions preserved
|
||||
assert isinstance(seg_output, torch.Tensor)
|
||||
assert isinstance(uv_output, torch.Tensor)
|
||||
assert seg_output.shape[0] == mock_feature_input.shape[0] # Batch size preserved
|
||||
assert uv_output.shape[0] == mock_feature_input.shape[0] # Batch size preserved
|
||||
|
||||
def test_segmentation_output_has_valid_probabilities(self, densepose_head, mock_feature_input):
|
||||
"""Test that segmentation output has valid probability distributions"""
|
||||
def test_segmentation_head_produces_correct_shape(self, densepose_head, mock_feature_input):
|
||||
"""Test that segmentation head produces correct output shape"""
|
||||
# Act
|
||||
with torch.no_grad():
|
||||
segmentation, _ = densepose_head(mock_feature_input)
|
||||
output = densepose_head(mock_feature_input)
|
||||
seg_output = output['segmentation']
|
||||
|
||||
# Assert
|
||||
# After softmax, values should be between 0 and 1
|
||||
assert torch.all(segmentation >= 0.0)
|
||||
assert torch.all(segmentation <= 1.0)
|
||||
|
||||
# Sum across body parts dimension should be approximately 1
|
||||
part_sums = torch.sum(segmentation, dim=1)
|
||||
assert torch.allclose(part_sums, torch.ones_like(part_sums), atol=1e-5)
|
||||
expected_channels = densepose_head.num_body_parts + 1 # +1 for background
|
||||
assert seg_output.shape[1] == expected_channels
|
||||
assert seg_output.shape[2] >= mock_feature_input.shape[2] # Height upsampled
|
||||
assert seg_output.shape[3] >= mock_feature_input.shape[3] # Width upsampled
|
||||
|
||||
def test_uv_coordinates_output_in_valid_range(self, densepose_head, mock_feature_input):
|
||||
"""Test that UV coordinates are in valid range [0, 1]"""
|
||||
def test_uv_regression_head_produces_correct_shape(self, densepose_head, mock_feature_input):
|
||||
"""Test that UV regression head produces correct output shape"""
|
||||
# Act
|
||||
with torch.no_grad():
|
||||
_, uv_coords = densepose_head(mock_feature_input)
|
||||
output = densepose_head(mock_feature_input)
|
||||
uv_output = output['uv_coordinates']
|
||||
|
||||
# Assert
|
||||
# UV coordinates should be in range [0, 1] after sigmoid
|
||||
assert torch.all(uv_coords >= 0.0)
|
||||
assert torch.all(uv_coords <= 1.0)
|
||||
assert uv_output.shape[1] == densepose_head.num_uv_coordinates
|
||||
assert uv_output.shape[2] >= mock_feature_input.shape[2] # Height upsampled
|
||||
assert uv_output.shape[3] >= mock_feature_input.shape[3] # Width upsampled
|
||||
|
||||
def test_head_handles_different_batch_sizes(self, densepose_head):
|
||||
"""Test that head handles different batch sizes correctly"""
|
||||
def test_compute_segmentation_loss_measures_pixel_classification(self, densepose_head, mock_feature_input, mock_target_masks):
|
||||
"""Test that compute_segmentation_loss measures pixel classification accuracy"""
|
||||
# Arrange
|
||||
batch_sizes = [1, 4, 8]
|
||||
output = densepose_head(mock_feature_input)
|
||||
seg_logits = output['segmentation']
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
input_tensor = torch.randn(batch_size, 512, 56, 100)
|
||||
|
||||
# Act
|
||||
with torch.no_grad():
|
||||
segmentation, uv_coords = densepose_head(input_tensor)
|
||||
|
||||
# Assert
|
||||
assert segmentation.shape[0] == batch_size
|
||||
assert uv_coords.shape[0] == batch_size
|
||||
|
||||
def test_head_is_trainable(self, densepose_head, mock_feature_input):
|
||||
"""Test that head parameters are trainable"""
|
||||
# Arrange
|
||||
seg_criterion = nn.CrossEntropyLoss()
|
||||
uv_criterion = nn.MSELoss()
|
||||
|
||||
# Create targets with correct shapes
|
||||
seg_target = torch.randint(0, 24, (2, 56, 100)) # Class indices for segmentation
|
||||
uv_target = torch.rand(2, 2, 56, 100) # UV coordinates target
|
||||
# Resize target to match output
|
||||
target_resized = torch.nn.functional.interpolate(
|
||||
mock_target_masks.float().unsqueeze(1),
|
||||
size=seg_logits.shape[2:],
|
||||
mode='nearest'
|
||||
).squeeze(1).long()
|
||||
|
||||
# Act
|
||||
segmentation, uv_coords = densepose_head(mock_feature_input)
|
||||
seg_loss = seg_criterion(segmentation, seg_target)
|
||||
uv_loss = uv_criterion(uv_coords, uv_target)
|
||||
total_loss = seg_loss + uv_loss
|
||||
total_loss.backward()
|
||||
loss = densepose_head.compute_segmentation_loss(seg_logits, target_resized)
|
||||
|
||||
# Assert
|
||||
assert loss is not None
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.dim() == 0 # Scalar loss
|
||||
assert loss.item() >= 0 # Loss should be non-negative
|
||||
|
||||
def test_compute_uv_loss_measures_coordinate_regression(self, densepose_head, mock_feature_input, mock_target_uv):
|
||||
"""Test that compute_uv_loss measures UV coordinate regression accuracy"""
|
||||
# Arrange
|
||||
output = densepose_head(mock_feature_input)
|
||||
uv_pred = output['uv_coordinates']
|
||||
|
||||
# Resize target to match output
|
||||
target_resized = torch.nn.functional.interpolate(
|
||||
mock_target_uv,
|
||||
size=uv_pred.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Act
|
||||
loss = densepose_head.compute_uv_loss(uv_pred, target_resized)
|
||||
|
||||
# Assert
|
||||
assert loss is not None
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.dim() == 0 # Scalar loss
|
||||
assert loss.item() >= 0 # Loss should be non-negative
|
||||
|
||||
def test_compute_total_loss_combines_segmentation_and_uv_losses(self, densepose_head, mock_feature_input, mock_target_masks, mock_target_uv):
|
||||
"""Test that compute_total_loss combines segmentation and UV losses"""
|
||||
# Arrange
|
||||
output = densepose_head(mock_feature_input)
|
||||
|
||||
# Resize targets to match outputs
|
||||
seg_target = torch.nn.functional.interpolate(
|
||||
mock_target_masks.float().unsqueeze(1),
|
||||
size=output['segmentation'].shape[2:],
|
||||
mode='nearest'
|
||||
).squeeze(1).long()
|
||||
|
||||
uv_target = torch.nn.functional.interpolate(
|
||||
mock_target_uv,
|
||||
size=output['uv_coordinates'].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Act
|
||||
total_loss = densepose_head.compute_total_loss(output, seg_target, uv_target)
|
||||
seg_loss = densepose_head.compute_segmentation_loss(output['segmentation'], seg_target)
|
||||
uv_loss = densepose_head.compute_uv_loss(output['uv_coordinates'], uv_target)
|
||||
|
||||
# Assert
|
||||
assert total_loss is not None
|
||||
assert isinstance(total_loss, torch.Tensor)
|
||||
assert total_loss.item() > 0
|
||||
# Check that gradients are computed
|
||||
# Total loss should be combination of individual losses
|
||||
expected_total = seg_loss + uv_loss
|
||||
assert torch.allclose(total_loss, expected_total, atol=1e-6)
|
||||
|
||||
def test_fpn_integration_enhances_multi_scale_features(self, mock_config, mock_feature_input):
|
||||
"""Test that FPN integration enhances multi-scale feature processing"""
|
||||
# Arrange
|
||||
config_with_fpn = mock_config.copy()
|
||||
config_with_fpn['use_fpn'] = True
|
||||
|
||||
config_without_fpn = mock_config.copy()
|
||||
config_without_fpn['use_fpn'] = False
|
||||
|
||||
head_with_fpn = DensePoseHead(config_with_fpn)
|
||||
head_without_fpn = DensePoseHead(config_without_fpn)
|
||||
|
||||
# Act
|
||||
output_with_fpn = head_with_fpn(mock_feature_input)
|
||||
output_without_fpn = head_without_fpn(mock_feature_input)
|
||||
|
||||
# Assert
|
||||
assert output_with_fpn['segmentation'].shape == output_without_fpn['segmentation'].shape
|
||||
assert output_with_fpn['uv_coordinates'].shape == output_without_fpn['uv_coordinates'].shape
|
||||
# Outputs should be different due to FPN
|
||||
assert not torch.allclose(output_with_fpn['segmentation'], output_without_fpn['segmentation'], atol=1e-6)
|
||||
|
||||
def test_get_prediction_confidence_provides_uncertainty_estimates(self, densepose_head, mock_feature_input):
|
||||
"""Test that get_prediction_confidence provides uncertainty estimates"""
|
||||
# Arrange
|
||||
output = densepose_head(mock_feature_input)
|
||||
|
||||
# Act
|
||||
confidence = densepose_head.get_prediction_confidence(output)
|
||||
|
||||
# Assert
|
||||
assert confidence is not None
|
||||
assert isinstance(confidence, dict)
|
||||
assert 'segmentation_confidence' in confidence
|
||||
assert 'uv_confidence' in confidence
|
||||
|
||||
seg_conf = confidence['segmentation_confidence']
|
||||
uv_conf = confidence['uv_confidence']
|
||||
|
||||
assert isinstance(seg_conf, torch.Tensor)
|
||||
assert isinstance(uv_conf, torch.Tensor)
|
||||
assert seg_conf.shape[0] == mock_feature_input.shape[0]
|
||||
assert uv_conf.shape[0] == mock_feature_input.shape[0]
|
||||
|
||||
def test_post_process_predictions_formats_output(self, densepose_head, mock_feature_input):
|
||||
"""Test that post_process_predictions formats output correctly"""
|
||||
# Arrange
|
||||
raw_output = densepose_head(mock_feature_input)
|
||||
|
||||
# Act
|
||||
processed = densepose_head.post_process_predictions(raw_output)
|
||||
|
||||
# Assert
|
||||
assert processed is not None
|
||||
assert isinstance(processed, dict)
|
||||
assert 'body_parts' in processed
|
||||
assert 'uv_coordinates' in processed
|
||||
assert 'confidence_scores' in processed
|
||||
|
||||
def test_training_mode_enables_dropout(self, densepose_head, mock_feature_input):
|
||||
"""Test that training mode enables dropout for regularization"""
|
||||
# Arrange
|
||||
densepose_head.train()
|
||||
|
||||
# Act
|
||||
output1 = densepose_head(mock_feature_input)
|
||||
output2 = densepose_head(mock_feature_input)
|
||||
|
||||
# Assert - outputs should be different due to dropout
|
||||
assert not torch.allclose(output1['segmentation'], output2['segmentation'], atol=1e-6)
|
||||
assert not torch.allclose(output1['uv_coordinates'], output2['uv_coordinates'], atol=1e-6)
|
||||
|
||||
def test_evaluation_mode_disables_dropout(self, densepose_head, mock_feature_input):
|
||||
"""Test that evaluation mode disables dropout for consistent inference"""
|
||||
# Arrange
|
||||
densepose_head.eval()
|
||||
|
||||
# Act
|
||||
output1 = densepose_head(mock_feature_input)
|
||||
output2 = densepose_head(mock_feature_input)
|
||||
|
||||
# Assert - outputs should be identical in eval mode
|
||||
assert torch.allclose(output1['segmentation'], output2['segmentation'], atol=1e-6)
|
||||
assert torch.allclose(output1['uv_coordinates'], output2['uv_coordinates'], atol=1e-6)
|
||||
|
||||
def test_head_validates_input_dimensions(self, densepose_head):
|
||||
"""Test that head validates input dimensions"""
|
||||
# Arrange
|
||||
invalid_input = torch.randn(2, 128, 56, 56) # Wrong number of channels
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(DensePoseError):
|
||||
densepose_head(invalid_input)
|
||||
|
||||
def test_head_handles_different_input_sizes(self, densepose_head):
|
||||
"""Test that head handles different input sizes"""
|
||||
# Arrange
|
||||
small_input = torch.randn(1, 256, 28, 28)
|
||||
large_input = torch.randn(1, 256, 112, 112)
|
||||
|
||||
# Act
|
||||
small_output = densepose_head(small_input)
|
||||
large_output = densepose_head(large_input)
|
||||
|
||||
# Assert
|
||||
assert small_output['segmentation'].shape[2:] != large_output['segmentation'].shape[2:]
|
||||
assert small_output['uv_coordinates'].shape[2:] != large_output['uv_coordinates'].shape[2:]
|
||||
|
||||
def test_head_supports_gradient_computation(self, densepose_head, mock_feature_input, mock_target_masks, mock_target_uv):
|
||||
"""Test that head supports gradient computation for training"""
|
||||
# Arrange
|
||||
densepose_head.train()
|
||||
optimizer = torch.optim.Adam(densepose_head.parameters(), lr=0.001)
|
||||
|
||||
output = densepose_head(mock_feature_input)
|
||||
|
||||
# Resize targets
|
||||
seg_target = torch.nn.functional.interpolate(
|
||||
mock_target_masks.float().unsqueeze(1),
|
||||
size=output['segmentation'].shape[2:],
|
||||
mode='nearest'
|
||||
).squeeze(1).long()
|
||||
|
||||
uv_target = torch.nn.functional.interpolate(
|
||||
mock_target_uv,
|
||||
size=output['uv_coordinates'].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Act
|
||||
loss = densepose_head.compute_total_loss(output, seg_target, uv_target)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
# Assert
|
||||
for param in densepose_head.parameters():
|
||||
if param.requires_grad:
|
||||
assert param.grad is not None
|
||||
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
|
||||
|
||||
def test_head_handles_invalid_input_shape(self, densepose_head):
|
||||
"""Test that head handles invalid input shapes gracefully"""
|
||||
def test_head_configuration_validation(self):
|
||||
"""Test that head validates configuration parameters"""
|
||||
# Arrange
|
||||
invalid_input = torch.randn(2, 256, 56, 100) # Wrong number of channels
|
||||
invalid_config = {
|
||||
'input_channels': 0, # Invalid
|
||||
'num_body_parts': -1, # Invalid
|
||||
'num_uv_coordinates': 2
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError):
|
||||
densepose_head(invalid_input)
|
||||
with pytest.raises(ValueError):
|
||||
DensePoseHead(invalid_config)
|
||||
|
||||
def test_head_supports_evaluation_mode(self, densepose_head, mock_feature_input):
|
||||
"""Test that head supports evaluation mode"""
|
||||
# Act
|
||||
densepose_head.eval()
|
||||
def test_save_and_load_model_state(self, densepose_head, mock_feature_input):
|
||||
"""Test that model state can be saved and loaded"""
|
||||
# Arrange
|
||||
original_output = densepose_head(mock_feature_input)
|
||||
|
||||
with torch.no_grad():
|
||||
seg1, uv1 = densepose_head(mock_feature_input)
|
||||
seg2, uv2 = densepose_head(mock_feature_input)
|
||||
# Act - Save state
|
||||
state_dict = densepose_head.state_dict()
|
||||
|
||||
# Assert - In eval mode with same input, outputs should be identical
|
||||
assert torch.allclose(seg1, seg2, atol=1e-6)
|
||||
assert torch.allclose(uv1, uv2, atol=1e-6)
|
||||
|
||||
def test_head_output_quality(self, densepose_head, mock_feature_input):
|
||||
"""Test that head produces meaningful outputs"""
|
||||
# Act
|
||||
with torch.no_grad():
|
||||
segmentation, uv_coords = densepose_head(mock_feature_input)
|
||||
# Create new head and load state
|
||||
new_head = DensePoseHead(densepose_head.config)
|
||||
new_head.load_state_dict(state_dict)
|
||||
new_output = new_head(mock_feature_input)
|
||||
|
||||
# Assert
|
||||
# Outputs should not contain NaN or Inf values
|
||||
assert not torch.isnan(segmentation).any()
|
||||
assert not torch.isinf(segmentation).any()
|
||||
assert not torch.isnan(uv_coords).any()
|
||||
assert not torch.isinf(uv_coords).any()
|
||||
|
||||
# Outputs should have reasonable variance (not all zeros or ones)
|
||||
assert segmentation.std() > 0.01
|
||||
assert uv_coords.std() > 0.01
|
||||
assert torch.allclose(original_output['segmentation'], new_output['segmentation'], atol=1e-6)
|
||||
assert torch.allclose(original_output['uv_coordinates'], new_output['uv_coordinates'], atol=1e-6)
|
||||
@@ -3,27 +3,27 @@ import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch
|
||||
from src.models.modality_translation import ModalityTranslationNetwork
|
||||
from src.models.modality_translation import ModalityTranslationNetwork, ModalityTranslationError
|
||||
|
||||
|
||||
class TestModalityTranslationNetwork:
|
||||
"""Test suite for Modality Translation Network following London School TDD principles"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_csi_input(self):
|
||||
"""Generate synthetic CSI input tensor for testing"""
|
||||
# Batch size 2, 3 antennas, 56 subcarriers, 100 temporal samples
|
||||
return torch.randn(2, 3, 56, 100)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Configuration for modality translation network"""
|
||||
return {
|
||||
'input_channels': 3,
|
||||
'hidden_dim': 256,
|
||||
'output_dim': 512,
|
||||
'num_layers': 3,
|
||||
'dropout_rate': 0.1
|
||||
'input_channels': 6, # Real and imaginary parts for 3 antennas
|
||||
'hidden_channels': [64, 128, 256],
|
||||
'output_channels': 256,
|
||||
'kernel_size': 3,
|
||||
'stride': 1,
|
||||
'padding': 1,
|
||||
'dropout_rate': 0.1,
|
||||
'activation': 'relu',
|
||||
'normalization': 'batch',
|
||||
'use_attention': True,
|
||||
'attention_heads': 8
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
@@ -31,98 +31,263 @@ class TestModalityTranslationNetwork:
|
||||
"""Create modality translation network instance for testing"""
|
||||
return ModalityTranslationNetwork(mock_config)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_csi_input(self):
|
||||
"""Generate mock CSI input tensor"""
|
||||
batch_size = 4
|
||||
channels = 6 # Real and imaginary parts for 3 antennas
|
||||
height = 56 # Number of subcarriers
|
||||
width = 100 # Time samples
|
||||
return torch.randn(batch_size, channels, height, width)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_target_features(self):
|
||||
"""Generate mock target feature tensor for training"""
|
||||
batch_size = 4
|
||||
feature_dim = 256
|
||||
spatial_height = 56
|
||||
spatial_width = 100
|
||||
return torch.randn(batch_size, feature_dim, spatial_height, spatial_width)
|
||||
|
||||
def test_network_initialization_creates_correct_architecture(self, mock_config):
|
||||
"""Test that network initializes with correct architecture"""
|
||||
"""Test that modality translation network initializes with correct architecture"""
|
||||
# Act
|
||||
network = ModalityTranslationNetwork(mock_config)
|
||||
|
||||
# Assert
|
||||
assert network is not None
|
||||
assert isinstance(network, nn.Module)
|
||||
assert network.input_channels == mock_config['input_channels']
|
||||
assert network.output_channels == mock_config['output_channels']
|
||||
assert network.use_attention == mock_config['use_attention']
|
||||
assert hasattr(network, 'encoder')
|
||||
assert hasattr(network, 'decoder')
|
||||
assert network.input_channels == mock_config['input_channels']
|
||||
assert network.hidden_dim == mock_config['hidden_dim']
|
||||
assert network.output_dim == mock_config['output_dim']
|
||||
if mock_config['use_attention']:
|
||||
assert hasattr(network, 'attention')
|
||||
|
||||
def test_forward_pass_produces_correct_output_shape(self, translation_network, mock_csi_input):
|
||||
"""Test that forward pass produces correctly shaped output"""
|
||||
# Act
|
||||
with torch.no_grad():
|
||||
output = translation_network(mock_csi_input)
|
||||
output = translation_network(mock_csi_input)
|
||||
|
||||
# Assert
|
||||
assert output is not None
|
||||
assert isinstance(output, torch.Tensor)
|
||||
assert output.shape[0] == mock_csi_input.shape[0] # Batch size preserved
|
||||
assert output.shape[1] == translation_network.output_dim # Correct output dimension
|
||||
assert len(output.shape) == 4 # Should maintain spatial dimensions
|
||||
assert output.shape[1] == translation_network.output_channels # Correct output channels
|
||||
assert output.shape[2] == mock_csi_input.shape[2] # Spatial height preserved
|
||||
assert output.shape[3] == mock_csi_input.shape[3] # Spatial width preserved
|
||||
|
||||
def test_forward_pass_handles_different_batch_sizes(self, translation_network):
|
||||
"""Test that network handles different batch sizes correctly"""
|
||||
def test_forward_pass_handles_different_input_sizes(self, translation_network):
|
||||
"""Test that forward pass handles different input sizes"""
|
||||
# Arrange
|
||||
batch_sizes = [1, 4, 8]
|
||||
small_input = torch.randn(2, 6, 28, 50)
|
||||
large_input = torch.randn(8, 6, 112, 200)
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
input_tensor = torch.randn(batch_size, 3, 56, 100)
|
||||
|
||||
# Act
|
||||
with torch.no_grad():
|
||||
output = translation_network(input_tensor)
|
||||
|
||||
# Assert
|
||||
assert output.shape[0] == batch_size
|
||||
assert output.shape[1] == translation_network.output_dim
|
||||
# Act
|
||||
small_output = translation_network(small_input)
|
||||
large_output = translation_network(large_input)
|
||||
|
||||
# Assert
|
||||
assert small_output.shape == (2, 256, 28, 50)
|
||||
assert large_output.shape == (8, 256, 112, 200)
|
||||
|
||||
def test_network_is_trainable(self, translation_network, mock_csi_input):
|
||||
"""Test that network parameters are trainable"""
|
||||
def test_encoder_extracts_hierarchical_features(self, translation_network, mock_csi_input):
|
||||
"""Test that encoder extracts hierarchical features"""
|
||||
# Act
|
||||
features = translation_network.encode(mock_csi_input)
|
||||
|
||||
# Assert
|
||||
assert features is not None
|
||||
assert isinstance(features, list)
|
||||
assert len(features) == len(translation_network.encoder)
|
||||
|
||||
# Check feature map sizes decrease with depth
|
||||
for i in range(1, len(features)):
|
||||
assert features[i].shape[2] <= features[i-1].shape[2] # Height decreases or stays same
|
||||
assert features[i].shape[3] <= features[i-1].shape[3] # Width decreases or stays same
|
||||
|
||||
def test_decoder_reconstructs_target_features(self, translation_network, mock_csi_input):
|
||||
"""Test that decoder reconstructs target feature representation"""
|
||||
# Arrange
|
||||
criterion = nn.MSELoss()
|
||||
encoded_features = translation_network.encode(mock_csi_input)
|
||||
|
||||
# Act
|
||||
decoded_output = translation_network.decode(encoded_features)
|
||||
|
||||
# Assert
|
||||
assert decoded_output is not None
|
||||
assert isinstance(decoded_output, torch.Tensor)
|
||||
assert decoded_output.shape[1] == translation_network.output_channels
|
||||
assert decoded_output.shape[2:] == mock_csi_input.shape[2:]
|
||||
|
||||
def test_attention_mechanism_enhances_features(self, mock_config, mock_csi_input):
|
||||
"""Test that attention mechanism enhances feature representation"""
|
||||
# Arrange
|
||||
config_with_attention = mock_config.copy()
|
||||
config_with_attention['use_attention'] = True
|
||||
|
||||
config_without_attention = mock_config.copy()
|
||||
config_without_attention['use_attention'] = False
|
||||
|
||||
network_with_attention = ModalityTranslationNetwork(config_with_attention)
|
||||
network_without_attention = ModalityTranslationNetwork(config_without_attention)
|
||||
|
||||
# Act
|
||||
output_with_attention = network_with_attention(mock_csi_input)
|
||||
output_without_attention = network_without_attention(mock_csi_input)
|
||||
|
||||
# Assert
|
||||
assert output_with_attention.shape == output_without_attention.shape
|
||||
# Outputs should be different due to attention mechanism
|
||||
assert not torch.allclose(output_with_attention, output_without_attention, atol=1e-6)
|
||||
|
||||
def test_training_mode_enables_dropout(self, translation_network, mock_csi_input):
|
||||
"""Test that training mode enables dropout for regularization"""
|
||||
# Arrange
|
||||
translation_network.train()
|
||||
|
||||
# Act
|
||||
output1 = translation_network(mock_csi_input)
|
||||
output2 = translation_network(mock_csi_input)
|
||||
|
||||
# Assert - outputs should be different due to dropout
|
||||
assert not torch.allclose(output1, output2, atol=1e-6)
|
||||
|
||||
def test_evaluation_mode_disables_dropout(self, translation_network, mock_csi_input):
|
||||
"""Test that evaluation mode disables dropout for consistent inference"""
|
||||
# Arrange
|
||||
translation_network.eval()
|
||||
|
||||
# Act
|
||||
output1 = translation_network(mock_csi_input)
|
||||
output2 = translation_network(mock_csi_input)
|
||||
|
||||
# Assert - outputs should be identical in eval mode
|
||||
assert torch.allclose(output1, output2, atol=1e-6)
|
||||
|
||||
def test_compute_translation_loss_measures_feature_alignment(self, translation_network, mock_csi_input, mock_target_features):
|
||||
"""Test that compute_translation_loss measures feature alignment"""
|
||||
# Arrange
|
||||
predicted_features = translation_network(mock_csi_input)
|
||||
|
||||
# Act
|
||||
loss = translation_network.compute_translation_loss(predicted_features, mock_target_features)
|
||||
|
||||
# Assert
|
||||
assert loss is not None
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.dim() == 0 # Scalar loss
|
||||
assert loss.item() >= 0 # Loss should be non-negative
|
||||
|
||||
def test_compute_translation_loss_handles_different_loss_types(self, translation_network, mock_csi_input, mock_target_features):
|
||||
"""Test that compute_translation_loss handles different loss types"""
|
||||
# Arrange
|
||||
predicted_features = translation_network(mock_csi_input)
|
||||
|
||||
# Act
|
||||
mse_loss = translation_network.compute_translation_loss(predicted_features, mock_target_features, loss_type='mse')
|
||||
l1_loss = translation_network.compute_translation_loss(predicted_features, mock_target_features, loss_type='l1')
|
||||
|
||||
# Assert
|
||||
assert mse_loss is not None
|
||||
assert l1_loss is not None
|
||||
assert mse_loss.item() != l1_loss.item() # Different loss types should give different values
|
||||
|
||||
def test_get_feature_statistics_provides_analysis(self, translation_network, mock_csi_input):
|
||||
"""Test that get_feature_statistics provides feature analysis"""
|
||||
# Arrange
|
||||
output = translation_network(mock_csi_input)
|
||||
|
||||
# Act
|
||||
stats = translation_network.get_feature_statistics(output)
|
||||
|
||||
# Assert
|
||||
assert stats is not None
|
||||
assert isinstance(stats, dict)
|
||||
assert 'mean' in stats
|
||||
assert 'std' in stats
|
||||
assert 'min' in stats
|
||||
assert 'max' in stats
|
||||
assert 'sparsity' in stats
|
||||
|
||||
def test_network_supports_gradient_computation(self, translation_network, mock_csi_input, mock_target_features):
|
||||
"""Test that network supports gradient computation for training"""
|
||||
# Arrange
|
||||
translation_network.train()
|
||||
optimizer = torch.optim.Adam(translation_network.parameters(), lr=0.001)
|
||||
|
||||
# Act
|
||||
output = translation_network(mock_csi_input)
|
||||
# Create target with same shape as output
|
||||
target = torch.randn_like(output)
|
||||
loss = criterion(output, target)
|
||||
loss = translation_network.compute_translation_loss(output, mock_target_features)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
# Assert
|
||||
assert loss.item() > 0
|
||||
# Check that gradients are computed
|
||||
for param in translation_network.parameters():
|
||||
if param.requires_grad:
|
||||
assert param.grad is not None
|
||||
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
|
||||
|
||||
def test_network_handles_invalid_input_shape(self, translation_network):
|
||||
"""Test that network handles invalid input shapes gracefully"""
|
||||
def test_network_validates_input_dimensions(self, translation_network):
|
||||
"""Test that network validates input dimensions"""
|
||||
# Arrange
|
||||
invalid_input = torch.randn(2, 5, 56, 100) # Wrong number of channels
|
||||
invalid_input = torch.randn(4, 3, 56, 100) # Wrong number of channels
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError):
|
||||
with pytest.raises(ModalityTranslationError):
|
||||
translation_network(invalid_input)
|
||||
|
||||
def test_network_supports_evaluation_mode(self, translation_network, mock_csi_input):
|
||||
"""Test that network supports evaluation mode"""
|
||||
# Act
|
||||
translation_network.eval()
|
||||
def test_network_handles_batch_size_one(self, translation_network):
|
||||
"""Test that network handles single sample inference"""
|
||||
# Arrange
|
||||
single_input = torch.randn(1, 6, 56, 100)
|
||||
|
||||
with torch.no_grad():
|
||||
output1 = translation_network(mock_csi_input)
|
||||
output2 = translation_network(mock_csi_input)
|
||||
|
||||
# Assert - In eval mode with same input, outputs should be identical
|
||||
assert torch.allclose(output1, output2, atol=1e-6)
|
||||
|
||||
def test_network_feature_extraction_quality(self, translation_network, mock_csi_input):
|
||||
"""Test that network extracts meaningful features"""
|
||||
# Act
|
||||
with torch.no_grad():
|
||||
output = translation_network(mock_csi_input)
|
||||
output = translation_network(single_input)
|
||||
|
||||
# Assert
|
||||
# Features should have reasonable statistics
|
||||
assert not torch.isnan(output).any()
|
||||
assert not torch.isinf(output).any()
|
||||
assert output.std() > 0.01 # Features should have some variance
|
||||
assert output.std() < 10.0 # But not be too extreme
|
||||
assert output.shape == (1, 256, 56, 100)
|
||||
|
||||
def test_save_and_load_model_state(self, translation_network, mock_csi_input):
|
||||
"""Test that model state can be saved and loaded"""
|
||||
# Arrange
|
||||
original_output = translation_network(mock_csi_input)
|
||||
|
||||
# Act - Save state
|
||||
state_dict = translation_network.state_dict()
|
||||
|
||||
# Create new network and load state
|
||||
new_network = ModalityTranslationNetwork(translation_network.config)
|
||||
new_network.load_state_dict(state_dict)
|
||||
new_output = new_network(mock_csi_input)
|
||||
|
||||
# Assert
|
||||
assert torch.allclose(original_output, new_output, atol=1e-6)
|
||||
|
||||
def test_network_configuration_validation(self):
|
||||
"""Test that network validates configuration parameters"""
|
||||
# Arrange
|
||||
invalid_config = {
|
||||
'input_channels': 0, # Invalid
|
||||
'hidden_channels': [], # Invalid
|
||||
'output_channels': 256
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
ModalityTranslationNetwork(invalid_config)
|
||||
|
||||
def test_feature_visualization_support(self, translation_network, mock_csi_input):
|
||||
"""Test that network supports feature visualization"""
|
||||
# Act
|
||||
features = translation_network.get_intermediate_features(mock_csi_input)
|
||||
|
||||
# Assert
|
||||
assert features is not None
|
||||
assert isinstance(features, dict)
|
||||
assert 'encoder_features' in features
|
||||
assert 'decoder_features' in features
|
||||
if translation_network.use_attention:
|
||||
assert 'attention_weights' in features
|
||||
244
tests/unit/test_router_interface.py
Normal file
244
tests/unit/test_router_interface.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from src.hardware.router_interface import RouterInterface, RouterConnectionError
|
||||
|
||||
|
||||
class TestRouterInterface:
|
||||
"""Test suite for Router Interface following London School TDD principles"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Configuration for router interface"""
|
||||
return {
|
||||
'router_ip': '192.168.1.1',
|
||||
'username': 'admin',
|
||||
'password': 'password',
|
||||
'ssh_port': 22,
|
||||
'timeout': 30,
|
||||
'max_retries': 3
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def router_interface(self, mock_config):
|
||||
"""Create router interface instance for testing"""
|
||||
return RouterInterface(mock_config)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ssh_client(self):
|
||||
"""Mock SSH client for testing"""
|
||||
mock_client = Mock()
|
||||
mock_client.connect = Mock()
|
||||
mock_client.exec_command = Mock()
|
||||
mock_client.close = Mock()
|
||||
return mock_client
|
||||
|
||||
def test_interface_initialization_creates_correct_configuration(self, mock_config):
|
||||
"""Test that router interface initializes with correct configuration"""
|
||||
# Act
|
||||
interface = RouterInterface(mock_config)
|
||||
|
||||
# Assert
|
||||
assert interface is not None
|
||||
assert interface.router_ip == mock_config['router_ip']
|
||||
assert interface.username == mock_config['username']
|
||||
assert interface.password == mock_config['password']
|
||||
assert interface.ssh_port == mock_config['ssh_port']
|
||||
assert interface.timeout == mock_config['timeout']
|
||||
assert interface.max_retries == mock_config['max_retries']
|
||||
assert not interface.is_connected
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_connect_establishes_ssh_connection(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that connect method establishes SSH connection"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
|
||||
# Act
|
||||
result = router_interface.connect()
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert router_interface.is_connected is True
|
||||
mock_ssh_client.set_missing_host_key_policy.assert_called_once()
|
||||
mock_ssh_client.connect.assert_called_once_with(
|
||||
hostname=router_interface.router_ip,
|
||||
port=router_interface.ssh_port,
|
||||
username=router_interface.username,
|
||||
password=router_interface.password,
|
||||
timeout=router_interface.timeout
|
||||
)
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_connect_handles_connection_failure(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that connect method handles connection failures gracefully"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
mock_ssh_client.connect.side_effect = Exception("Connection failed")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RouterConnectionError):
|
||||
router_interface.connect()
|
||||
|
||||
assert router_interface.is_connected is False
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_disconnect_closes_ssh_connection(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that disconnect method closes SSH connection"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
router_interface.connect()
|
||||
|
||||
# Act
|
||||
router_interface.disconnect()
|
||||
|
||||
# Assert
|
||||
assert router_interface.is_connected is False
|
||||
mock_ssh_client.close.assert_called_once()
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_execute_command_runs_ssh_command(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that execute_command runs SSH commands correctly"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
mock_stdout = Mock()
|
||||
mock_stdout.read.return_value = b"command output"
|
||||
mock_stderr = Mock()
|
||||
mock_stderr.read.return_value = b""
|
||||
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
|
||||
|
||||
router_interface.connect()
|
||||
|
||||
# Act
|
||||
result = router_interface.execute_command("test command")
|
||||
|
||||
# Assert
|
||||
assert result == "command output"
|
||||
mock_ssh_client.exec_command.assert_called_with("test command")
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_execute_command_handles_command_errors(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that execute_command handles command errors"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
mock_stdout = Mock()
|
||||
mock_stdout.read.return_value = b""
|
||||
mock_stderr = Mock()
|
||||
mock_stderr.read.return_value = b"command error"
|
||||
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
|
||||
|
||||
router_interface.connect()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RouterConnectionError):
|
||||
router_interface.execute_command("failing command")
|
||||
|
||||
def test_execute_command_requires_connection(self, router_interface):
|
||||
"""Test that execute_command requires active connection"""
|
||||
# Act & Assert
|
||||
with pytest.raises(RouterConnectionError):
|
||||
router_interface.execute_command("test command")
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_get_router_info_retrieves_system_information(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that get_router_info retrieves router system information"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
mock_stdout = Mock()
|
||||
mock_stdout.read.return_value = b"Router Model: AC1900\nFirmware: 1.2.3"
|
||||
mock_stderr = Mock()
|
||||
mock_stderr.read.return_value = b""
|
||||
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
|
||||
|
||||
router_interface.connect()
|
||||
|
||||
# Act
|
||||
info = router_interface.get_router_info()
|
||||
|
||||
# Assert
|
||||
assert info is not None
|
||||
assert isinstance(info, dict)
|
||||
assert 'model' in info
|
||||
assert 'firmware' in info
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_enable_monitor_mode_configures_wifi_monitoring(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that enable_monitor_mode configures WiFi monitoring"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
mock_stdout = Mock()
|
||||
mock_stdout.read.return_value = b"Monitor mode enabled"
|
||||
mock_stderr = Mock()
|
||||
mock_stderr.read.return_value = b""
|
||||
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
|
||||
|
||||
router_interface.connect()
|
||||
|
||||
# Act
|
||||
result = router_interface.enable_monitor_mode("wlan0")
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_ssh_client.exec_command.assert_called()
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_disable_monitor_mode_disables_wifi_monitoring(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that disable_monitor_mode disables WiFi monitoring"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
mock_stdout = Mock()
|
||||
mock_stdout.read.return_value = b"Monitor mode disabled"
|
||||
mock_stderr = Mock()
|
||||
mock_stderr.read.return_value = b""
|
||||
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
|
||||
|
||||
router_interface.connect()
|
||||
|
||||
# Act
|
||||
result = router_interface.disable_monitor_mode("wlan0")
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_ssh_client.exec_command.assert_called()
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_interface_supports_context_manager(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that router interface supports context manager protocol"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
|
||||
# Act
|
||||
with router_interface as interface:
|
||||
# Assert
|
||||
assert interface.is_connected is True
|
||||
|
||||
# Assert - connection should be closed after context
|
||||
assert router_interface.is_connected is False
|
||||
mock_ssh_client.close.assert_called_once()
|
||||
|
||||
def test_interface_validates_configuration(self):
|
||||
"""Test that router interface validates configuration parameters"""
|
||||
# Arrange
|
||||
invalid_config = {
|
||||
'router_ip': '', # Invalid IP
|
||||
'username': 'admin',
|
||||
'password': 'password'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
RouterInterface(invalid_config)
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_interface_implements_retry_logic(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that interface implements retry logic for failed operations"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
mock_ssh_client.connect.side_effect = [Exception("Temp failure"), None] # Fail once, then succeed
|
||||
|
||||
# Act
|
||||
result = router_interface.connect()
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert mock_ssh_client.connect.call_count == 2 # Should retry once
|
||||
Reference in New Issue
Block a user