feat: Complete Rust port of WiFi-DensePose with modular crates
Major changes: - Organized Python v1 implementation into v1/ subdirectory - Created Rust workspace with 9 modular crates: - wifi-densepose-core: Core types, traits, errors - wifi-densepose-signal: CSI processing, phase sanitization, FFT - wifi-densepose-nn: Neural network inference (ONNX/Candle/tch) - wifi-densepose-api: Axum-based REST/WebSocket API - wifi-densepose-db: SQLx database layer - wifi-densepose-config: Configuration management - wifi-densepose-hardware: Hardware abstraction - wifi-densepose-wasm: WebAssembly bindings - wifi-densepose-cli: Command-line interface Documentation: - ADR-001: Workspace structure - ADR-002: Signal processing library selection - ADR-003: Neural network inference strategy - DDD domain model with bounded contexts Testing: - 69 tests passing across all crates - Signal processing: 45 tests - Neural networks: 21 tests - Core: 3 doc tests Performance targets: - 10x faster CSI processing (~0.5ms vs ~5ms) - 5x lower memory usage (~100MB vs ~500MB) - WASM support for browser deployment
This commit is contained in:
264
v1/tests/unit/test_csi_extractor.py
Normal file
264
v1/tests/unit/test_csi_extractor.py
Normal file
@@ -0,0 +1,264 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
import torch
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from src.hardware.csi_extractor import CSIExtractor, CSIExtractionError
|
||||
|
||||
|
||||
class TestCSIExtractor:
|
||||
"""Test suite for CSI Extractor following London School TDD principles"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Configuration for CSI extractor"""
|
||||
return {
|
||||
'interface': 'wlan0',
|
||||
'channel': 6,
|
||||
'bandwidth': 20,
|
||||
'sample_rate': 1000,
|
||||
'buffer_size': 1024,
|
||||
'extraction_timeout': 5.0
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_router_interface(self):
|
||||
"""Mock router interface for testing"""
|
||||
mock_router = Mock()
|
||||
mock_router.is_connected = True
|
||||
mock_router.execute_command = Mock()
|
||||
return mock_router
|
||||
|
||||
@pytest.fixture
|
||||
def csi_extractor(self, mock_config, mock_router_interface):
|
||||
"""Create CSI extractor instance for testing"""
|
||||
return CSIExtractor(mock_config, mock_router_interface)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_csi_data(self):
|
||||
"""Generate synthetic CSI data for testing"""
|
||||
# Simulate CSI data: complex values for multiple subcarriers
|
||||
num_subcarriers = 56
|
||||
num_antennas = 3
|
||||
amplitude = np.random.uniform(0.1, 2.0, (num_antennas, num_subcarriers))
|
||||
phase = np.random.uniform(-np.pi, np.pi, (num_antennas, num_subcarriers))
|
||||
return amplitude * np.exp(1j * phase)
|
||||
|
||||
def test_extractor_initialization_creates_correct_configuration(self, mock_config, mock_router_interface):
|
||||
"""Test that CSI extractor initializes with correct configuration"""
|
||||
# Act
|
||||
extractor = CSIExtractor(mock_config, mock_router_interface)
|
||||
|
||||
# Assert
|
||||
assert extractor is not None
|
||||
assert extractor.interface == mock_config['interface']
|
||||
assert extractor.channel == mock_config['channel']
|
||||
assert extractor.bandwidth == mock_config['bandwidth']
|
||||
assert extractor.sample_rate == mock_config['sample_rate']
|
||||
assert extractor.buffer_size == mock_config['buffer_size']
|
||||
assert extractor.extraction_timeout == mock_config['extraction_timeout']
|
||||
assert extractor.router_interface == mock_router_interface
|
||||
assert not extractor.is_extracting
|
||||
|
||||
def test_start_extraction_configures_monitor_mode(self, csi_extractor, mock_router_interface):
|
||||
"""Test that start_extraction configures monitor mode"""
|
||||
# Arrange
|
||||
mock_router_interface.enable_monitor_mode.return_value = True
|
||||
mock_router_interface.execute_command.return_value = "CSI extraction started"
|
||||
|
||||
# Act
|
||||
result = csi_extractor.start_extraction()
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert csi_extractor.is_extracting is True
|
||||
mock_router_interface.enable_monitor_mode.assert_called_once_with(csi_extractor.interface)
|
||||
|
||||
def test_start_extraction_handles_monitor_mode_failure(self, csi_extractor, mock_router_interface):
|
||||
"""Test that start_extraction handles monitor mode configuration failure"""
|
||||
# Arrange
|
||||
mock_router_interface.enable_monitor_mode.return_value = False
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(CSIExtractionError):
|
||||
csi_extractor.start_extraction()
|
||||
|
||||
assert csi_extractor.is_extracting is False
|
||||
|
||||
def test_stop_extraction_disables_monitor_mode(self, csi_extractor, mock_router_interface):
|
||||
"""Test that stop_extraction disables monitor mode"""
|
||||
# Arrange
|
||||
mock_router_interface.enable_monitor_mode.return_value = True
|
||||
mock_router_interface.disable_monitor_mode.return_value = True
|
||||
mock_router_interface.execute_command.return_value = "CSI extraction started"
|
||||
|
||||
csi_extractor.start_extraction()
|
||||
|
||||
# Act
|
||||
result = csi_extractor.stop_extraction()
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert csi_extractor.is_extracting is False
|
||||
mock_router_interface.disable_monitor_mode.assert_called_once_with(csi_extractor.interface)
|
||||
|
||||
def test_extract_csi_data_returns_valid_format(self, csi_extractor, mock_router_interface, mock_csi_data):
|
||||
"""Test that extract_csi_data returns data in valid format"""
|
||||
# Arrange
|
||||
mock_router_interface.enable_monitor_mode.return_value = True
|
||||
mock_router_interface.execute_command.return_value = "CSI extraction started"
|
||||
|
||||
# Mock the CSI data extraction
|
||||
with patch.object(csi_extractor, '_parse_csi_output', return_value=mock_csi_data):
|
||||
csi_extractor.start_extraction()
|
||||
|
||||
# Act
|
||||
csi_data = csi_extractor.extract_csi_data()
|
||||
|
||||
# Assert
|
||||
assert csi_data is not None
|
||||
assert isinstance(csi_data, np.ndarray)
|
||||
assert csi_data.dtype == np.complex128
|
||||
assert csi_data.shape == mock_csi_data.shape
|
||||
|
||||
def test_extract_csi_data_requires_active_extraction(self, csi_extractor):
|
||||
"""Test that extract_csi_data requires active extraction"""
|
||||
# Act & Assert
|
||||
with pytest.raises(CSIExtractionError):
|
||||
csi_extractor.extract_csi_data()
|
||||
|
||||
def test_extract_csi_data_handles_timeout(self, csi_extractor, mock_router_interface):
|
||||
"""Test that extract_csi_data handles extraction timeout"""
|
||||
# Arrange
|
||||
mock_router_interface.enable_monitor_mode.return_value = True
|
||||
mock_router_interface.execute_command.side_effect = [
|
||||
"CSI extraction started",
|
||||
Exception("Timeout")
|
||||
]
|
||||
|
||||
csi_extractor.start_extraction()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(CSIExtractionError):
|
||||
csi_extractor.extract_csi_data()
|
||||
|
||||
def test_convert_to_tensor_produces_correct_format(self, csi_extractor, mock_csi_data):
|
||||
"""Test that convert_to_tensor produces correctly formatted tensor"""
|
||||
# Act
|
||||
tensor = csi_extractor.convert_to_tensor(mock_csi_data)
|
||||
|
||||
# Assert
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
assert tensor.dtype == torch.float32
|
||||
assert tensor.shape[0] == mock_csi_data.shape[0] * 2 # Real and imaginary parts
|
||||
assert tensor.shape[1] == mock_csi_data.shape[1]
|
||||
|
||||
def test_convert_to_tensor_handles_invalid_input(self, csi_extractor):
|
||||
"""Test that convert_to_tensor handles invalid input"""
|
||||
# Arrange
|
||||
invalid_data = "not an array"
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
csi_extractor.convert_to_tensor(invalid_data)
|
||||
|
||||
def test_get_extraction_stats_returns_valid_statistics(self, csi_extractor, mock_router_interface):
|
||||
"""Test that get_extraction_stats returns valid statistics"""
|
||||
# Arrange
|
||||
mock_router_interface.enable_monitor_mode.return_value = True
|
||||
mock_router_interface.execute_command.return_value = "CSI extraction started"
|
||||
|
||||
csi_extractor.start_extraction()
|
||||
|
||||
# Act
|
||||
stats = csi_extractor.get_extraction_stats()
|
||||
|
||||
# Assert
|
||||
assert stats is not None
|
||||
assert isinstance(stats, dict)
|
||||
assert 'samples_extracted' in stats
|
||||
assert 'extraction_rate' in stats
|
||||
assert 'buffer_utilization' in stats
|
||||
assert 'last_extraction_time' in stats
|
||||
|
||||
def test_set_channel_configures_wifi_channel(self, csi_extractor, mock_router_interface):
|
||||
"""Test that set_channel configures WiFi channel"""
|
||||
# Arrange
|
||||
new_channel = 11
|
||||
mock_router_interface.execute_command.return_value = f"Channel set to {new_channel}"
|
||||
|
||||
# Act
|
||||
result = csi_extractor.set_channel(new_channel)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert csi_extractor.channel == new_channel
|
||||
mock_router_interface.execute_command.assert_called()
|
||||
|
||||
def test_set_channel_validates_channel_range(self, csi_extractor):
|
||||
"""Test that set_channel validates channel range"""
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
csi_extractor.set_channel(0) # Invalid channel
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
csi_extractor.set_channel(15) # Invalid channel
|
||||
|
||||
def test_extractor_supports_context_manager(self, csi_extractor, mock_router_interface):
|
||||
"""Test that CSI extractor supports context manager protocol"""
|
||||
# Arrange
|
||||
mock_router_interface.enable_monitor_mode.return_value = True
|
||||
mock_router_interface.disable_monitor_mode.return_value = True
|
||||
mock_router_interface.execute_command.return_value = "CSI extraction started"
|
||||
|
||||
# Act
|
||||
with csi_extractor as extractor:
|
||||
# Assert
|
||||
assert extractor.is_extracting is True
|
||||
|
||||
# Assert - extraction should be stopped after context
|
||||
assert csi_extractor.is_extracting is False
|
||||
|
||||
def test_extractor_validates_configuration(self, mock_router_interface):
|
||||
"""Test that CSI extractor validates configuration parameters"""
|
||||
# Arrange
|
||||
invalid_config = {
|
||||
'interface': '', # Invalid interface
|
||||
'channel': 6,
|
||||
'bandwidth': 20
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
CSIExtractor(invalid_config, mock_router_interface)
|
||||
|
||||
def test_parse_csi_output_processes_raw_data(self, csi_extractor):
|
||||
"""Test that _parse_csi_output processes raw CSI data correctly"""
|
||||
# Arrange
|
||||
raw_output = "CSI_DATA: 1.5+0.5j,2.0-1.0j,0.8+1.2j"
|
||||
|
||||
# Act
|
||||
parsed_data = csi_extractor._parse_csi_output(raw_output)
|
||||
|
||||
# Assert
|
||||
assert parsed_data is not None
|
||||
assert isinstance(parsed_data, np.ndarray)
|
||||
assert parsed_data.dtype == np.complex128
|
||||
|
||||
def test_buffer_management_handles_overflow(self, csi_extractor, mock_router_interface, mock_csi_data):
|
||||
"""Test that buffer management handles overflow correctly"""
|
||||
# Arrange
|
||||
mock_router_interface.enable_monitor_mode.return_value = True
|
||||
mock_router_interface.execute_command.return_value = "CSI extraction started"
|
||||
|
||||
with patch.object(csi_extractor, '_parse_csi_output', return_value=mock_csi_data):
|
||||
csi_extractor.start_extraction()
|
||||
|
||||
# Fill buffer beyond capacity
|
||||
for _ in range(csi_extractor.buffer_size + 10):
|
||||
csi_extractor._add_to_buffer(mock_csi_data)
|
||||
|
||||
# Act
|
||||
stats = csi_extractor.get_extraction_stats()
|
||||
|
||||
# Assert
|
||||
assert stats['buffer_utilization'] <= 1.0 # Should not exceed 100%
|
||||
588
v1/tests/unit/test_csi_extractor_direct.py
Normal file
588
v1/tests/unit/test_csi_extractor_direct.py
Normal file
@@ -0,0 +1,588 @@
|
||||
"""Direct tests for CSI extractor avoiding import issues."""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||||
from typing import Dict, Any, Optional
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Add src to path for direct import
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../'))
|
||||
|
||||
# Import the CSI extractor module directly
|
||||
from src.hardware.csi_extractor import (
|
||||
CSIExtractor,
|
||||
CSIParseError,
|
||||
CSIData,
|
||||
ESP32CSIParser,
|
||||
RouterCSIParser,
|
||||
CSIValidationError
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestCSIExtractorDirect:
|
||||
"""Test CSI extractor with direct imports."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger(self):
|
||||
"""Mock logger for testing."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def esp32_config(self):
|
||||
"""ESP32 configuration for testing."""
|
||||
return {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 1024,
|
||||
'timeout': 5.0,
|
||||
'validation_enabled': True,
|
||||
'retry_attempts': 3
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def router_config(self):
|
||||
"""Router configuration for testing."""
|
||||
return {
|
||||
'hardware_type': 'router',
|
||||
'sampling_rate': 50,
|
||||
'buffer_size': 512,
|
||||
'timeout': 10.0,
|
||||
'validation_enabled': False,
|
||||
'retry_attempts': 1
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def sample_csi_data(self):
|
||||
"""Sample CSI data for testing."""
|
||||
return CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={'source': 'esp32', 'channel': 6}
|
||||
)
|
||||
|
||||
# Initialization tests
|
||||
def test_should_initialize_with_valid_config(self, esp32_config, mock_logger):
|
||||
"""Should initialize CSI extractor with valid configuration."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
assert extractor.config == esp32_config
|
||||
assert extractor.logger == mock_logger
|
||||
assert extractor.is_connected == False
|
||||
assert extractor.hardware_type == 'esp32'
|
||||
|
||||
def test_should_create_esp32_parser(self, esp32_config, mock_logger):
|
||||
"""Should create ESP32 parser when hardware_type is esp32."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
assert isinstance(extractor.parser, ESP32CSIParser)
|
||||
|
||||
def test_should_create_router_parser(self, router_config, mock_logger):
|
||||
"""Should create router parser when hardware_type is router."""
|
||||
extractor = CSIExtractor(config=router_config, logger=mock_logger)
|
||||
|
||||
assert isinstance(extractor.parser, RouterCSIParser)
|
||||
assert extractor.hardware_type == 'router'
|
||||
|
||||
def test_should_raise_error_for_unsupported_hardware(self, mock_logger):
|
||||
"""Should raise error for unsupported hardware type."""
|
||||
invalid_config = {
|
||||
'hardware_type': 'unsupported',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 1024,
|
||||
'timeout': 5.0
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported hardware type: unsupported"):
|
||||
CSIExtractor(config=invalid_config, logger=mock_logger)
|
||||
|
||||
# Configuration validation tests
|
||||
def test_config_validation_missing_fields(self, mock_logger):
|
||||
"""Should validate required configuration fields."""
|
||||
invalid_config = {'invalid': 'config'}
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required configuration"):
|
||||
CSIExtractor(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_config_validation_negative_sampling_rate(self, mock_logger):
|
||||
"""Should validate sampling_rate is positive."""
|
||||
invalid_config = {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': -1,
|
||||
'buffer_size': 1024,
|
||||
'timeout': 5.0
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="sampling_rate must be positive"):
|
||||
CSIExtractor(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_config_validation_zero_buffer_size(self, mock_logger):
|
||||
"""Should validate buffer_size is positive."""
|
||||
invalid_config = {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 0,
|
||||
'timeout': 5.0
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="buffer_size must be positive"):
|
||||
CSIExtractor(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_config_validation_negative_timeout(self, mock_logger):
|
||||
"""Should validate timeout is positive."""
|
||||
invalid_config = {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 1024,
|
||||
'timeout': -1.0
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="timeout must be positive"):
|
||||
CSIExtractor(config=invalid_config, logger=mock_logger)
|
||||
|
||||
# Connection tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_establish_connection_successfully(self, esp32_config, mock_logger):
|
||||
"""Should establish connection to hardware successfully."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
with patch.object(extractor, '_establish_hardware_connection', new_callable=AsyncMock) as mock_connect:
|
||||
mock_connect.return_value = True
|
||||
|
||||
result = await extractor.connect()
|
||||
|
||||
assert result == True
|
||||
assert extractor.is_connected == True
|
||||
mock_connect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_handle_connection_failure(self, esp32_config, mock_logger):
|
||||
"""Should handle connection failure gracefully."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
with patch.object(extractor, '_establish_hardware_connection', new_callable=AsyncMock) as mock_connect:
|
||||
mock_connect.side_effect = ConnectionError("Hardware not found")
|
||||
|
||||
result = await extractor.connect()
|
||||
|
||||
assert result == False
|
||||
assert extractor.is_connected == False
|
||||
extractor.logger.error.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_disconnect_properly(self, esp32_config, mock_logger):
|
||||
"""Should disconnect from hardware properly."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
|
||||
with patch.object(extractor, '_close_hardware_connection', new_callable=AsyncMock) as mock_disconnect:
|
||||
await extractor.disconnect()
|
||||
|
||||
assert extractor.is_connected == False
|
||||
mock_disconnect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_when_not_connected(self, esp32_config, mock_logger):
|
||||
"""Should handle disconnect when not connected."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = False
|
||||
|
||||
with patch.object(extractor, '_close_hardware_connection', new_callable=AsyncMock) as mock_close:
|
||||
await extractor.disconnect()
|
||||
|
||||
# Should not call close when not connected
|
||||
mock_close.assert_not_called()
|
||||
assert extractor.is_connected == False
|
||||
|
||||
# Data extraction tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_extract_csi_data_successfully(self, esp32_config, mock_logger, sample_csi_data):
|
||||
"""Should extract CSI data successfully from hardware."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
|
||||
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
|
||||
with patch.object(extractor.parser, 'parse', return_value=sample_csi_data) as mock_parse:
|
||||
mock_read.return_value = b"raw_csi_data"
|
||||
|
||||
result = await extractor.extract_csi()
|
||||
|
||||
assert result == sample_csi_data
|
||||
mock_read.assert_called_once()
|
||||
mock_parse.assert_called_once_with(b"raw_csi_data")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_handle_extraction_failure_when_not_connected(self, esp32_config, mock_logger):
|
||||
"""Should handle extraction failure when not connected."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = False
|
||||
|
||||
with pytest.raises(CSIParseError, match="Not connected to hardware"):
|
||||
await extractor.extract_csi()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_retry_on_temporary_failure(self, esp32_config, mock_logger, sample_csi_data):
|
||||
"""Should retry extraction on temporary failure."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
|
||||
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
|
||||
with patch.object(extractor.parser, 'parse') as mock_parse:
|
||||
# First two calls fail, third succeeds
|
||||
mock_read.side_effect = [ConnectionError(), ConnectionError(), b"raw_data"]
|
||||
mock_parse.return_value = sample_csi_data
|
||||
|
||||
result = await extractor.extract_csi()
|
||||
|
||||
assert result == sample_csi_data
|
||||
assert mock_read.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_with_validation_disabled(self, esp32_config, mock_logger, sample_csi_data):
|
||||
"""Should skip validation when disabled."""
|
||||
esp32_config['validation_enabled'] = False
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
|
||||
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
|
||||
with patch.object(extractor.parser, 'parse', return_value=sample_csi_data) as mock_parse:
|
||||
with patch.object(extractor, 'validate_csi_data') as mock_validate:
|
||||
mock_read.return_value = b"raw_data"
|
||||
|
||||
result = await extractor.extract_csi()
|
||||
|
||||
assert result == sample_csi_data
|
||||
mock_validate.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_max_retries_exceeded(self, esp32_config, mock_logger):
|
||||
"""Should raise error after max retries exceeded."""
|
||||
esp32_config['retry_attempts'] = 2
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
|
||||
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
|
||||
mock_read.side_effect = ConnectionError("Connection failed")
|
||||
|
||||
with pytest.raises(CSIParseError, match="Extraction failed after 2 attempts"):
|
||||
await extractor.extract_csi()
|
||||
|
||||
assert mock_read.call_count == 2
|
||||
|
||||
# Validation tests
|
||||
def test_should_validate_csi_data_successfully(self, esp32_config, mock_logger, sample_csi_data):
|
||||
"""Should validate CSI data successfully."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
result = extractor.validate_csi_data(sample_csi_data)
|
||||
|
||||
assert result == True
|
||||
|
||||
def test_validation_empty_amplitude(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for empty amplitude."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.array([]),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Empty amplitude data"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_empty_phase(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for empty phase."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.array([]),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Empty phase data"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_invalid_frequency(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for invalid frequency."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=0,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid frequency"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_invalid_bandwidth(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for invalid bandwidth."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=0,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid bandwidth"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_invalid_subcarriers(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for invalid subcarriers."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=0,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid number of subcarriers"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_invalid_antennas(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for invalid antennas."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=0,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid number of antennas"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_snr_too_low(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for SNR too low."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=-100,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid SNR value"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_snr_too_high(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for SNR too high."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=100,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid SNR value"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
# Streaming tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_start_streaming_successfully(self, esp32_config, mock_logger, sample_csi_data):
|
||||
"""Should start CSI data streaming successfully."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
callback = Mock()
|
||||
|
||||
with patch.object(extractor, 'extract_csi', new_callable=AsyncMock) as mock_extract:
|
||||
mock_extract.return_value = sample_csi_data
|
||||
|
||||
# Start streaming with limited iterations to avoid infinite loop
|
||||
streaming_task = asyncio.create_task(extractor.start_streaming(callback))
|
||||
await asyncio.sleep(0.1) # Let it run briefly
|
||||
extractor.stop_streaming()
|
||||
await streaming_task
|
||||
|
||||
callback.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_stop_streaming_gracefully(self, esp32_config, mock_logger):
|
||||
"""Should stop streaming gracefully."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_streaming = True
|
||||
|
||||
extractor.stop_streaming()
|
||||
|
||||
assert extractor.is_streaming == False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_with_exception(self, esp32_config, mock_logger):
|
||||
"""Should handle exceptions during streaming."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
callback = Mock()
|
||||
|
||||
with patch.object(extractor, 'extract_csi', new_callable=AsyncMock) as mock_extract:
|
||||
mock_extract.side_effect = Exception("Extraction error")
|
||||
|
||||
# Start streaming and let it handle the exception
|
||||
streaming_task = asyncio.create_task(extractor.start_streaming(callback))
|
||||
await asyncio.sleep(0.1) # Let it run briefly and hit the exception
|
||||
await streaming_task
|
||||
|
||||
# Should log error and stop streaming
|
||||
assert extractor.is_streaming == False
|
||||
extractor.logger.error.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestESP32CSIParserDirect:
|
||||
"""Test ESP32 CSI parser with direct imports."""
|
||||
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
"""Create ESP32 CSI parser for testing."""
|
||||
return ESP32CSIParser()
|
||||
|
||||
@pytest.fixture
|
||||
def raw_esp32_data(self):
|
||||
"""Sample raw ESP32 CSI data."""
|
||||
return b"CSI_DATA:1234567890,3,56,2400,20,15.5,[1.0,2.0,3.0],[0.5,1.5,2.5]"
|
||||
|
||||
def test_should_parse_valid_esp32_data(self, parser, raw_esp32_data):
|
||||
"""Should parse valid ESP32 CSI data successfully."""
|
||||
result = parser.parse(raw_esp32_data)
|
||||
|
||||
assert isinstance(result, CSIData)
|
||||
assert result.num_antennas == 3
|
||||
assert result.num_subcarriers == 56
|
||||
assert result.frequency == 2400000000 # 2.4 GHz
|
||||
assert result.bandwidth == 20000000 # 20 MHz
|
||||
assert result.snr == 15.5
|
||||
|
||||
def test_should_handle_malformed_data(self, parser):
|
||||
"""Should handle malformed ESP32 data gracefully."""
|
||||
malformed_data = b"INVALID_DATA"
|
||||
|
||||
with pytest.raises(CSIParseError, match="Invalid ESP32 CSI data format"):
|
||||
parser.parse(malformed_data)
|
||||
|
||||
def test_should_handle_empty_data(self, parser):
|
||||
"""Should handle empty data gracefully."""
|
||||
with pytest.raises(CSIParseError, match="Empty data received"):
|
||||
parser.parse(b"")
|
||||
|
||||
def test_parse_with_value_error(self, parser):
|
||||
"""Should handle ValueError during parsing."""
|
||||
invalid_data = b"CSI_DATA:invalid_timestamp,3,56,2400,20,15.5"
|
||||
|
||||
with pytest.raises(CSIParseError, match="Failed to parse ESP32 data"):
|
||||
parser.parse(invalid_data)
|
||||
|
||||
def test_parse_with_index_error(self, parser):
|
||||
"""Should handle IndexError during parsing."""
|
||||
invalid_data = b"CSI_DATA:1234567890" # Missing fields
|
||||
|
||||
with pytest.raises(CSIParseError, match="Failed to parse ESP32 data"):
|
||||
parser.parse(invalid_data)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestRouterCSIParserDirect:
|
||||
"""Test Router CSI parser with direct imports."""
|
||||
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
"""Create Router CSI parser for testing."""
|
||||
return RouterCSIParser()
|
||||
|
||||
def test_should_parse_atheros_format(self, parser):
|
||||
"""Should parse Atheros CSI format successfully."""
|
||||
raw_data = b"ATHEROS_CSI:mock_data"
|
||||
|
||||
with patch.object(parser, '_parse_atheros_format', return_value=Mock(spec=CSIData)) as mock_parse:
|
||||
result = parser.parse(raw_data)
|
||||
|
||||
mock_parse.assert_called_once()
|
||||
assert result is not None
|
||||
|
||||
def test_should_handle_unknown_format(self, parser):
|
||||
"""Should handle unknown router format gracefully."""
|
||||
unknown_data = b"UNKNOWN_FORMAT:data"
|
||||
|
||||
with pytest.raises(CSIParseError, match="Unknown router CSI format"):
|
||||
parser.parse(unknown_data)
|
||||
|
||||
def test_parse_atheros_format_directly(self, parser):
|
||||
"""Should parse Atheros format directly."""
|
||||
raw_data = b"ATHEROS_CSI:mock_data"
|
||||
|
||||
result = parser.parse(raw_data)
|
||||
|
||||
assert isinstance(result, CSIData)
|
||||
assert result.metadata['source'] == 'atheros_router'
|
||||
|
||||
def test_should_handle_empty_data_router(self, parser):
|
||||
"""Should handle empty data gracefully."""
|
||||
with pytest.raises(CSIParseError, match="Empty data received"):
|
||||
parser.parse(b"")
|
||||
275
v1/tests/unit/test_csi_extractor_tdd.py
Normal file
275
v1/tests/unit/test_csi_extractor_tdd.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""Test-Driven Development tests for CSI extractor using London School approach."""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||||
from typing import Dict, Any, Optional
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from src.hardware.csi_extractor import (
|
||||
CSIExtractor,
|
||||
CSIParseError,
|
||||
CSIData,
|
||||
ESP32CSIParser,
|
||||
RouterCSIParser,
|
||||
CSIValidationError
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestCSIExtractor:
|
||||
"""Test CSI extractor using London School TDD - focus on interactions and behavior."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger(self):
|
||||
"""Mock logger for testing."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Mock configuration for CSI extractor."""
|
||||
return {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 1024,
|
||||
'timeout': 5.0,
|
||||
'validation_enabled': True,
|
||||
'retry_attempts': 3
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def csi_extractor(self, mock_config, mock_logger):
|
||||
"""Create CSI extractor instance for testing."""
|
||||
return CSIExtractor(config=mock_config, logger=mock_logger)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_csi_data(self):
|
||||
"""Sample CSI data for testing."""
|
||||
return CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={'source': 'esp32', 'channel': 6}
|
||||
)
|
||||
|
||||
def test_should_initialize_with_valid_config(self, mock_config, mock_logger):
|
||||
"""Should initialize CSI extractor with valid configuration."""
|
||||
extractor = CSIExtractor(config=mock_config, logger=mock_logger)
|
||||
|
||||
assert extractor.config == mock_config
|
||||
assert extractor.logger == mock_logger
|
||||
assert extractor.is_connected == False
|
||||
assert extractor.hardware_type == 'esp32'
|
||||
|
||||
def test_should_raise_error_with_invalid_config(self, mock_logger):
|
||||
"""Should raise error when initialized with invalid configuration."""
|
||||
invalid_config = {'invalid': 'config'}
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required configuration"):
|
||||
CSIExtractor(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_should_create_appropriate_parser(self, mock_config, mock_logger):
|
||||
"""Should create appropriate parser based on hardware type."""
|
||||
extractor = CSIExtractor(config=mock_config, logger=mock_logger)
|
||||
|
||||
assert isinstance(extractor.parser, ESP32CSIParser)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_establish_connection_successfully(self, csi_extractor):
|
||||
"""Should establish connection to hardware successfully."""
|
||||
with patch.object(csi_extractor, '_establish_hardware_connection', new_callable=AsyncMock) as mock_connect:
|
||||
mock_connect.return_value = True
|
||||
|
||||
result = await csi_extractor.connect()
|
||||
|
||||
assert result == True
|
||||
assert csi_extractor.is_connected == True
|
||||
mock_connect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_handle_connection_failure(self, csi_extractor):
|
||||
"""Should handle connection failure gracefully."""
|
||||
with patch.object(csi_extractor, '_establish_hardware_connection', new_callable=AsyncMock) as mock_connect:
|
||||
mock_connect.side_effect = ConnectionError("Hardware not found")
|
||||
|
||||
result = await csi_extractor.connect()
|
||||
|
||||
assert result == False
|
||||
assert csi_extractor.is_connected == False
|
||||
csi_extractor.logger.error.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_disconnect_properly(self, csi_extractor):
|
||||
"""Should disconnect from hardware properly."""
|
||||
csi_extractor.is_connected = True
|
||||
|
||||
with patch.object(csi_extractor, '_close_hardware_connection', new_callable=AsyncMock) as mock_disconnect:
|
||||
await csi_extractor.disconnect()
|
||||
|
||||
assert csi_extractor.is_connected == False
|
||||
mock_disconnect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_extract_csi_data_successfully(self, csi_extractor, sample_csi_data):
|
||||
"""Should extract CSI data successfully from hardware."""
|
||||
csi_extractor.is_connected = True
|
||||
|
||||
with patch.object(csi_extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
|
||||
with patch.object(csi_extractor.parser, 'parse', return_value=sample_csi_data) as mock_parse:
|
||||
mock_read.return_value = b"raw_csi_data"
|
||||
|
||||
result = await csi_extractor.extract_csi()
|
||||
|
||||
assert result == sample_csi_data
|
||||
mock_read.assert_called_once()
|
||||
mock_parse.assert_called_once_with(b"raw_csi_data")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_handle_extraction_failure_when_not_connected(self, csi_extractor):
|
||||
"""Should handle extraction failure when not connected."""
|
||||
csi_extractor.is_connected = False
|
||||
|
||||
with pytest.raises(CSIParseError, match="Not connected to hardware"):
|
||||
await csi_extractor.extract_csi()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_retry_on_temporary_failure(self, csi_extractor, sample_csi_data):
|
||||
"""Should retry extraction on temporary failure."""
|
||||
csi_extractor.is_connected = True
|
||||
|
||||
with patch.object(csi_extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
|
||||
with patch.object(csi_extractor.parser, 'parse') as mock_parse:
|
||||
# First two calls fail, third succeeds
|
||||
mock_read.side_effect = [ConnectionError(), ConnectionError(), b"raw_data"]
|
||||
mock_parse.return_value = sample_csi_data
|
||||
|
||||
result = await csi_extractor.extract_csi()
|
||||
|
||||
assert result == sample_csi_data
|
||||
assert mock_read.call_count == 3
|
||||
|
||||
def test_should_validate_csi_data_successfully(self, csi_extractor, sample_csi_data):
|
||||
"""Should validate CSI data successfully."""
|
||||
result = csi_extractor.validate_csi_data(sample_csi_data)
|
||||
|
||||
assert result == True
|
||||
|
||||
def test_should_reject_invalid_csi_data(self, csi_extractor):
|
||||
"""Should reject CSI data with invalid structure."""
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.array([]), # Empty array
|
||||
phase=np.array([]),
|
||||
frequency=0, # Invalid frequency
|
||||
bandwidth=0,
|
||||
num_subcarriers=0,
|
||||
num_antennas=0,
|
||||
snr=-100, # Invalid SNR
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError):
|
||||
csi_extractor.validate_csi_data(invalid_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_start_streaming_successfully(self, csi_extractor, sample_csi_data):
|
||||
"""Should start CSI data streaming successfully."""
|
||||
csi_extractor.is_connected = True
|
||||
callback = Mock()
|
||||
|
||||
with patch.object(csi_extractor, 'extract_csi', new_callable=AsyncMock) as mock_extract:
|
||||
mock_extract.return_value = sample_csi_data
|
||||
|
||||
# Start streaming with limited iterations to avoid infinite loop
|
||||
streaming_task = asyncio.create_task(csi_extractor.start_streaming(callback))
|
||||
await asyncio.sleep(0.1) # Let it run briefly
|
||||
csi_extractor.stop_streaming()
|
||||
await streaming_task
|
||||
|
||||
callback.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_stop_streaming_gracefully(self, csi_extractor):
|
||||
"""Should stop streaming gracefully."""
|
||||
csi_extractor.is_streaming = True
|
||||
|
||||
csi_extractor.stop_streaming()
|
||||
|
||||
assert csi_extractor.is_streaming == False
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestESP32CSIParser:
|
||||
"""Test ESP32 CSI parser using London School TDD."""
|
||||
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
"""Create ESP32 CSI parser for testing."""
|
||||
return ESP32CSIParser()
|
||||
|
||||
@pytest.fixture
|
||||
def raw_esp32_data(self):
|
||||
"""Sample raw ESP32 CSI data."""
|
||||
return b"CSI_DATA:1234567890,3,56,2400,20,15.5,[1.0,2.0,3.0],[0.5,1.5,2.5]"
|
||||
|
||||
def test_should_parse_valid_esp32_data(self, parser, raw_esp32_data):
|
||||
"""Should parse valid ESP32 CSI data successfully."""
|
||||
result = parser.parse(raw_esp32_data)
|
||||
|
||||
assert isinstance(result, CSIData)
|
||||
assert result.num_antennas == 3
|
||||
assert result.num_subcarriers == 56
|
||||
assert result.frequency == 2400000000 # 2.4 GHz
|
||||
assert result.bandwidth == 20000000 # 20 MHz
|
||||
assert result.snr == 15.5
|
||||
|
||||
def test_should_handle_malformed_data(self, parser):
|
||||
"""Should handle malformed ESP32 data gracefully."""
|
||||
malformed_data = b"INVALID_DATA"
|
||||
|
||||
with pytest.raises(CSIParseError, match="Invalid ESP32 CSI data format"):
|
||||
parser.parse(malformed_data)
|
||||
|
||||
def test_should_handle_empty_data(self, parser):
|
||||
"""Should handle empty data gracefully."""
|
||||
with pytest.raises(CSIParseError, match="Empty data received"):
|
||||
parser.parse(b"")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestRouterCSIParser:
|
||||
"""Test Router CSI parser using London School TDD."""
|
||||
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
"""Create Router CSI parser for testing."""
|
||||
return RouterCSIParser()
|
||||
|
||||
def test_should_parse_atheros_format(self, parser):
|
||||
"""Should parse Atheros CSI format successfully."""
|
||||
raw_data = b"ATHEROS_CSI:mock_data"
|
||||
|
||||
with patch.object(parser, '_parse_atheros_format', return_value=Mock(spec=CSIData)) as mock_parse:
|
||||
result = parser.parse(raw_data)
|
||||
|
||||
mock_parse.assert_called_once()
|
||||
assert result is not None
|
||||
|
||||
def test_should_handle_unknown_format(self, parser):
|
||||
"""Should handle unknown router format gracefully."""
|
||||
unknown_data = b"UNKNOWN_FORMAT:data"
|
||||
|
||||
with pytest.raises(CSIParseError, match="Unknown router CSI format"):
|
||||
parser.parse(unknown_data)
|
||||
386
v1/tests/unit/test_csi_extractor_tdd_complete.py
Normal file
386
v1/tests/unit/test_csi_extractor_tdd_complete.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""Complete TDD tests for CSI extractor with 100% coverage."""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||||
from typing import Dict, Any, Optional
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from src.hardware.csi_extractor import (
|
||||
CSIExtractor,
|
||||
CSIParseError,
|
||||
CSIData,
|
||||
ESP32CSIParser,
|
||||
RouterCSIParser,
|
||||
CSIValidationError
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestCSIExtractorComplete:
|
||||
"""Complete CSI extractor tests for 100% coverage."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger(self):
|
||||
"""Mock logger for testing."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def esp32_config(self):
|
||||
"""ESP32 configuration for testing."""
|
||||
return {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 1024,
|
||||
'timeout': 5.0,
|
||||
'validation_enabled': True,
|
||||
'retry_attempts': 3
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def router_config(self):
|
||||
"""Router configuration for testing."""
|
||||
return {
|
||||
'hardware_type': 'router',
|
||||
'sampling_rate': 50,
|
||||
'buffer_size': 512,
|
||||
'timeout': 10.0,
|
||||
'validation_enabled': False,
|
||||
'retry_attempts': 1
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def sample_csi_data(self):
|
||||
"""Sample CSI data for testing."""
|
||||
return CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={'source': 'esp32', 'channel': 6}
|
||||
)
|
||||
|
||||
def test_should_create_router_parser(self, router_config, mock_logger):
|
||||
"""Should create router parser when hardware_type is router."""
|
||||
extractor = CSIExtractor(config=router_config, logger=mock_logger)
|
||||
|
||||
assert isinstance(extractor.parser, RouterCSIParser)
|
||||
assert extractor.hardware_type == 'router'
|
||||
|
||||
def test_should_raise_error_for_unsupported_hardware(self, mock_logger):
|
||||
"""Should raise error for unsupported hardware type."""
|
||||
invalid_config = {
|
||||
'hardware_type': 'unsupported',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 1024,
|
||||
'timeout': 5.0
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported hardware type: unsupported"):
|
||||
CSIExtractor(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_config_validation_negative_sampling_rate(self, mock_logger):
|
||||
"""Should validate sampling_rate is positive."""
|
||||
invalid_config = {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': -1,
|
||||
'buffer_size': 1024,
|
||||
'timeout': 5.0
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="sampling_rate must be positive"):
|
||||
CSIExtractor(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_config_validation_zero_buffer_size(self, mock_logger):
|
||||
"""Should validate buffer_size is positive."""
|
||||
invalid_config = {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 0,
|
||||
'timeout': 5.0
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="buffer_size must be positive"):
|
||||
CSIExtractor(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_config_validation_negative_timeout(self, mock_logger):
|
||||
"""Should validate timeout is positive."""
|
||||
invalid_config = {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 1024,
|
||||
'timeout': -1.0
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="timeout must be positive"):
|
||||
CSIExtractor(config=invalid_config, logger=mock_logger)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_when_not_connected(self, esp32_config, mock_logger):
|
||||
"""Should handle disconnect when not connected."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = False
|
||||
|
||||
with patch.object(extractor, '_close_hardware_connection', new_callable=AsyncMock) as mock_close:
|
||||
await extractor.disconnect()
|
||||
|
||||
# Should not call close when not connected
|
||||
mock_close.assert_not_called()
|
||||
assert extractor.is_connected == False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_with_validation_disabled(self, esp32_config, mock_logger, sample_csi_data):
|
||||
"""Should skip validation when disabled."""
|
||||
esp32_config['validation_enabled'] = False
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
|
||||
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
|
||||
with patch.object(extractor.parser, 'parse', return_value=sample_csi_data) as mock_parse:
|
||||
with patch.object(extractor, 'validate_csi_data') as mock_validate:
|
||||
mock_read.return_value = b"raw_data"
|
||||
|
||||
result = await extractor.extract_csi()
|
||||
|
||||
assert result == sample_csi_data
|
||||
mock_validate.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_max_retries_exceeded(self, esp32_config, mock_logger):
|
||||
"""Should raise error after max retries exceeded."""
|
||||
esp32_config['retry_attempts'] = 2
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
|
||||
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
|
||||
mock_read.side_effect = ConnectionError("Connection failed")
|
||||
|
||||
with pytest.raises(CSIParseError, match="Extraction failed after 2 attempts"):
|
||||
await extractor.extract_csi()
|
||||
|
||||
assert mock_read.call_count == 2
|
||||
|
||||
def test_validation_empty_amplitude(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for empty amplitude."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.array([]),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Empty amplitude data"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_empty_phase(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for empty phase."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.array([]),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Empty phase data"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_invalid_frequency(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for invalid frequency."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=0,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid frequency"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_invalid_bandwidth(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for invalid bandwidth."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=0,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid bandwidth"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_invalid_subcarriers(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for invalid subcarriers."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=0,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid number of subcarriers"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_invalid_antennas(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for invalid antennas."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=0,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid number of antennas"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_snr_too_low(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for SNR too low."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=-100,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid SNR value"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_snr_too_high(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for SNR too high."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=100,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid SNR value"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_with_exception(self, esp32_config, mock_logger):
|
||||
"""Should handle exceptions during streaming."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
callback = Mock()
|
||||
|
||||
with patch.object(extractor, 'extract_csi', new_callable=AsyncMock) as mock_extract:
|
||||
mock_extract.side_effect = Exception("Extraction error")
|
||||
|
||||
# Start streaming and let it handle the exception
|
||||
streaming_task = asyncio.create_task(extractor.start_streaming(callback))
|
||||
await asyncio.sleep(0.1) # Let it run briefly and hit the exception
|
||||
await streaming_task
|
||||
|
||||
# Should log error and stop streaming
|
||||
assert extractor.is_streaming == False
|
||||
extractor.logger.error.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestESP32CSIParserComplete:
|
||||
"""Complete ESP32 CSI parser tests for 100% coverage."""
|
||||
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
"""Create ESP32 CSI parser for testing."""
|
||||
return ESP32CSIParser()
|
||||
|
||||
def test_parse_with_value_error(self, parser):
|
||||
"""Should handle ValueError during parsing."""
|
||||
invalid_data = b"CSI_DATA:invalid_timestamp,3,56,2400,20,15.5"
|
||||
|
||||
with pytest.raises(CSIParseError, match="Failed to parse ESP32 data"):
|
||||
parser.parse(invalid_data)
|
||||
|
||||
def test_parse_with_index_error(self, parser):
|
||||
"""Should handle IndexError during parsing."""
|
||||
invalid_data = b"CSI_DATA:1234567890" # Missing fields
|
||||
|
||||
with pytest.raises(CSIParseError, match="Failed to parse ESP32 data"):
|
||||
parser.parse(invalid_data)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestRouterCSIParserComplete:
|
||||
"""Complete Router CSI parser tests for 100% coverage."""
|
||||
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
"""Create Router CSI parser for testing."""
|
||||
return RouterCSIParser()
|
||||
|
||||
def test_parse_atheros_format_directly(self, parser):
|
||||
"""Should parse Atheros format directly."""
|
||||
raw_data = b"ATHEROS_CSI:mock_data"
|
||||
|
||||
result = parser.parse(raw_data)
|
||||
|
||||
assert isinstance(result, CSIData)
|
||||
assert result.metadata['source'] == 'atheros_router'
|
||||
87
v1/tests/unit/test_csi_processor.py
Normal file
87
v1/tests/unit/test_csi_processor.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch
|
||||
from src.core.csi_processor import CSIProcessor
|
||||
|
||||
|
||||
class TestCSIProcessor:
|
||||
"""Test suite for CSI processor following London School TDD principles"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_csi_data(self):
|
||||
"""Generate synthetic CSI data for testing"""
|
||||
# Simple raw CSI data array for testing
|
||||
return np.random.uniform(0.1, 2.0, (3, 56, 100))
|
||||
|
||||
@pytest.fixture
|
||||
def csi_processor(self):
|
||||
"""Create CSI processor instance for testing"""
|
||||
return CSIProcessor()
|
||||
|
||||
def test_process_csi_data_returns_normalized_output(self, csi_processor, mock_csi_data):
|
||||
"""Test that CSI processing returns properly normalized output"""
|
||||
# Act
|
||||
result = csi_processor.process_raw_csi(mock_csi_data)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert isinstance(result, np.ndarray)
|
||||
assert result.shape == mock_csi_data.shape
|
||||
|
||||
# Verify normalization - mean should be close to 0, std close to 1
|
||||
assert abs(result.mean()) < 0.1
|
||||
assert abs(result.std() - 1.0) < 0.1
|
||||
|
||||
def test_process_csi_data_handles_invalid_input(self, csi_processor):
|
||||
"""Test that CSI processor handles invalid input gracefully"""
|
||||
# Arrange
|
||||
invalid_data = np.array([])
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Raw CSI data cannot be empty"):
|
||||
csi_processor.process_raw_csi(invalid_data)
|
||||
|
||||
def test_process_csi_data_removes_nan_values(self, csi_processor, mock_csi_data):
|
||||
"""Test that CSI processor removes NaN values from input"""
|
||||
# Arrange
|
||||
mock_csi_data[0, 0, 0] = np.nan
|
||||
|
||||
# Act
|
||||
result = csi_processor.process_raw_csi(mock_csi_data)
|
||||
|
||||
# Assert
|
||||
assert not np.isnan(result).any()
|
||||
|
||||
def test_process_csi_data_applies_temporal_filtering(self, csi_processor, mock_csi_data):
|
||||
"""Test that temporal filtering is applied to CSI data"""
|
||||
# Arrange - Add noise to make filtering effect visible
|
||||
noisy_data = mock_csi_data + np.random.normal(0, 0.1, mock_csi_data.shape)
|
||||
|
||||
# Act
|
||||
result = csi_processor.process_raw_csi(noisy_data)
|
||||
|
||||
# Assert - Result should be normalized
|
||||
assert isinstance(result, np.ndarray)
|
||||
assert result.shape == noisy_data.shape
|
||||
|
||||
def test_process_csi_data_preserves_metadata(self, csi_processor, mock_csi_data):
|
||||
"""Test that metadata is preserved during processing"""
|
||||
# Act
|
||||
result = csi_processor.process_raw_csi(mock_csi_data)
|
||||
|
||||
# Assert - For now, just verify processing works
|
||||
assert result is not None
|
||||
assert isinstance(result, np.ndarray)
|
||||
|
||||
def test_process_csi_data_performance_requirement(self, csi_processor, mock_csi_data):
|
||||
"""Test that CSI processing meets performance requirements (<10ms)"""
|
||||
import time
|
||||
|
||||
# Act
|
||||
start_time = time.time()
|
||||
result = csi_processor.process_raw_csi(mock_csi_data)
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
# Assert
|
||||
assert processing_time < 0.01 # <10ms requirement
|
||||
assert result is not None
|
||||
479
v1/tests/unit/test_csi_processor_tdd.py
Normal file
479
v1/tests/unit/test_csi_processor_tdd.py
Normal file
@@ -0,0 +1,479 @@
|
||||
"""TDD tests for CSI processor following London School approach."""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||||
from datetime import datetime, timezone
|
||||
import importlib.util
|
||||
from typing import Dict, List, Any
|
||||
|
||||
# Import the CSI processor module directly
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
'csi_processor',
|
||||
'/workspaces/wifi-densepose/src/core/csi_processor.py'
|
||||
)
|
||||
csi_processor_module = importlib.util.module_from_spec(spec)
|
||||
|
||||
# Import CSI extractor for dependencies
|
||||
csi_spec = importlib.util.spec_from_file_location(
|
||||
'csi_extractor',
|
||||
'/workspaces/wifi-densepose/src/hardware/csi_extractor.py'
|
||||
)
|
||||
csi_module = importlib.util.module_from_spec(csi_spec)
|
||||
csi_spec.loader.exec_module(csi_module)
|
||||
|
||||
# Make dependencies available and load the processor
|
||||
csi_processor_module.CSIData = csi_module.CSIData
|
||||
spec.loader.exec_module(csi_processor_module)
|
||||
|
||||
# Get classes from modules
|
||||
CSIProcessor = csi_processor_module.CSIProcessor
|
||||
CSIProcessingError = csi_processor_module.CSIProcessingError
|
||||
HumanDetectionResult = csi_processor_module.HumanDetectionResult
|
||||
CSIFeatures = csi_processor_module.CSIFeatures
|
||||
CSIData = csi_module.CSIData
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestCSIProcessor:
|
||||
"""Test CSI processor using London School TDD."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger(self):
|
||||
"""Mock logger for testing."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def processor_config(self):
|
||||
"""CSI processor configuration for testing."""
|
||||
return {
|
||||
'sampling_rate': 100,
|
||||
'window_size': 256,
|
||||
'overlap': 0.5,
|
||||
'noise_threshold': -60.0,
|
||||
'human_detection_threshold': 0.7,
|
||||
'smoothing_factor': 0.8,
|
||||
'max_history_size': 1000,
|
||||
'enable_preprocessing': True,
|
||||
'enable_feature_extraction': True,
|
||||
'enable_human_detection': True
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def csi_processor(self, processor_config, mock_logger):
|
||||
"""Create CSI processor for testing."""
|
||||
return CSIProcessor(config=processor_config, logger=mock_logger)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_csi_data(self):
|
||||
"""Sample CSI data for testing."""
|
||||
return CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56) + 1.0, # Ensure positive amplitude
|
||||
phase=np.random.uniform(-np.pi, np.pi, (3, 56)),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={'source': 'test'}
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_features(self):
|
||||
"""Sample CSI features for testing."""
|
||||
return CSIFeatures(
|
||||
amplitude_mean=np.random.rand(56),
|
||||
amplitude_variance=np.random.rand(56),
|
||||
phase_difference=np.random.rand(56),
|
||||
correlation_matrix=np.random.rand(3, 3),
|
||||
doppler_shift=np.random.rand(10),
|
||||
power_spectral_density=np.random.rand(128),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
metadata={'processing_params': {}}
|
||||
)
|
||||
|
||||
# Initialization tests
|
||||
def test_should_initialize_with_valid_config(self, processor_config, mock_logger):
|
||||
"""Should initialize CSI processor with valid configuration."""
|
||||
processor = CSIProcessor(config=processor_config, logger=mock_logger)
|
||||
|
||||
assert processor.config == processor_config
|
||||
assert processor.logger == mock_logger
|
||||
assert processor.sampling_rate == 100
|
||||
assert processor.window_size == 256
|
||||
assert processor.overlap == 0.5
|
||||
assert processor.noise_threshold == -60.0
|
||||
assert processor.human_detection_threshold == 0.7
|
||||
assert processor.smoothing_factor == 0.8
|
||||
assert processor.max_history_size == 1000
|
||||
assert len(processor.csi_history) == 0
|
||||
|
||||
def test_should_raise_error_with_invalid_config(self, mock_logger):
|
||||
"""Should raise error when initialized with invalid configuration."""
|
||||
invalid_config = {'invalid': 'config'}
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required configuration"):
|
||||
CSIProcessor(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_should_validate_required_fields(self, mock_logger):
|
||||
"""Should validate all required configuration fields."""
|
||||
required_fields = ['sampling_rate', 'window_size', 'overlap', 'noise_threshold']
|
||||
base_config = {
|
||||
'sampling_rate': 100,
|
||||
'window_size': 256,
|
||||
'overlap': 0.5,
|
||||
'noise_threshold': -60.0
|
||||
}
|
||||
|
||||
for field in required_fields:
|
||||
config = base_config.copy()
|
||||
del config[field]
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required configuration"):
|
||||
CSIProcessor(config=config, logger=mock_logger)
|
||||
|
||||
def test_should_use_default_values(self, mock_logger):
|
||||
"""Should use default values for optional parameters."""
|
||||
minimal_config = {
|
||||
'sampling_rate': 100,
|
||||
'window_size': 256,
|
||||
'overlap': 0.5,
|
||||
'noise_threshold': -60.0
|
||||
}
|
||||
|
||||
processor = CSIProcessor(config=minimal_config, logger=mock_logger)
|
||||
|
||||
assert processor.human_detection_threshold == 0.8 # default
|
||||
assert processor.smoothing_factor == 0.9 # default
|
||||
assert processor.max_history_size == 500 # default
|
||||
|
||||
def test_should_initialize_without_logger(self, processor_config):
|
||||
"""Should initialize without logger provided."""
|
||||
processor = CSIProcessor(config=processor_config)
|
||||
|
||||
assert processor.logger is not None # Should create default logger
|
||||
|
||||
# Preprocessing tests
|
||||
def test_should_preprocess_csi_data_successfully(self, csi_processor, sample_csi_data):
|
||||
"""Should preprocess CSI data successfully."""
|
||||
with patch.object(csi_processor, '_remove_noise') as mock_noise:
|
||||
with patch.object(csi_processor, '_apply_windowing') as mock_window:
|
||||
with patch.object(csi_processor, '_normalize_amplitude') as mock_normalize:
|
||||
mock_noise.return_value = sample_csi_data
|
||||
mock_window.return_value = sample_csi_data
|
||||
mock_normalize.return_value = sample_csi_data
|
||||
|
||||
result = csi_processor.preprocess_csi_data(sample_csi_data)
|
||||
|
||||
assert result == sample_csi_data
|
||||
mock_noise.assert_called_once_with(sample_csi_data)
|
||||
mock_window.assert_called_once()
|
||||
mock_normalize.assert_called_once()
|
||||
|
||||
def test_should_skip_preprocessing_when_disabled(self, processor_config, mock_logger, sample_csi_data):
|
||||
"""Should skip preprocessing when disabled."""
|
||||
processor_config['enable_preprocessing'] = False
|
||||
processor = CSIProcessor(config=processor_config, logger=mock_logger)
|
||||
|
||||
result = processor.preprocess_csi_data(sample_csi_data)
|
||||
|
||||
assert result == sample_csi_data
|
||||
|
||||
def test_should_handle_preprocessing_error(self, csi_processor, sample_csi_data):
|
||||
"""Should handle preprocessing errors gracefully."""
|
||||
with patch.object(csi_processor, '_remove_noise') as mock_noise:
|
||||
mock_noise.side_effect = Exception("Preprocessing error")
|
||||
|
||||
with pytest.raises(CSIProcessingError, match="Failed to preprocess CSI data"):
|
||||
csi_processor.preprocess_csi_data(sample_csi_data)
|
||||
|
||||
# Feature extraction tests
|
||||
def test_should_extract_features_successfully(self, csi_processor, sample_csi_data, sample_features):
|
||||
"""Should extract features from CSI data successfully."""
|
||||
with patch.object(csi_processor, '_extract_amplitude_features') as mock_amp:
|
||||
with patch.object(csi_processor, '_extract_phase_features') as mock_phase:
|
||||
with patch.object(csi_processor, '_extract_correlation_features') as mock_corr:
|
||||
with patch.object(csi_processor, '_extract_doppler_features') as mock_doppler:
|
||||
mock_amp.return_value = (sample_features.amplitude_mean, sample_features.amplitude_variance)
|
||||
mock_phase.return_value = sample_features.phase_difference
|
||||
mock_corr.return_value = sample_features.correlation_matrix
|
||||
mock_doppler.return_value = (sample_features.doppler_shift, sample_features.power_spectral_density)
|
||||
|
||||
result = csi_processor.extract_features(sample_csi_data)
|
||||
|
||||
assert isinstance(result, CSIFeatures)
|
||||
assert np.array_equal(result.amplitude_mean, sample_features.amplitude_mean)
|
||||
assert np.array_equal(result.amplitude_variance, sample_features.amplitude_variance)
|
||||
mock_amp.assert_called_once()
|
||||
mock_phase.assert_called_once()
|
||||
mock_corr.assert_called_once()
|
||||
mock_doppler.assert_called_once()
|
||||
|
||||
def test_should_skip_feature_extraction_when_disabled(self, processor_config, mock_logger, sample_csi_data):
|
||||
"""Should skip feature extraction when disabled."""
|
||||
processor_config['enable_feature_extraction'] = False
|
||||
processor = CSIProcessor(config=processor_config, logger=mock_logger)
|
||||
|
||||
result = processor.extract_features(sample_csi_data)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_should_handle_feature_extraction_error(self, csi_processor, sample_csi_data):
|
||||
"""Should handle feature extraction errors gracefully."""
|
||||
with patch.object(csi_processor, '_extract_amplitude_features') as mock_amp:
|
||||
mock_amp.side_effect = Exception("Feature extraction error")
|
||||
|
||||
with pytest.raises(CSIProcessingError, match="Failed to extract features"):
|
||||
csi_processor.extract_features(sample_csi_data)
|
||||
|
||||
# Human detection tests
|
||||
def test_should_detect_human_presence_successfully(self, csi_processor, sample_features):
|
||||
"""Should detect human presence successfully."""
|
||||
with patch.object(csi_processor, '_analyze_motion_patterns') as mock_motion:
|
||||
with patch.object(csi_processor, '_calculate_detection_confidence') as mock_confidence:
|
||||
with patch.object(csi_processor, '_apply_temporal_smoothing') as mock_smooth:
|
||||
mock_motion.return_value = 0.9
|
||||
mock_confidence.return_value = 0.85
|
||||
mock_smooth.return_value = 0.88
|
||||
|
||||
result = csi_processor.detect_human_presence(sample_features)
|
||||
|
||||
assert isinstance(result, HumanDetectionResult)
|
||||
assert result.human_detected == True
|
||||
assert result.confidence == 0.88
|
||||
assert result.motion_score == 0.9
|
||||
mock_motion.assert_called_once()
|
||||
mock_confidence.assert_called_once()
|
||||
mock_smooth.assert_called_once()
|
||||
|
||||
def test_should_detect_no_human_presence(self, csi_processor, sample_features):
|
||||
"""Should detect no human presence when confidence is low."""
|
||||
with patch.object(csi_processor, '_analyze_motion_patterns') as mock_motion:
|
||||
with patch.object(csi_processor, '_calculate_detection_confidence') as mock_confidence:
|
||||
with patch.object(csi_processor, '_apply_temporal_smoothing') as mock_smooth:
|
||||
mock_motion.return_value = 0.3
|
||||
mock_confidence.return_value = 0.2
|
||||
mock_smooth.return_value = 0.25
|
||||
|
||||
result = csi_processor.detect_human_presence(sample_features)
|
||||
|
||||
assert result.human_detected == False
|
||||
assert result.confidence == 0.25
|
||||
assert result.motion_score == 0.3
|
||||
|
||||
def test_should_skip_human_detection_when_disabled(self, processor_config, mock_logger, sample_features):
|
||||
"""Should skip human detection when disabled."""
|
||||
processor_config['enable_human_detection'] = False
|
||||
processor = CSIProcessor(config=processor_config, logger=mock_logger)
|
||||
|
||||
result = processor.detect_human_presence(sample_features)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_should_handle_human_detection_error(self, csi_processor, sample_features):
|
||||
"""Should handle human detection errors gracefully."""
|
||||
with patch.object(csi_processor, '_analyze_motion_patterns') as mock_motion:
|
||||
mock_motion.side_effect = Exception("Detection error")
|
||||
|
||||
with pytest.raises(CSIProcessingError, match="Failed to detect human presence"):
|
||||
csi_processor.detect_human_presence(sample_features)
|
||||
|
||||
# Processing pipeline tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_process_csi_data_pipeline_successfully(self, csi_processor, sample_csi_data, sample_features):
|
||||
"""Should process CSI data through full pipeline successfully."""
|
||||
expected_detection = HumanDetectionResult(
|
||||
human_detected=True,
|
||||
confidence=0.85,
|
||||
motion_score=0.9,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
features=sample_features,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with patch.object(csi_processor, 'preprocess_csi_data', return_value=sample_csi_data) as mock_preprocess:
|
||||
with patch.object(csi_processor, 'extract_features', return_value=sample_features) as mock_features:
|
||||
with patch.object(csi_processor, 'detect_human_presence', return_value=expected_detection) as mock_detect:
|
||||
|
||||
result = await csi_processor.process_csi_data(sample_csi_data)
|
||||
|
||||
assert result == expected_detection
|
||||
mock_preprocess.assert_called_once_with(sample_csi_data)
|
||||
mock_features.assert_called_once_with(sample_csi_data)
|
||||
mock_detect.assert_called_once_with(sample_features)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_handle_pipeline_processing_error(self, csi_processor, sample_csi_data):
|
||||
"""Should handle pipeline processing errors gracefully."""
|
||||
with patch.object(csi_processor, 'preprocess_csi_data') as mock_preprocess:
|
||||
mock_preprocess.side_effect = CSIProcessingError("Pipeline error")
|
||||
|
||||
with pytest.raises(CSIProcessingError):
|
||||
await csi_processor.process_csi_data(sample_csi_data)
|
||||
|
||||
# History management tests
|
||||
def test_should_add_csi_data_to_history(self, csi_processor, sample_csi_data):
|
||||
"""Should add CSI data to history successfully."""
|
||||
csi_processor.add_to_history(sample_csi_data)
|
||||
|
||||
assert len(csi_processor.csi_history) == 1
|
||||
assert csi_processor.csi_history[0] == sample_csi_data
|
||||
|
||||
def test_should_maintain_history_size_limit(self, processor_config, mock_logger):
|
||||
"""Should maintain history size within limits."""
|
||||
processor_config['max_history_size'] = 2
|
||||
processor = CSIProcessor(config=processor_config, logger=mock_logger)
|
||||
|
||||
# Add 3 items to history of size 2
|
||||
for i in range(3):
|
||||
csi_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={'index': i}
|
||||
)
|
||||
processor.add_to_history(csi_data)
|
||||
|
||||
assert len(processor.csi_history) == 2
|
||||
assert processor.csi_history[0].metadata['index'] == 1 # First item removed
|
||||
assert processor.csi_history[1].metadata['index'] == 2
|
||||
|
||||
def test_should_clear_history(self, csi_processor, sample_csi_data):
|
||||
"""Should clear history successfully."""
|
||||
csi_processor.add_to_history(sample_csi_data)
|
||||
assert len(csi_processor.csi_history) > 0
|
||||
|
||||
csi_processor.clear_history()
|
||||
|
||||
assert len(csi_processor.csi_history) == 0
|
||||
|
||||
def test_should_get_recent_history(self, csi_processor):
|
||||
"""Should get recent history entries."""
|
||||
# Add 5 items to history
|
||||
for i in range(5):
|
||||
csi_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={'index': i}
|
||||
)
|
||||
csi_processor.add_to_history(csi_data)
|
||||
|
||||
recent = csi_processor.get_recent_history(3)
|
||||
|
||||
assert len(recent) == 3
|
||||
assert recent[0].metadata['index'] == 2 # Most recent first
|
||||
assert recent[1].metadata['index'] == 3
|
||||
assert recent[2].metadata['index'] == 4
|
||||
|
||||
# Statistics and monitoring tests
|
||||
def test_should_get_processing_statistics(self, csi_processor):
|
||||
"""Should get processing statistics."""
|
||||
# Simulate some processing
|
||||
csi_processor._total_processed = 100
|
||||
csi_processor._processing_errors = 5
|
||||
csi_processor._human_detections = 25
|
||||
|
||||
stats = csi_processor.get_processing_statistics()
|
||||
|
||||
assert isinstance(stats, dict)
|
||||
assert stats['total_processed'] == 100
|
||||
assert stats['processing_errors'] == 5
|
||||
assert stats['human_detections'] == 25
|
||||
assert stats['error_rate'] == 0.05
|
||||
assert stats['detection_rate'] == 0.25
|
||||
|
||||
def test_should_reset_statistics(self, csi_processor):
|
||||
"""Should reset processing statistics."""
|
||||
csi_processor._total_processed = 100
|
||||
csi_processor._processing_errors = 5
|
||||
csi_processor._human_detections = 25
|
||||
|
||||
csi_processor.reset_statistics()
|
||||
|
||||
assert csi_processor._total_processed == 0
|
||||
assert csi_processor._processing_errors == 0
|
||||
assert csi_processor._human_detections == 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestCSIFeatures:
|
||||
"""Test CSI features data structure."""
|
||||
|
||||
def test_should_create_csi_features(self):
|
||||
"""Should create CSI features successfully."""
|
||||
features = CSIFeatures(
|
||||
amplitude_mean=np.random.rand(56),
|
||||
amplitude_variance=np.random.rand(56),
|
||||
phase_difference=np.random.rand(56),
|
||||
correlation_matrix=np.random.rand(3, 3),
|
||||
doppler_shift=np.random.rand(10),
|
||||
power_spectral_density=np.random.rand(128),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
metadata={'test': 'data'}
|
||||
)
|
||||
|
||||
assert features.amplitude_mean.shape == (56,)
|
||||
assert features.amplitude_variance.shape == (56,)
|
||||
assert features.phase_difference.shape == (56,)
|
||||
assert features.correlation_matrix.shape == (3, 3)
|
||||
assert features.doppler_shift.shape == (10,)
|
||||
assert features.power_spectral_density.shape == (128,)
|
||||
assert isinstance(features.timestamp, datetime)
|
||||
assert features.metadata['test'] == 'data'
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestHumanDetectionResult:
|
||||
"""Test human detection result data structure."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_features(self):
|
||||
"""Sample features for testing."""
|
||||
return CSIFeatures(
|
||||
amplitude_mean=np.random.rand(56),
|
||||
amplitude_variance=np.random.rand(56),
|
||||
phase_difference=np.random.rand(56),
|
||||
correlation_matrix=np.random.rand(3, 3),
|
||||
doppler_shift=np.random.rand(10),
|
||||
power_spectral_density=np.random.rand(128),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
metadata={}
|
||||
)
|
||||
|
||||
def test_should_create_detection_result(self, sample_features):
|
||||
"""Should create human detection result successfully."""
|
||||
result = HumanDetectionResult(
|
||||
human_detected=True,
|
||||
confidence=0.85,
|
||||
motion_score=0.92,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
features=sample_features,
|
||||
metadata={'test': 'data'}
|
||||
)
|
||||
|
||||
assert result.human_detected == True
|
||||
assert result.confidence == 0.85
|
||||
assert result.motion_score == 0.92
|
||||
assert isinstance(result.timestamp, datetime)
|
||||
assert result.features == sample_features
|
||||
assert result.metadata['test'] == 'data'
|
||||
599
v1/tests/unit/test_csi_standalone.py
Normal file
599
v1/tests/unit/test_csi_standalone.py
Normal file
@@ -0,0 +1,599 @@
|
||||
"""Standalone tests for CSI extractor module."""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
import importlib.util
|
||||
|
||||
# Import the module directly to avoid circular imports
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
'csi_extractor',
|
||||
'/workspaces/wifi-densepose/src/hardware/csi_extractor.py'
|
||||
)
|
||||
csi_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(csi_module)
|
||||
|
||||
# Get classes from the module
|
||||
CSIExtractor = csi_module.CSIExtractor
|
||||
CSIParseError = csi_module.CSIParseError
|
||||
CSIData = csi_module.CSIData
|
||||
ESP32CSIParser = csi_module.ESP32CSIParser
|
||||
RouterCSIParser = csi_module.RouterCSIParser
|
||||
CSIValidationError = csi_module.CSIValidationError
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestCSIExtractorStandalone:
|
||||
"""Standalone tests for CSI extractor with 100% coverage."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger(self):
|
||||
"""Mock logger for testing."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def esp32_config(self):
|
||||
"""ESP32 configuration for testing."""
|
||||
return {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 1024,
|
||||
'timeout': 5.0,
|
||||
'validation_enabled': True,
|
||||
'retry_attempts': 3
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def router_config(self):
|
||||
"""Router configuration for testing."""
|
||||
return {
|
||||
'hardware_type': 'router',
|
||||
'sampling_rate': 50,
|
||||
'buffer_size': 512,
|
||||
'timeout': 10.0,
|
||||
'validation_enabled': False,
|
||||
'retry_attempts': 1
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def sample_csi_data(self):
|
||||
"""Sample CSI data for testing."""
|
||||
return CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={'source': 'esp32', 'channel': 6}
|
||||
)
|
||||
|
||||
# Test all initialization paths
|
||||
def test_init_esp32_config(self, esp32_config, mock_logger):
|
||||
"""Should initialize with ESP32 configuration."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
assert extractor.config == esp32_config
|
||||
assert extractor.logger == mock_logger
|
||||
assert extractor.is_connected == False
|
||||
assert extractor.hardware_type == 'esp32'
|
||||
assert isinstance(extractor.parser, ESP32CSIParser)
|
||||
|
||||
def test_init_router_config(self, router_config, mock_logger):
|
||||
"""Should initialize with router configuration."""
|
||||
extractor = CSIExtractor(config=router_config, logger=mock_logger)
|
||||
|
||||
assert isinstance(extractor.parser, RouterCSIParser)
|
||||
assert extractor.hardware_type == 'router'
|
||||
|
||||
def test_init_unsupported_hardware(self, mock_logger):
|
||||
"""Should raise error for unsupported hardware type."""
|
||||
invalid_config = {
|
||||
'hardware_type': 'unsupported',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 1024,
|
||||
'timeout': 5.0
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported hardware type: unsupported"):
|
||||
CSIExtractor(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_init_without_logger(self, esp32_config):
|
||||
"""Should initialize without logger."""
|
||||
extractor = CSIExtractor(config=esp32_config)
|
||||
|
||||
assert extractor.logger is not None # Should create default logger
|
||||
|
||||
# Test all validation paths
|
||||
def test_validation_missing_fields(self, mock_logger):
|
||||
"""Should validate missing required fields."""
|
||||
for missing_field in ['hardware_type', 'sampling_rate', 'buffer_size', 'timeout']:
|
||||
config = {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 1024,
|
||||
'timeout': 5.0
|
||||
}
|
||||
del config[missing_field]
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required configuration"):
|
||||
CSIExtractor(config=config, logger=mock_logger)
|
||||
|
||||
def test_validation_negative_sampling_rate(self, mock_logger):
|
||||
"""Should validate sampling_rate is positive."""
|
||||
config = {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': -1,
|
||||
'buffer_size': 1024,
|
||||
'timeout': 5.0
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="sampling_rate must be positive"):
|
||||
CSIExtractor(config=config, logger=mock_logger)
|
||||
|
||||
def test_validation_zero_buffer_size(self, mock_logger):
|
||||
"""Should validate buffer_size is positive."""
|
||||
config = {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 0,
|
||||
'timeout': 5.0
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="buffer_size must be positive"):
|
||||
CSIExtractor(config=config, logger=mock_logger)
|
||||
|
||||
def test_validation_negative_timeout(self, mock_logger):
|
||||
"""Should validate timeout is positive."""
|
||||
config = {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 1024,
|
||||
'timeout': -1.0
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="timeout must be positive"):
|
||||
CSIExtractor(config=config, logger=mock_logger)
|
||||
|
||||
# Test connection management
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_success(self, esp32_config, mock_logger):
|
||||
"""Should connect successfully."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
with patch.object(extractor, '_establish_hardware_connection', new_callable=AsyncMock) as mock_conn:
|
||||
mock_conn.return_value = True
|
||||
|
||||
result = await extractor.connect()
|
||||
|
||||
assert result == True
|
||||
assert extractor.is_connected == True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_failure(self, esp32_config, mock_logger):
|
||||
"""Should handle connection failure."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
with patch.object(extractor, '_establish_hardware_connection', new_callable=AsyncMock) as mock_conn:
|
||||
mock_conn.side_effect = ConnectionError("Failed")
|
||||
|
||||
result = await extractor.connect()
|
||||
|
||||
assert result == False
|
||||
assert extractor.is_connected == False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_when_connected(self, esp32_config, mock_logger):
|
||||
"""Should disconnect when connected."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
|
||||
with patch.object(extractor, '_close_hardware_connection', new_callable=AsyncMock) as mock_close:
|
||||
await extractor.disconnect()
|
||||
|
||||
assert extractor.is_connected == False
|
||||
mock_close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_when_not_connected(self, esp32_config, mock_logger):
|
||||
"""Should not disconnect when not connected."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = False
|
||||
|
||||
with patch.object(extractor, '_close_hardware_connection', new_callable=AsyncMock) as mock_close:
|
||||
await extractor.disconnect()
|
||||
|
||||
mock_close.assert_not_called()
|
||||
|
||||
# Test extraction
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_not_connected(self, esp32_config, mock_logger):
|
||||
"""Should raise error when not connected."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = False
|
||||
|
||||
with pytest.raises(CSIParseError, match="Not connected to hardware"):
|
||||
await extractor.extract_csi()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_success_with_validation(self, esp32_config, mock_logger, sample_csi_data):
|
||||
"""Should extract successfully with validation."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
|
||||
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
|
||||
with patch.object(extractor.parser, 'parse', return_value=sample_csi_data):
|
||||
with patch.object(extractor, 'validate_csi_data', return_value=True) as mock_validate:
|
||||
mock_read.return_value = b"raw_data"
|
||||
|
||||
result = await extractor.extract_csi()
|
||||
|
||||
assert result == sample_csi_data
|
||||
mock_validate.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_success_without_validation(self, esp32_config, mock_logger, sample_csi_data):
|
||||
"""Should extract successfully without validation."""
|
||||
esp32_config['validation_enabled'] = False
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
|
||||
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
|
||||
with patch.object(extractor.parser, 'parse', return_value=sample_csi_data):
|
||||
with patch.object(extractor, 'validate_csi_data') as mock_validate:
|
||||
mock_read.return_value = b"raw_data"
|
||||
|
||||
result = await extractor.extract_csi()
|
||||
|
||||
assert result == sample_csi_data
|
||||
mock_validate.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_retry_success(self, esp32_config, mock_logger, sample_csi_data):
|
||||
"""Should retry and succeed."""
|
||||
esp32_config['retry_attempts'] = 3
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
|
||||
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
|
||||
with patch.object(extractor.parser, 'parse', return_value=sample_csi_data):
|
||||
# Fail first two attempts, succeed on third
|
||||
mock_read.side_effect = [ConnectionError(), ConnectionError(), b"raw_data"]
|
||||
|
||||
result = await extractor.extract_csi()
|
||||
|
||||
assert result == sample_csi_data
|
||||
assert mock_read.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_retry_failure(self, esp32_config, mock_logger):
|
||||
"""Should fail after max retries."""
|
||||
esp32_config['retry_attempts'] = 2
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
|
||||
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
|
||||
mock_read.side_effect = ConnectionError("Failed")
|
||||
|
||||
with pytest.raises(CSIParseError, match="Extraction failed after 2 attempts"):
|
||||
await extractor.extract_csi()
|
||||
|
||||
# Test validation
|
||||
def test_validate_success(self, esp32_config, mock_logger, sample_csi_data):
|
||||
"""Should validate successfully."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
result = extractor.validate_csi_data(sample_csi_data)
|
||||
|
||||
assert result == True
|
||||
|
||||
def test_validate_empty_amplitude(self, esp32_config, mock_logger):
|
||||
"""Should reject empty amplitude."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.array([]),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Empty amplitude data"):
|
||||
extractor.validate_csi_data(data)
|
||||
|
||||
def test_validate_empty_phase(self, esp32_config, mock_logger):
|
||||
"""Should reject empty phase."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.array([]),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Empty phase data"):
|
||||
extractor.validate_csi_data(data)
|
||||
|
||||
def test_validate_invalid_frequency(self, esp32_config, mock_logger):
|
||||
"""Should reject invalid frequency."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=0,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid frequency"):
|
||||
extractor.validate_csi_data(data)
|
||||
|
||||
def test_validate_invalid_bandwidth(self, esp32_config, mock_logger):
|
||||
"""Should reject invalid bandwidth."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=0,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid bandwidth"):
|
||||
extractor.validate_csi_data(data)
|
||||
|
||||
def test_validate_invalid_subcarriers(self, esp32_config, mock_logger):
|
||||
"""Should reject invalid subcarriers."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=0,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid number of subcarriers"):
|
||||
extractor.validate_csi_data(data)
|
||||
|
||||
def test_validate_invalid_antennas(self, esp32_config, mock_logger):
|
||||
"""Should reject invalid antennas."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=0,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid number of antennas"):
|
||||
extractor.validate_csi_data(data)
|
||||
|
||||
def test_validate_snr_too_low(self, esp32_config, mock_logger):
|
||||
"""Should reject SNR too low."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=-100,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid SNR value"):
|
||||
extractor.validate_csi_data(data)
|
||||
|
||||
def test_validate_snr_too_high(self, esp32_config, mock_logger):
|
||||
"""Should reject SNR too high."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=100,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid SNR value"):
|
||||
extractor.validate_csi_data(data)
|
||||
|
||||
# Test streaming
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_success(self, esp32_config, mock_logger, sample_csi_data):
|
||||
"""Should stream successfully."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
callback = Mock()
|
||||
|
||||
with patch.object(extractor, 'extract_csi', new_callable=AsyncMock) as mock_extract:
|
||||
mock_extract.return_value = sample_csi_data
|
||||
|
||||
# Start streaming task
|
||||
task = asyncio.create_task(extractor.start_streaming(callback))
|
||||
await asyncio.sleep(0.1) # Let it run briefly
|
||||
extractor.stop_streaming()
|
||||
await task
|
||||
|
||||
callback.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_exception(self, esp32_config, mock_logger):
|
||||
"""Should handle streaming exceptions."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
callback = Mock()
|
||||
|
||||
with patch.object(extractor, 'extract_csi', new_callable=AsyncMock) as mock_extract:
|
||||
mock_extract.side_effect = Exception("Test error")
|
||||
|
||||
# Start streaming and let it handle exception
|
||||
task = asyncio.create_task(extractor.start_streaming(callback))
|
||||
await task # This should complete due to exception
|
||||
|
||||
assert extractor.is_streaming == False
|
||||
|
||||
def test_stop_streaming(self, esp32_config, mock_logger):
|
||||
"""Should stop streaming."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_streaming = True
|
||||
|
||||
extractor.stop_streaming()
|
||||
|
||||
assert extractor.is_streaming == False
|
||||
|
||||
# Test placeholder implementations for 100% coverage
|
||||
@pytest.mark.asyncio
|
||||
async def test_establish_hardware_connection_placeholder(self, esp32_config, mock_logger):
|
||||
"""Should test placeholder hardware connection."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
result = await extractor._establish_hardware_connection()
|
||||
|
||||
assert result == True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_hardware_connection_placeholder(self, esp32_config, mock_logger):
|
||||
"""Should test placeholder hardware disconnection."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
# Should not raise any exception
|
||||
await extractor._close_hardware_connection()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_raw_data_placeholder(self, esp32_config, mock_logger):
|
||||
"""Should test placeholder raw data reading."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
result = await extractor._read_raw_data()
|
||||
|
||||
assert result == b"CSI_DATA:1234567890,3,56,2400,20,15.5,[1.0,2.0,3.0],[0.5,1.5,2.5]"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
class TestESP32CSIParserStandalone:
|
||||
"""Standalone tests for ESP32 CSI parser."""
|
||||
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
"""Create parser instance."""
|
||||
return ESP32CSIParser()
|
||||
|
||||
def test_parse_valid_data(self, parser):
|
||||
"""Should parse valid ESP32 data."""
|
||||
data = b"CSI_DATA:1234567890,3,56,2400,20,15.5,[1.0,2.0,3.0],[0.5,1.5,2.5]"
|
||||
|
||||
result = parser.parse(data)
|
||||
|
||||
assert isinstance(result, CSIData)
|
||||
assert result.num_antennas == 3
|
||||
assert result.num_subcarriers == 56
|
||||
assert result.frequency == 2400000000
|
||||
assert result.bandwidth == 20000000
|
||||
assert result.snr == 15.5
|
||||
|
||||
def test_parse_empty_data(self, parser):
|
||||
"""Should reject empty data."""
|
||||
with pytest.raises(CSIParseError, match="Empty data received"):
|
||||
parser.parse(b"")
|
||||
|
||||
def test_parse_invalid_format(self, parser):
|
||||
"""Should reject invalid format."""
|
||||
with pytest.raises(CSIParseError, match="Invalid ESP32 CSI data format"):
|
||||
parser.parse(b"INVALID_DATA")
|
||||
|
||||
def test_parse_value_error(self, parser):
|
||||
"""Should handle ValueError."""
|
||||
data = b"CSI_DATA:invalid_number,3,56,2400,20,15.5"
|
||||
|
||||
with pytest.raises(CSIParseError, match="Failed to parse ESP32 data"):
|
||||
parser.parse(data)
|
||||
|
||||
def test_parse_index_error(self, parser):
|
||||
"""Should handle IndexError."""
|
||||
data = b"CSI_DATA:1234567890" # Missing fields
|
||||
|
||||
with pytest.raises(CSIParseError, match="Failed to parse ESP32 data"):
|
||||
parser.parse(data)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
class TestRouterCSIParserStandalone:
|
||||
"""Standalone tests for Router CSI parser."""
|
||||
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
"""Create parser instance."""
|
||||
return RouterCSIParser()
|
||||
|
||||
def test_parse_empty_data(self, parser):
|
||||
"""Should reject empty data."""
|
||||
with pytest.raises(CSIParseError, match="Empty data received"):
|
||||
parser.parse(b"")
|
||||
|
||||
def test_parse_atheros_format(self, parser):
|
||||
"""Should parse Atheros format."""
|
||||
data = b"ATHEROS_CSI:mock_data"
|
||||
|
||||
result = parser.parse(data)
|
||||
|
||||
assert isinstance(result, CSIData)
|
||||
assert result.metadata['source'] == 'atheros_router'
|
||||
|
||||
def test_parse_unknown_format(self, parser):
|
||||
"""Should reject unknown format."""
|
||||
data = b"UNKNOWN_FORMAT:data"
|
||||
|
||||
with pytest.raises(CSIParseError, match="Unknown router CSI format"):
|
||||
parser.parse(data)
|
||||
367
v1/tests/unit/test_densepose_head.py
Normal file
367
v1/tests/unit/test_densepose_head.py
Normal file
@@ -0,0 +1,367 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch
|
||||
from src.models.densepose_head import DensePoseHead, DensePoseError
|
||||
|
||||
|
||||
class TestDensePoseHead:
|
||||
"""Test suite for DensePose Head following London School TDD principles"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Configuration for DensePose head"""
|
||||
return {
|
||||
'input_channels': 256,
|
||||
'num_body_parts': 24,
|
||||
'num_uv_coordinates': 2,
|
||||
'hidden_channels': [128, 64],
|
||||
'kernel_size': 3,
|
||||
'padding': 1,
|
||||
'dropout_rate': 0.1,
|
||||
'use_deformable_conv': False,
|
||||
'use_fpn': True,
|
||||
'fpn_levels': [2, 3, 4, 5],
|
||||
'output_stride': 4
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def densepose_head(self, mock_config):
|
||||
"""Create DensePose head instance for testing"""
|
||||
return DensePoseHead(mock_config)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_feature_input(self):
|
||||
"""Generate mock feature input tensor"""
|
||||
batch_size = 2
|
||||
channels = 256
|
||||
height = 56
|
||||
width = 56
|
||||
return torch.randn(batch_size, channels, height, width)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_target_masks(self):
|
||||
"""Generate mock target segmentation masks"""
|
||||
batch_size = 2
|
||||
num_parts = 24
|
||||
height = 224
|
||||
width = 224
|
||||
return torch.randint(0, num_parts + 1, (batch_size, height, width))
|
||||
|
||||
@pytest.fixture
|
||||
def mock_target_uv(self):
|
||||
"""Generate mock target UV coordinates"""
|
||||
batch_size = 2
|
||||
num_coords = 2
|
||||
height = 224
|
||||
width = 224
|
||||
return torch.randn(batch_size, num_coords, height, width)
|
||||
|
||||
def test_head_initialization_creates_correct_architecture(self, mock_config):
|
||||
"""Test that DensePose head initializes with correct architecture"""
|
||||
# Act
|
||||
head = DensePoseHead(mock_config)
|
||||
|
||||
# Assert
|
||||
assert head is not None
|
||||
assert isinstance(head, nn.Module)
|
||||
assert head.input_channels == mock_config['input_channels']
|
||||
assert head.num_body_parts == mock_config['num_body_parts']
|
||||
assert head.num_uv_coordinates == mock_config['num_uv_coordinates']
|
||||
assert head.use_fpn == mock_config['use_fpn']
|
||||
assert hasattr(head, 'segmentation_head')
|
||||
assert hasattr(head, 'uv_regression_head')
|
||||
if mock_config['use_fpn']:
|
||||
assert hasattr(head, 'fpn')
|
||||
|
||||
def test_forward_pass_produces_correct_output_format(self, densepose_head, mock_feature_input):
|
||||
"""Test that forward pass produces correctly formatted output"""
|
||||
# Act
|
||||
output = densepose_head(mock_feature_input)
|
||||
|
||||
# Assert
|
||||
assert output is not None
|
||||
assert isinstance(output, dict)
|
||||
assert 'segmentation' in output
|
||||
assert 'uv_coordinates' in output
|
||||
|
||||
seg_output = output['segmentation']
|
||||
uv_output = output['uv_coordinates']
|
||||
|
||||
assert isinstance(seg_output, torch.Tensor)
|
||||
assert isinstance(uv_output, torch.Tensor)
|
||||
assert seg_output.shape[0] == mock_feature_input.shape[0] # Batch size preserved
|
||||
assert uv_output.shape[0] == mock_feature_input.shape[0] # Batch size preserved
|
||||
|
||||
def test_segmentation_head_produces_correct_shape(self, densepose_head, mock_feature_input):
|
||||
"""Test that segmentation head produces correct output shape"""
|
||||
# Act
|
||||
output = densepose_head(mock_feature_input)
|
||||
seg_output = output['segmentation']
|
||||
|
||||
# Assert
|
||||
expected_channels = densepose_head.num_body_parts + 1 # +1 for background
|
||||
assert seg_output.shape[1] == expected_channels
|
||||
assert seg_output.shape[2] >= mock_feature_input.shape[2] # Height upsampled
|
||||
assert seg_output.shape[3] >= mock_feature_input.shape[3] # Width upsampled
|
||||
|
||||
def test_uv_regression_head_produces_correct_shape(self, densepose_head, mock_feature_input):
|
||||
"""Test that UV regression head produces correct output shape"""
|
||||
# Act
|
||||
output = densepose_head(mock_feature_input)
|
||||
uv_output = output['uv_coordinates']
|
||||
|
||||
# Assert
|
||||
assert uv_output.shape[1] == densepose_head.num_uv_coordinates
|
||||
assert uv_output.shape[2] >= mock_feature_input.shape[2] # Height upsampled
|
||||
assert uv_output.shape[3] >= mock_feature_input.shape[3] # Width upsampled
|
||||
|
||||
def test_compute_segmentation_loss_measures_pixel_classification(self, densepose_head, mock_feature_input, mock_target_masks):
|
||||
"""Test that compute_segmentation_loss measures pixel classification accuracy"""
|
||||
# Arrange
|
||||
output = densepose_head(mock_feature_input)
|
||||
seg_logits = output['segmentation']
|
||||
|
||||
# Resize target to match output
|
||||
target_resized = torch.nn.functional.interpolate(
|
||||
mock_target_masks.float().unsqueeze(1),
|
||||
size=seg_logits.shape[2:],
|
||||
mode='nearest'
|
||||
).squeeze(1).long()
|
||||
|
||||
# Act
|
||||
loss = densepose_head.compute_segmentation_loss(seg_logits, target_resized)
|
||||
|
||||
# Assert
|
||||
assert loss is not None
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.dim() == 0 # Scalar loss
|
||||
assert loss.item() >= 0 # Loss should be non-negative
|
||||
|
||||
def test_compute_uv_loss_measures_coordinate_regression(self, densepose_head, mock_feature_input, mock_target_uv):
|
||||
"""Test that compute_uv_loss measures UV coordinate regression accuracy"""
|
||||
# Arrange
|
||||
output = densepose_head(mock_feature_input)
|
||||
uv_pred = output['uv_coordinates']
|
||||
|
||||
# Resize target to match output
|
||||
target_resized = torch.nn.functional.interpolate(
|
||||
mock_target_uv,
|
||||
size=uv_pred.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Act
|
||||
loss = densepose_head.compute_uv_loss(uv_pred, target_resized)
|
||||
|
||||
# Assert
|
||||
assert loss is not None
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.dim() == 0 # Scalar loss
|
||||
assert loss.item() >= 0 # Loss should be non-negative
|
||||
|
||||
def test_compute_total_loss_combines_segmentation_and_uv_losses(self, densepose_head, mock_feature_input, mock_target_masks, mock_target_uv):
|
||||
"""Test that compute_total_loss combines segmentation and UV losses"""
|
||||
# Arrange
|
||||
output = densepose_head(mock_feature_input)
|
||||
|
||||
# Resize targets to match outputs
|
||||
seg_target = torch.nn.functional.interpolate(
|
||||
mock_target_masks.float().unsqueeze(1),
|
||||
size=output['segmentation'].shape[2:],
|
||||
mode='nearest'
|
||||
).squeeze(1).long()
|
||||
|
||||
uv_target = torch.nn.functional.interpolate(
|
||||
mock_target_uv,
|
||||
size=output['uv_coordinates'].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Act
|
||||
total_loss = densepose_head.compute_total_loss(output, seg_target, uv_target)
|
||||
seg_loss = densepose_head.compute_segmentation_loss(output['segmentation'], seg_target)
|
||||
uv_loss = densepose_head.compute_uv_loss(output['uv_coordinates'], uv_target)
|
||||
|
||||
# Assert
|
||||
assert total_loss is not None
|
||||
assert isinstance(total_loss, torch.Tensor)
|
||||
assert total_loss.item() > 0
|
||||
# Total loss should be combination of individual losses
|
||||
expected_total = seg_loss + uv_loss
|
||||
assert torch.allclose(total_loss, expected_total, atol=1e-6)
|
||||
|
||||
def test_fpn_integration_enhances_multi_scale_features(self, mock_config, mock_feature_input):
|
||||
"""Test that FPN integration enhances multi-scale feature processing"""
|
||||
# Arrange
|
||||
config_with_fpn = mock_config.copy()
|
||||
config_with_fpn['use_fpn'] = True
|
||||
|
||||
config_without_fpn = mock_config.copy()
|
||||
config_without_fpn['use_fpn'] = False
|
||||
|
||||
head_with_fpn = DensePoseHead(config_with_fpn)
|
||||
head_without_fpn = DensePoseHead(config_without_fpn)
|
||||
|
||||
# Act
|
||||
output_with_fpn = head_with_fpn(mock_feature_input)
|
||||
output_without_fpn = head_without_fpn(mock_feature_input)
|
||||
|
||||
# Assert
|
||||
assert output_with_fpn['segmentation'].shape == output_without_fpn['segmentation'].shape
|
||||
assert output_with_fpn['uv_coordinates'].shape == output_without_fpn['uv_coordinates'].shape
|
||||
# Outputs should be different due to FPN
|
||||
assert not torch.allclose(output_with_fpn['segmentation'], output_without_fpn['segmentation'], atol=1e-6)
|
||||
|
||||
def test_get_prediction_confidence_provides_uncertainty_estimates(self, densepose_head, mock_feature_input):
|
||||
"""Test that get_prediction_confidence provides uncertainty estimates"""
|
||||
# Arrange
|
||||
output = densepose_head(mock_feature_input)
|
||||
|
||||
# Act
|
||||
confidence = densepose_head.get_prediction_confidence(output)
|
||||
|
||||
# Assert
|
||||
assert confidence is not None
|
||||
assert isinstance(confidence, dict)
|
||||
assert 'segmentation_confidence' in confidence
|
||||
assert 'uv_confidence' in confidence
|
||||
|
||||
seg_conf = confidence['segmentation_confidence']
|
||||
uv_conf = confidence['uv_confidence']
|
||||
|
||||
assert isinstance(seg_conf, torch.Tensor)
|
||||
assert isinstance(uv_conf, torch.Tensor)
|
||||
assert seg_conf.shape[0] == mock_feature_input.shape[0]
|
||||
assert uv_conf.shape[0] == mock_feature_input.shape[0]
|
||||
|
||||
def test_post_process_predictions_formats_output(self, densepose_head, mock_feature_input):
|
||||
"""Test that post_process_predictions formats output correctly"""
|
||||
# Arrange
|
||||
raw_output = densepose_head(mock_feature_input)
|
||||
|
||||
# Act
|
||||
processed = densepose_head.post_process_predictions(raw_output)
|
||||
|
||||
# Assert
|
||||
assert processed is not None
|
||||
assert isinstance(processed, dict)
|
||||
assert 'body_parts' in processed
|
||||
assert 'uv_coordinates' in processed
|
||||
assert 'confidence_scores' in processed
|
||||
|
||||
def test_training_mode_enables_dropout(self, densepose_head, mock_feature_input):
|
||||
"""Test that training mode enables dropout for regularization"""
|
||||
# Arrange
|
||||
densepose_head.train()
|
||||
|
||||
# Act
|
||||
output1 = densepose_head(mock_feature_input)
|
||||
output2 = densepose_head(mock_feature_input)
|
||||
|
||||
# Assert - outputs should be different due to dropout
|
||||
assert not torch.allclose(output1['segmentation'], output2['segmentation'], atol=1e-6)
|
||||
assert not torch.allclose(output1['uv_coordinates'], output2['uv_coordinates'], atol=1e-6)
|
||||
|
||||
def test_evaluation_mode_disables_dropout(self, densepose_head, mock_feature_input):
|
||||
"""Test that evaluation mode disables dropout for consistent inference"""
|
||||
# Arrange
|
||||
densepose_head.eval()
|
||||
|
||||
# Act
|
||||
output1 = densepose_head(mock_feature_input)
|
||||
output2 = densepose_head(mock_feature_input)
|
||||
|
||||
# Assert - outputs should be identical in eval mode
|
||||
assert torch.allclose(output1['segmentation'], output2['segmentation'], atol=1e-6)
|
||||
assert torch.allclose(output1['uv_coordinates'], output2['uv_coordinates'], atol=1e-6)
|
||||
|
||||
def test_head_validates_input_dimensions(self, densepose_head):
|
||||
"""Test that head validates input dimensions"""
|
||||
# Arrange
|
||||
invalid_input = torch.randn(2, 128, 56, 56) # Wrong number of channels
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(DensePoseError):
|
||||
densepose_head(invalid_input)
|
||||
|
||||
def test_head_handles_different_input_sizes(self, densepose_head):
|
||||
"""Test that head handles different input sizes"""
|
||||
# Arrange
|
||||
small_input = torch.randn(1, 256, 28, 28)
|
||||
large_input = torch.randn(1, 256, 112, 112)
|
||||
|
||||
# Act
|
||||
small_output = densepose_head(small_input)
|
||||
large_output = densepose_head(large_input)
|
||||
|
||||
# Assert
|
||||
assert small_output['segmentation'].shape[2:] != large_output['segmentation'].shape[2:]
|
||||
assert small_output['uv_coordinates'].shape[2:] != large_output['uv_coordinates'].shape[2:]
|
||||
|
||||
def test_head_supports_gradient_computation(self, densepose_head, mock_feature_input, mock_target_masks, mock_target_uv):
|
||||
"""Test that head supports gradient computation for training"""
|
||||
# Arrange
|
||||
densepose_head.train()
|
||||
optimizer = torch.optim.Adam(densepose_head.parameters(), lr=0.001)
|
||||
|
||||
output = densepose_head(mock_feature_input)
|
||||
|
||||
# Resize targets
|
||||
seg_target = torch.nn.functional.interpolate(
|
||||
mock_target_masks.float().unsqueeze(1),
|
||||
size=output['segmentation'].shape[2:],
|
||||
mode='nearest'
|
||||
).squeeze(1).long()
|
||||
|
||||
uv_target = torch.nn.functional.interpolate(
|
||||
mock_target_uv,
|
||||
size=output['uv_coordinates'].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Act
|
||||
loss = densepose_head.compute_total_loss(output, seg_target, uv_target)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
# Assert
|
||||
for param in densepose_head.parameters():
|
||||
if param.requires_grad:
|
||||
assert param.grad is not None
|
||||
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
|
||||
|
||||
def test_head_configuration_validation(self):
|
||||
"""Test that head validates configuration parameters"""
|
||||
# Arrange
|
||||
invalid_config = {
|
||||
'input_channels': 0, # Invalid
|
||||
'num_body_parts': -1, # Invalid
|
||||
'num_uv_coordinates': 2
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
DensePoseHead(invalid_config)
|
||||
|
||||
def test_save_and_load_model_state(self, densepose_head, mock_feature_input):
|
||||
"""Test that model state can be saved and loaded"""
|
||||
# Arrange
|
||||
original_output = densepose_head(mock_feature_input)
|
||||
|
||||
# Act - Save state
|
||||
state_dict = densepose_head.state_dict()
|
||||
|
||||
# Create new head and load state
|
||||
new_head = DensePoseHead(densepose_head.config)
|
||||
new_head.load_state_dict(state_dict)
|
||||
new_output = new_head(mock_feature_input)
|
||||
|
||||
# Assert
|
||||
assert torch.allclose(original_output['segmentation'], new_output['segmentation'], atol=1e-6)
|
||||
assert torch.allclose(original_output['uv_coordinates'], new_output['uv_coordinates'], atol=1e-6)
|
||||
293
v1/tests/unit/test_modality_translation.py
Normal file
293
v1/tests/unit/test_modality_translation.py
Normal file
@@ -0,0 +1,293 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch
|
||||
from src.models.modality_translation import ModalityTranslationNetwork, ModalityTranslationError
|
||||
|
||||
|
||||
class TestModalityTranslationNetwork:
|
||||
"""Test suite for Modality Translation Network following London School TDD principles"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Configuration for modality translation network"""
|
||||
return {
|
||||
'input_channels': 6, # Real and imaginary parts for 3 antennas
|
||||
'hidden_channels': [64, 128, 256],
|
||||
'output_channels': 256,
|
||||
'kernel_size': 3,
|
||||
'stride': 1,
|
||||
'padding': 1,
|
||||
'dropout_rate': 0.1,
|
||||
'activation': 'relu',
|
||||
'normalization': 'batch',
|
||||
'use_attention': True,
|
||||
'attention_heads': 8
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def translation_network(self, mock_config):
|
||||
"""Create modality translation network instance for testing"""
|
||||
return ModalityTranslationNetwork(mock_config)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_csi_input(self):
|
||||
"""Generate mock CSI input tensor"""
|
||||
batch_size = 4
|
||||
channels = 6 # Real and imaginary parts for 3 antennas
|
||||
height = 56 # Number of subcarriers
|
||||
width = 100 # Time samples
|
||||
return torch.randn(batch_size, channels, height, width)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_target_features(self):
|
||||
"""Generate mock target feature tensor for training"""
|
||||
batch_size = 4
|
||||
feature_dim = 256
|
||||
spatial_height = 56
|
||||
spatial_width = 100
|
||||
return torch.randn(batch_size, feature_dim, spatial_height, spatial_width)
|
||||
|
||||
def test_network_initialization_creates_correct_architecture(self, mock_config):
|
||||
"""Test that modality translation network initializes with correct architecture"""
|
||||
# Act
|
||||
network = ModalityTranslationNetwork(mock_config)
|
||||
|
||||
# Assert
|
||||
assert network is not None
|
||||
assert isinstance(network, nn.Module)
|
||||
assert network.input_channels == mock_config['input_channels']
|
||||
assert network.output_channels == mock_config['output_channels']
|
||||
assert network.use_attention == mock_config['use_attention']
|
||||
assert hasattr(network, 'encoder')
|
||||
assert hasattr(network, 'decoder')
|
||||
if mock_config['use_attention']:
|
||||
assert hasattr(network, 'attention')
|
||||
|
||||
def test_forward_pass_produces_correct_output_shape(self, translation_network, mock_csi_input):
|
||||
"""Test that forward pass produces correctly shaped output"""
|
||||
# Act
|
||||
output = translation_network(mock_csi_input)
|
||||
|
||||
# Assert
|
||||
assert output is not None
|
||||
assert isinstance(output, torch.Tensor)
|
||||
assert output.shape[0] == mock_csi_input.shape[0] # Batch size preserved
|
||||
assert output.shape[1] == translation_network.output_channels # Correct output channels
|
||||
assert output.shape[2] == mock_csi_input.shape[2] # Spatial height preserved
|
||||
assert output.shape[3] == mock_csi_input.shape[3] # Spatial width preserved
|
||||
|
||||
def test_forward_pass_handles_different_input_sizes(self, translation_network):
|
||||
"""Test that forward pass handles different input sizes"""
|
||||
# Arrange
|
||||
small_input = torch.randn(2, 6, 28, 50)
|
||||
large_input = torch.randn(8, 6, 112, 200)
|
||||
|
||||
# Act
|
||||
small_output = translation_network(small_input)
|
||||
large_output = translation_network(large_input)
|
||||
|
||||
# Assert
|
||||
assert small_output.shape == (2, 256, 28, 50)
|
||||
assert large_output.shape == (8, 256, 112, 200)
|
||||
|
||||
def test_encoder_extracts_hierarchical_features(self, translation_network, mock_csi_input):
|
||||
"""Test that encoder extracts hierarchical features"""
|
||||
# Act
|
||||
features = translation_network.encode(mock_csi_input)
|
||||
|
||||
# Assert
|
||||
assert features is not None
|
||||
assert isinstance(features, list)
|
||||
assert len(features) == len(translation_network.encoder)
|
||||
|
||||
# Check feature map sizes decrease with depth
|
||||
for i in range(1, len(features)):
|
||||
assert features[i].shape[2] <= features[i-1].shape[2] # Height decreases or stays same
|
||||
assert features[i].shape[3] <= features[i-1].shape[3] # Width decreases or stays same
|
||||
|
||||
def test_decoder_reconstructs_target_features(self, translation_network, mock_csi_input):
|
||||
"""Test that decoder reconstructs target feature representation"""
|
||||
# Arrange
|
||||
encoded_features = translation_network.encode(mock_csi_input)
|
||||
|
||||
# Act
|
||||
decoded_output = translation_network.decode(encoded_features)
|
||||
|
||||
# Assert
|
||||
assert decoded_output is not None
|
||||
assert isinstance(decoded_output, torch.Tensor)
|
||||
assert decoded_output.shape[1] == translation_network.output_channels
|
||||
assert decoded_output.shape[2:] == mock_csi_input.shape[2:]
|
||||
|
||||
def test_attention_mechanism_enhances_features(self, mock_config, mock_csi_input):
|
||||
"""Test that attention mechanism enhances feature representation"""
|
||||
# Arrange
|
||||
config_with_attention = mock_config.copy()
|
||||
config_with_attention['use_attention'] = True
|
||||
|
||||
config_without_attention = mock_config.copy()
|
||||
config_without_attention['use_attention'] = False
|
||||
|
||||
network_with_attention = ModalityTranslationNetwork(config_with_attention)
|
||||
network_without_attention = ModalityTranslationNetwork(config_without_attention)
|
||||
|
||||
# Act
|
||||
output_with_attention = network_with_attention(mock_csi_input)
|
||||
output_without_attention = network_without_attention(mock_csi_input)
|
||||
|
||||
# Assert
|
||||
assert output_with_attention.shape == output_without_attention.shape
|
||||
# Outputs should be different due to attention mechanism
|
||||
assert not torch.allclose(output_with_attention, output_without_attention, atol=1e-6)
|
||||
|
||||
def test_training_mode_enables_dropout(self, translation_network, mock_csi_input):
|
||||
"""Test that training mode enables dropout for regularization"""
|
||||
# Arrange
|
||||
translation_network.train()
|
||||
|
||||
# Act
|
||||
output1 = translation_network(mock_csi_input)
|
||||
output2 = translation_network(mock_csi_input)
|
||||
|
||||
# Assert - outputs should be different due to dropout
|
||||
assert not torch.allclose(output1, output2, atol=1e-6)
|
||||
|
||||
def test_evaluation_mode_disables_dropout(self, translation_network, mock_csi_input):
|
||||
"""Test that evaluation mode disables dropout for consistent inference"""
|
||||
# Arrange
|
||||
translation_network.eval()
|
||||
|
||||
# Act
|
||||
output1 = translation_network(mock_csi_input)
|
||||
output2 = translation_network(mock_csi_input)
|
||||
|
||||
# Assert - outputs should be identical in eval mode
|
||||
assert torch.allclose(output1, output2, atol=1e-6)
|
||||
|
||||
def test_compute_translation_loss_measures_feature_alignment(self, translation_network, mock_csi_input, mock_target_features):
|
||||
"""Test that compute_translation_loss measures feature alignment"""
|
||||
# Arrange
|
||||
predicted_features = translation_network(mock_csi_input)
|
||||
|
||||
# Act
|
||||
loss = translation_network.compute_translation_loss(predicted_features, mock_target_features)
|
||||
|
||||
# Assert
|
||||
assert loss is not None
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.dim() == 0 # Scalar loss
|
||||
assert loss.item() >= 0 # Loss should be non-negative
|
||||
|
||||
def test_compute_translation_loss_handles_different_loss_types(self, translation_network, mock_csi_input, mock_target_features):
|
||||
"""Test that compute_translation_loss handles different loss types"""
|
||||
# Arrange
|
||||
predicted_features = translation_network(mock_csi_input)
|
||||
|
||||
# Act
|
||||
mse_loss = translation_network.compute_translation_loss(predicted_features, mock_target_features, loss_type='mse')
|
||||
l1_loss = translation_network.compute_translation_loss(predicted_features, mock_target_features, loss_type='l1')
|
||||
|
||||
# Assert
|
||||
assert mse_loss is not None
|
||||
assert l1_loss is not None
|
||||
assert mse_loss.item() != l1_loss.item() # Different loss types should give different values
|
||||
|
||||
def test_get_feature_statistics_provides_analysis(self, translation_network, mock_csi_input):
|
||||
"""Test that get_feature_statistics provides feature analysis"""
|
||||
# Arrange
|
||||
output = translation_network(mock_csi_input)
|
||||
|
||||
# Act
|
||||
stats = translation_network.get_feature_statistics(output)
|
||||
|
||||
# Assert
|
||||
assert stats is not None
|
||||
assert isinstance(stats, dict)
|
||||
assert 'mean' in stats
|
||||
assert 'std' in stats
|
||||
assert 'min' in stats
|
||||
assert 'max' in stats
|
||||
assert 'sparsity' in stats
|
||||
|
||||
def test_network_supports_gradient_computation(self, translation_network, mock_csi_input, mock_target_features):
|
||||
"""Test that network supports gradient computation for training"""
|
||||
# Arrange
|
||||
translation_network.train()
|
||||
optimizer = torch.optim.Adam(translation_network.parameters(), lr=0.001)
|
||||
|
||||
# Act
|
||||
output = translation_network(mock_csi_input)
|
||||
loss = translation_network.compute_translation_loss(output, mock_target_features)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
# Assert
|
||||
for param in translation_network.parameters():
|
||||
if param.requires_grad:
|
||||
assert param.grad is not None
|
||||
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
|
||||
|
||||
def test_network_validates_input_dimensions(self, translation_network):
|
||||
"""Test that network validates input dimensions"""
|
||||
# Arrange
|
||||
invalid_input = torch.randn(4, 3, 56, 100) # Wrong number of channels
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ModalityTranslationError):
|
||||
translation_network(invalid_input)
|
||||
|
||||
def test_network_handles_batch_size_one(self, translation_network):
|
||||
"""Test that network handles single sample inference"""
|
||||
# Arrange
|
||||
single_input = torch.randn(1, 6, 56, 100)
|
||||
|
||||
# Act
|
||||
output = translation_network(single_input)
|
||||
|
||||
# Assert
|
||||
assert output.shape == (1, 256, 56, 100)
|
||||
|
||||
def test_save_and_load_model_state(self, translation_network, mock_csi_input):
|
||||
"""Test that model state can be saved and loaded"""
|
||||
# Arrange
|
||||
original_output = translation_network(mock_csi_input)
|
||||
|
||||
# Act - Save state
|
||||
state_dict = translation_network.state_dict()
|
||||
|
||||
# Create new network and load state
|
||||
new_network = ModalityTranslationNetwork(translation_network.config)
|
||||
new_network.load_state_dict(state_dict)
|
||||
new_output = new_network(mock_csi_input)
|
||||
|
||||
# Assert
|
||||
assert torch.allclose(original_output, new_output, atol=1e-6)
|
||||
|
||||
def test_network_configuration_validation(self):
|
||||
"""Test that network validates configuration parameters"""
|
||||
# Arrange
|
||||
invalid_config = {
|
||||
'input_channels': 0, # Invalid
|
||||
'hidden_channels': [], # Invalid
|
||||
'output_channels': 256
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
ModalityTranslationNetwork(invalid_config)
|
||||
|
||||
def test_feature_visualization_support(self, translation_network, mock_csi_input):
|
||||
"""Test that network supports feature visualization"""
|
||||
# Act
|
||||
features = translation_network.get_intermediate_features(mock_csi_input)
|
||||
|
||||
# Assert
|
||||
assert features is not None
|
||||
assert isinstance(features, dict)
|
||||
assert 'encoder_features' in features
|
||||
assert 'decoder_features' in features
|
||||
if translation_network.use_attention:
|
||||
assert 'attention_weights' in features
|
||||
107
v1/tests/unit/test_phase_sanitizer.py
Normal file
107
v1/tests/unit/test_phase_sanitizer.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch
|
||||
from src.core.phase_sanitizer import PhaseSanitizer
|
||||
|
||||
|
||||
class TestPhaseSanitizer:
|
||||
"""Test suite for Phase Sanitizer following London School TDD principles"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_phase_data(self):
|
||||
"""Generate synthetic phase data for testing"""
|
||||
# Phase data with unwrapping issues and outliers
|
||||
return np.array([
|
||||
[0.1, 0.2, 6.0, 0.4, 0.5], # Contains phase jump at index 2
|
||||
[-3.0, -0.1, 0.0, 0.1, 0.2], # Contains wrapped phase at index 0
|
||||
[0.0, 0.1, 0.2, 0.3, 0.4] # Clean phase data
|
||||
])
|
||||
|
||||
@pytest.fixture
|
||||
def phase_sanitizer(self):
|
||||
"""Create Phase Sanitizer instance for testing"""
|
||||
return PhaseSanitizer()
|
||||
|
||||
def test_unwrap_phase_removes_discontinuities(self, phase_sanitizer, mock_phase_data):
|
||||
"""Test that phase unwrapping removes 2π discontinuities"""
|
||||
# Act
|
||||
result = phase_sanitizer.unwrap_phase(mock_phase_data)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert isinstance(result, np.ndarray)
|
||||
assert result.shape == mock_phase_data.shape
|
||||
|
||||
# Check that large jumps are reduced
|
||||
for i in range(result.shape[0]):
|
||||
phase_diffs = np.abs(np.diff(result[i]))
|
||||
assert np.all(phase_diffs < np.pi) # No jumps larger than π
|
||||
|
||||
def test_remove_outliers_filters_anomalous_values(self, phase_sanitizer, mock_phase_data):
|
||||
"""Test that outlier removal filters anomalous phase values"""
|
||||
# Arrange - Add clear outliers
|
||||
outlier_data = mock_phase_data.copy()
|
||||
outlier_data[0, 2] = 100.0 # Clear outlier
|
||||
|
||||
# Act
|
||||
result = phase_sanitizer.remove_outliers(outlier_data)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert isinstance(result, np.ndarray)
|
||||
assert result.shape == outlier_data.shape
|
||||
assert np.abs(result[0, 2]) < 10.0 # Outlier should be corrected
|
||||
|
||||
def test_smooth_phase_reduces_noise(self, phase_sanitizer, mock_phase_data):
|
||||
"""Test that phase smoothing reduces noise while preserving trends"""
|
||||
# Arrange - Add noise
|
||||
noisy_data = mock_phase_data + np.random.normal(0, 0.1, mock_phase_data.shape)
|
||||
|
||||
# Act
|
||||
result = phase_sanitizer.smooth_phase(noisy_data)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert isinstance(result, np.ndarray)
|
||||
assert result.shape == noisy_data.shape
|
||||
|
||||
# Smoothed data should have lower variance
|
||||
original_variance = np.var(noisy_data)
|
||||
smoothed_variance = np.var(result)
|
||||
assert smoothed_variance <= original_variance
|
||||
|
||||
def test_sanitize_handles_empty_input(self, phase_sanitizer):
|
||||
"""Test that sanitizer handles empty input gracefully"""
|
||||
# Arrange
|
||||
empty_data = np.array([])
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Phase data cannot be empty"):
|
||||
phase_sanitizer.sanitize(empty_data)
|
||||
|
||||
def test_sanitize_full_pipeline_integration(self, phase_sanitizer, mock_phase_data):
|
||||
"""Test that full sanitization pipeline works correctly"""
|
||||
# Act
|
||||
result = phase_sanitizer.sanitize(mock_phase_data)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert isinstance(result, np.ndarray)
|
||||
assert result.shape == mock_phase_data.shape
|
||||
|
||||
# Result should be within reasonable phase bounds
|
||||
assert np.all(result >= -2*np.pi)
|
||||
assert np.all(result <= 2*np.pi)
|
||||
|
||||
def test_sanitize_performance_requirement(self, phase_sanitizer, mock_phase_data):
|
||||
"""Test that phase sanitization meets performance requirements (<5ms)"""
|
||||
import time
|
||||
|
||||
# Act
|
||||
start_time = time.time()
|
||||
result = phase_sanitizer.sanitize(mock_phase_data)
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
# Assert
|
||||
assert processing_time < 0.005 # <5ms requirement
|
||||
assert result is not None
|
||||
407
v1/tests/unit/test_phase_sanitizer_tdd.py
Normal file
407
v1/tests/unit/test_phase_sanitizer_tdd.py
Normal file
@@ -0,0 +1,407 @@
|
||||
"""TDD tests for phase sanitizer following London School approach."""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
from datetime import datetime, timezone
|
||||
import importlib.util
|
||||
|
||||
# Import the phase sanitizer module directly
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
'phase_sanitizer',
|
||||
'/workspaces/wifi-densepose/src/core/phase_sanitizer.py'
|
||||
)
|
||||
phase_sanitizer_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(phase_sanitizer_module)
|
||||
|
||||
# Get classes from the module
|
||||
PhaseSanitizer = phase_sanitizer_module.PhaseSanitizer
|
||||
PhaseSanitizationError = phase_sanitizer_module.PhaseSanitizationError
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestPhaseSanitizer:
|
||||
"""Test phase sanitizer using London School TDD."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger(self):
|
||||
"""Mock logger for testing."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def sanitizer_config(self):
|
||||
"""Phase sanitizer configuration for testing."""
|
||||
return {
|
||||
'unwrapping_method': 'numpy',
|
||||
'outlier_threshold': 3.0,
|
||||
'smoothing_window': 5,
|
||||
'enable_outlier_removal': True,
|
||||
'enable_smoothing': True,
|
||||
'enable_noise_filtering': True,
|
||||
'noise_threshold': 0.1,
|
||||
'phase_range': (-np.pi, np.pi)
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def phase_sanitizer(self, sanitizer_config, mock_logger):
|
||||
"""Create phase sanitizer for testing."""
|
||||
return PhaseSanitizer(config=sanitizer_config, logger=mock_logger)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_wrapped_phase(self):
|
||||
"""Sample wrapped phase data with discontinuities."""
|
||||
# Create phase data with wrapping
|
||||
phase = np.linspace(0, 4*np.pi, 100)
|
||||
wrapped_phase = np.angle(np.exp(1j * phase)) # Wrap to [-π, π]
|
||||
return wrapped_phase.reshape(1, -1) # Shape: (1, 100)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_noisy_phase(self):
|
||||
"""Sample phase data with noise and outliers."""
|
||||
clean_phase = np.linspace(-np.pi, np.pi, 50)
|
||||
noise = np.random.normal(0, 0.05, 50)
|
||||
# Add some outliers
|
||||
outliers = np.random.choice(50, 5, replace=False)
|
||||
noisy_phase = clean_phase + noise
|
||||
noisy_phase[outliers] += np.random.uniform(-2, 2, 5) # Add outliers
|
||||
return noisy_phase.reshape(1, -1)
|
||||
|
||||
# Initialization tests
|
||||
def test_should_initialize_with_valid_config(self, sanitizer_config, mock_logger):
|
||||
"""Should initialize phase sanitizer with valid configuration."""
|
||||
sanitizer = PhaseSanitizer(config=sanitizer_config, logger=mock_logger)
|
||||
|
||||
assert sanitizer.config == sanitizer_config
|
||||
assert sanitizer.logger == mock_logger
|
||||
assert sanitizer.unwrapping_method == 'numpy'
|
||||
assert sanitizer.outlier_threshold == 3.0
|
||||
assert sanitizer.smoothing_window == 5
|
||||
assert sanitizer.enable_outlier_removal == True
|
||||
assert sanitizer.enable_smoothing == True
|
||||
assert sanitizer.enable_noise_filtering == True
|
||||
assert sanitizer.noise_threshold == 0.1
|
||||
assert sanitizer.phase_range == (-np.pi, np.pi)
|
||||
|
||||
def test_should_raise_error_with_invalid_config(self, mock_logger):
|
||||
"""Should raise error when initialized with invalid configuration."""
|
||||
invalid_config = {'invalid': 'config'}
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required configuration"):
|
||||
PhaseSanitizer(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_should_validate_required_fields(self, mock_logger):
|
||||
"""Should validate required configuration fields."""
|
||||
required_fields = ['unwrapping_method', 'outlier_threshold', 'smoothing_window']
|
||||
base_config = {
|
||||
'unwrapping_method': 'numpy',
|
||||
'outlier_threshold': 3.0,
|
||||
'smoothing_window': 5
|
||||
}
|
||||
|
||||
for field in required_fields:
|
||||
config = base_config.copy()
|
||||
del config[field]
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required configuration"):
|
||||
PhaseSanitizer(config=config, logger=mock_logger)
|
||||
|
||||
def test_should_use_default_values(self, mock_logger):
|
||||
"""Should use default values for optional parameters."""
|
||||
minimal_config = {
|
||||
'unwrapping_method': 'numpy',
|
||||
'outlier_threshold': 3.0,
|
||||
'smoothing_window': 5
|
||||
}
|
||||
|
||||
sanitizer = PhaseSanitizer(config=minimal_config, logger=mock_logger)
|
||||
|
||||
assert sanitizer.enable_outlier_removal == True # default
|
||||
assert sanitizer.enable_smoothing == True # default
|
||||
assert sanitizer.enable_noise_filtering == False # default
|
||||
assert sanitizer.noise_threshold == 0.05 # default
|
||||
assert sanitizer.phase_range == (-np.pi, np.pi) # default
|
||||
|
||||
def test_should_initialize_without_logger(self, sanitizer_config):
|
||||
"""Should initialize without logger provided."""
|
||||
sanitizer = PhaseSanitizer(config=sanitizer_config)
|
||||
|
||||
assert sanitizer.logger is not None # Should create default logger
|
||||
|
||||
# Phase unwrapping tests
|
||||
def test_should_unwrap_phase_successfully(self, phase_sanitizer, sample_wrapped_phase):
|
||||
"""Should unwrap phase data successfully."""
|
||||
result = phase_sanitizer.unwrap_phase(sample_wrapped_phase)
|
||||
|
||||
# Check that result has same shape
|
||||
assert result.shape == sample_wrapped_phase.shape
|
||||
|
||||
# Check that unwrapping removed discontinuities
|
||||
phase_diff = np.diff(result.flatten())
|
||||
large_jumps = np.abs(phase_diff) > np.pi
|
||||
assert np.sum(large_jumps) < np.sum(np.abs(np.diff(sample_wrapped_phase.flatten())) > np.pi)
|
||||
|
||||
def test_should_handle_different_unwrapping_methods(self, sanitizer_config, mock_logger):
|
||||
"""Should handle different unwrapping methods."""
|
||||
for method in ['numpy', 'scipy', 'custom']:
|
||||
sanitizer_config['unwrapping_method'] = method
|
||||
sanitizer = PhaseSanitizer(config=sanitizer_config, logger=mock_logger)
|
||||
|
||||
phase_data = np.random.uniform(-np.pi, np.pi, (2, 50))
|
||||
|
||||
with patch.object(sanitizer, f'_unwrap_{method}', return_value=phase_data) as mock_unwrap:
|
||||
result = sanitizer.unwrap_phase(phase_data)
|
||||
|
||||
assert result.shape == phase_data.shape
|
||||
mock_unwrap.assert_called_once()
|
||||
|
||||
def test_should_handle_unwrapping_error(self, phase_sanitizer):
|
||||
"""Should handle phase unwrapping errors gracefully."""
|
||||
invalid_phase = np.array([[]]) # Empty array
|
||||
|
||||
with pytest.raises(PhaseSanitizationError, match="Failed to unwrap phase"):
|
||||
phase_sanitizer.unwrap_phase(invalid_phase)
|
||||
|
||||
# Outlier removal tests
|
||||
def test_should_remove_outliers_successfully(self, phase_sanitizer, sample_noisy_phase):
|
||||
"""Should remove outliers from phase data successfully."""
|
||||
with patch.object(phase_sanitizer, '_detect_outliers') as mock_detect:
|
||||
with patch.object(phase_sanitizer, '_interpolate_outliers') as mock_interpolate:
|
||||
outlier_mask = np.zeros(sample_noisy_phase.shape, dtype=bool)
|
||||
outlier_mask[0, [10, 20, 30]] = True # Mark some outliers
|
||||
clean_phase = sample_noisy_phase.copy()
|
||||
|
||||
mock_detect.return_value = outlier_mask
|
||||
mock_interpolate.return_value = clean_phase
|
||||
|
||||
result = phase_sanitizer.remove_outliers(sample_noisy_phase)
|
||||
|
||||
assert result.shape == sample_noisy_phase.shape
|
||||
mock_detect.assert_called_once_with(sample_noisy_phase)
|
||||
mock_interpolate.assert_called_once()
|
||||
|
||||
def test_should_skip_outlier_removal_when_disabled(self, sanitizer_config, mock_logger, sample_noisy_phase):
|
||||
"""Should skip outlier removal when disabled."""
|
||||
sanitizer_config['enable_outlier_removal'] = False
|
||||
sanitizer = PhaseSanitizer(config=sanitizer_config, logger=mock_logger)
|
||||
|
||||
result = sanitizer.remove_outliers(sample_noisy_phase)
|
||||
|
||||
assert np.array_equal(result, sample_noisy_phase)
|
||||
|
||||
def test_should_handle_outlier_removal_error(self, phase_sanitizer):
|
||||
"""Should handle outlier removal errors gracefully."""
|
||||
with patch.object(phase_sanitizer, '_detect_outliers') as mock_detect:
|
||||
mock_detect.side_effect = Exception("Detection error")
|
||||
|
||||
phase_data = np.random.uniform(-np.pi, np.pi, (2, 50))
|
||||
|
||||
with pytest.raises(PhaseSanitizationError, match="Failed to remove outliers"):
|
||||
phase_sanitizer.remove_outliers(phase_data)
|
||||
|
||||
# Smoothing tests
|
||||
def test_should_smooth_phase_successfully(self, phase_sanitizer, sample_noisy_phase):
|
||||
"""Should smooth phase data successfully."""
|
||||
with patch.object(phase_sanitizer, '_apply_moving_average') as mock_smooth:
|
||||
smoothed_phase = sample_noisy_phase * 0.9 # Simulate smoothing
|
||||
mock_smooth.return_value = smoothed_phase
|
||||
|
||||
result = phase_sanitizer.smooth_phase(sample_noisy_phase)
|
||||
|
||||
assert result.shape == sample_noisy_phase.shape
|
||||
mock_smooth.assert_called_once_with(sample_noisy_phase, phase_sanitizer.smoothing_window)
|
||||
|
||||
def test_should_skip_smoothing_when_disabled(self, sanitizer_config, mock_logger, sample_noisy_phase):
|
||||
"""Should skip smoothing when disabled."""
|
||||
sanitizer_config['enable_smoothing'] = False
|
||||
sanitizer = PhaseSanitizer(config=sanitizer_config, logger=mock_logger)
|
||||
|
||||
result = sanitizer.smooth_phase(sample_noisy_phase)
|
||||
|
||||
assert np.array_equal(result, sample_noisy_phase)
|
||||
|
||||
def test_should_handle_smoothing_error(self, phase_sanitizer):
|
||||
"""Should handle smoothing errors gracefully."""
|
||||
with patch.object(phase_sanitizer, '_apply_moving_average') as mock_smooth:
|
||||
mock_smooth.side_effect = Exception("Smoothing error")
|
||||
|
||||
phase_data = np.random.uniform(-np.pi, np.pi, (2, 50))
|
||||
|
||||
with pytest.raises(PhaseSanitizationError, match="Failed to smooth phase"):
|
||||
phase_sanitizer.smooth_phase(phase_data)
|
||||
|
||||
# Noise filtering tests
|
||||
def test_should_filter_noise_successfully(self, phase_sanitizer, sample_noisy_phase):
|
||||
"""Should filter noise from phase data successfully."""
|
||||
with patch.object(phase_sanitizer, '_apply_low_pass_filter') as mock_filter:
|
||||
filtered_phase = sample_noisy_phase * 0.95 # Simulate filtering
|
||||
mock_filter.return_value = filtered_phase
|
||||
|
||||
result = phase_sanitizer.filter_noise(sample_noisy_phase)
|
||||
|
||||
assert result.shape == sample_noisy_phase.shape
|
||||
mock_filter.assert_called_once_with(sample_noisy_phase, phase_sanitizer.noise_threshold)
|
||||
|
||||
def test_should_skip_noise_filtering_when_disabled(self, sanitizer_config, mock_logger, sample_noisy_phase):
|
||||
"""Should skip noise filtering when disabled."""
|
||||
sanitizer_config['enable_noise_filtering'] = False
|
||||
sanitizer = PhaseSanitizer(config=sanitizer_config, logger=mock_logger)
|
||||
|
||||
result = sanitizer.filter_noise(sample_noisy_phase)
|
||||
|
||||
assert np.array_equal(result, sample_noisy_phase)
|
||||
|
||||
def test_should_handle_noise_filtering_error(self, phase_sanitizer):
|
||||
"""Should handle noise filtering errors gracefully."""
|
||||
with patch.object(phase_sanitizer, '_apply_low_pass_filter') as mock_filter:
|
||||
mock_filter.side_effect = Exception("Filtering error")
|
||||
|
||||
phase_data = np.random.uniform(-np.pi, np.pi, (2, 50))
|
||||
|
||||
with pytest.raises(PhaseSanitizationError, match="Failed to filter noise"):
|
||||
phase_sanitizer.filter_noise(phase_data)
|
||||
|
||||
# Complete sanitization pipeline tests
|
||||
def test_should_sanitize_phase_pipeline_successfully(self, phase_sanitizer, sample_wrapped_phase):
|
||||
"""Should sanitize phase through complete pipeline successfully."""
|
||||
with patch.object(phase_sanitizer, 'unwrap_phase', return_value=sample_wrapped_phase) as mock_unwrap:
|
||||
with patch.object(phase_sanitizer, 'remove_outliers', return_value=sample_wrapped_phase) as mock_outliers:
|
||||
with patch.object(phase_sanitizer, 'smooth_phase', return_value=sample_wrapped_phase) as mock_smooth:
|
||||
with patch.object(phase_sanitizer, 'filter_noise', return_value=sample_wrapped_phase) as mock_filter:
|
||||
|
||||
result = phase_sanitizer.sanitize_phase(sample_wrapped_phase)
|
||||
|
||||
assert result.shape == sample_wrapped_phase.shape
|
||||
mock_unwrap.assert_called_once_with(sample_wrapped_phase)
|
||||
mock_outliers.assert_called_once()
|
||||
mock_smooth.assert_called_once()
|
||||
mock_filter.assert_called_once()
|
||||
|
||||
def test_should_handle_sanitization_pipeline_error(self, phase_sanitizer, sample_wrapped_phase):
|
||||
"""Should handle sanitization pipeline errors gracefully."""
|
||||
with patch.object(phase_sanitizer, 'unwrap_phase') as mock_unwrap:
|
||||
mock_unwrap.side_effect = PhaseSanitizationError("Unwrapping failed")
|
||||
|
||||
with pytest.raises(PhaseSanitizationError):
|
||||
phase_sanitizer.sanitize_phase(sample_wrapped_phase)
|
||||
|
||||
# Phase validation tests
|
||||
def test_should_validate_phase_data_successfully(self, phase_sanitizer):
|
||||
"""Should validate phase data successfully."""
|
||||
valid_phase = np.random.uniform(-np.pi, np.pi, (3, 56))
|
||||
|
||||
result = phase_sanitizer.validate_phase_data(valid_phase)
|
||||
|
||||
assert result == True
|
||||
|
||||
def test_should_reject_invalid_phase_shape(self, phase_sanitizer):
|
||||
"""Should reject phase data with invalid shape."""
|
||||
invalid_phase = np.array([1, 2, 3]) # 1D array
|
||||
|
||||
with pytest.raises(PhaseSanitizationError, match="Phase data must be 2D"):
|
||||
phase_sanitizer.validate_phase_data(invalid_phase)
|
||||
|
||||
def test_should_reject_empty_phase_data(self, phase_sanitizer):
|
||||
"""Should reject empty phase data."""
|
||||
empty_phase = np.array([]).reshape(0, 0)
|
||||
|
||||
with pytest.raises(PhaseSanitizationError, match="Phase data cannot be empty"):
|
||||
phase_sanitizer.validate_phase_data(empty_phase)
|
||||
|
||||
def test_should_reject_phase_out_of_range(self, phase_sanitizer):
|
||||
"""Should reject phase data outside valid range."""
|
||||
invalid_phase = np.array([[10.0, -10.0, 5.0, -5.0]]) # Outside [-π, π]
|
||||
|
||||
with pytest.raises(PhaseSanitizationError, match="Phase values outside valid range"):
|
||||
phase_sanitizer.validate_phase_data(invalid_phase)
|
||||
|
||||
# Statistics and monitoring tests
|
||||
def test_should_get_sanitization_statistics(self, phase_sanitizer):
|
||||
"""Should get sanitization statistics."""
|
||||
# Simulate some processing
|
||||
phase_sanitizer._total_processed = 50
|
||||
phase_sanitizer._outliers_removed = 5
|
||||
phase_sanitizer._sanitization_errors = 2
|
||||
|
||||
stats = phase_sanitizer.get_sanitization_statistics()
|
||||
|
||||
assert isinstance(stats, dict)
|
||||
assert stats['total_processed'] == 50
|
||||
assert stats['outliers_removed'] == 5
|
||||
assert stats['sanitization_errors'] == 2
|
||||
assert stats['outlier_rate'] == 0.1
|
||||
assert stats['error_rate'] == 0.04
|
||||
|
||||
def test_should_reset_statistics(self, phase_sanitizer):
|
||||
"""Should reset sanitization statistics."""
|
||||
phase_sanitizer._total_processed = 50
|
||||
phase_sanitizer._outliers_removed = 5
|
||||
phase_sanitizer._sanitization_errors = 2
|
||||
|
||||
phase_sanitizer.reset_statistics()
|
||||
|
||||
assert phase_sanitizer._total_processed == 0
|
||||
assert phase_sanitizer._outliers_removed == 0
|
||||
assert phase_sanitizer._sanitization_errors == 0
|
||||
|
||||
# Configuration validation tests
|
||||
def test_should_validate_unwrapping_method(self, mock_logger):
|
||||
"""Should validate unwrapping method."""
|
||||
invalid_config = {
|
||||
'unwrapping_method': 'invalid_method',
|
||||
'outlier_threshold': 3.0,
|
||||
'smoothing_window': 5
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid unwrapping method"):
|
||||
PhaseSanitizer(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_should_validate_outlier_threshold(self, mock_logger):
|
||||
"""Should validate outlier threshold."""
|
||||
invalid_config = {
|
||||
'unwrapping_method': 'numpy',
|
||||
'outlier_threshold': -1.0, # Negative threshold
|
||||
'smoothing_window': 5
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="outlier_threshold must be positive"):
|
||||
PhaseSanitizer(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_should_validate_smoothing_window(self, mock_logger):
|
||||
"""Should validate smoothing window."""
|
||||
invalid_config = {
|
||||
'unwrapping_method': 'numpy',
|
||||
'outlier_threshold': 3.0,
|
||||
'smoothing_window': 0 # Invalid window size
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="smoothing_window must be positive"):
|
||||
PhaseSanitizer(config=invalid_config, logger=mock_logger)
|
||||
|
||||
# Edge case tests
|
||||
def test_should_handle_single_antenna_data(self, phase_sanitizer):
|
||||
"""Should handle single antenna phase data."""
|
||||
single_antenna_phase = np.random.uniform(-np.pi, np.pi, (1, 56))
|
||||
|
||||
result = phase_sanitizer.sanitize_phase(single_antenna_phase)
|
||||
|
||||
assert result.shape == single_antenna_phase.shape
|
||||
|
||||
def test_should_handle_small_phase_arrays(self, phase_sanitizer):
|
||||
"""Should handle small phase arrays."""
|
||||
small_phase = np.random.uniform(-np.pi, np.pi, (2, 5))
|
||||
|
||||
result = phase_sanitizer.sanitize_phase(small_phase)
|
||||
|
||||
assert result.shape == small_phase.shape
|
||||
|
||||
def test_should_handle_constant_phase_data(self, phase_sanitizer):
|
||||
"""Should handle constant phase data."""
|
||||
constant_phase = np.full((3, 20), 0.5)
|
||||
|
||||
result = phase_sanitizer.sanitize_phase(constant_phase)
|
||||
|
||||
assert result.shape == constant_phase.shape
|
||||
244
v1/tests/unit/test_router_interface.py
Normal file
244
v1/tests/unit/test_router_interface.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from src.hardware.router_interface import RouterInterface, RouterConnectionError
|
||||
|
||||
|
||||
class TestRouterInterface:
|
||||
"""Test suite for Router Interface following London School TDD principles"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Configuration for router interface"""
|
||||
return {
|
||||
'router_ip': '192.168.1.1',
|
||||
'username': 'admin',
|
||||
'password': 'password',
|
||||
'ssh_port': 22,
|
||||
'timeout': 30,
|
||||
'max_retries': 3
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def router_interface(self, mock_config):
|
||||
"""Create router interface instance for testing"""
|
||||
return RouterInterface(mock_config)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ssh_client(self):
|
||||
"""Mock SSH client for testing"""
|
||||
mock_client = Mock()
|
||||
mock_client.connect = Mock()
|
||||
mock_client.exec_command = Mock()
|
||||
mock_client.close = Mock()
|
||||
return mock_client
|
||||
|
||||
def test_interface_initialization_creates_correct_configuration(self, mock_config):
|
||||
"""Test that router interface initializes with correct configuration"""
|
||||
# Act
|
||||
interface = RouterInterface(mock_config)
|
||||
|
||||
# Assert
|
||||
assert interface is not None
|
||||
assert interface.router_ip == mock_config['router_ip']
|
||||
assert interface.username == mock_config['username']
|
||||
assert interface.password == mock_config['password']
|
||||
assert interface.ssh_port == mock_config['ssh_port']
|
||||
assert interface.timeout == mock_config['timeout']
|
||||
assert interface.max_retries == mock_config['max_retries']
|
||||
assert not interface.is_connected
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_connect_establishes_ssh_connection(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that connect method establishes SSH connection"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
|
||||
# Act
|
||||
result = router_interface.connect()
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert router_interface.is_connected is True
|
||||
mock_ssh_client.set_missing_host_key_policy.assert_called_once()
|
||||
mock_ssh_client.connect.assert_called_once_with(
|
||||
hostname=router_interface.router_ip,
|
||||
port=router_interface.ssh_port,
|
||||
username=router_interface.username,
|
||||
password=router_interface.password,
|
||||
timeout=router_interface.timeout
|
||||
)
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_connect_handles_connection_failure(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that connect method handles connection failures gracefully"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
mock_ssh_client.connect.side_effect = Exception("Connection failed")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RouterConnectionError):
|
||||
router_interface.connect()
|
||||
|
||||
assert router_interface.is_connected is False
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_disconnect_closes_ssh_connection(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that disconnect method closes SSH connection"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
router_interface.connect()
|
||||
|
||||
# Act
|
||||
router_interface.disconnect()
|
||||
|
||||
# Assert
|
||||
assert router_interface.is_connected is False
|
||||
mock_ssh_client.close.assert_called_once()
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_execute_command_runs_ssh_command(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that execute_command runs SSH commands correctly"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
mock_stdout = Mock()
|
||||
mock_stdout.read.return_value = b"command output"
|
||||
mock_stderr = Mock()
|
||||
mock_stderr.read.return_value = b""
|
||||
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
|
||||
|
||||
router_interface.connect()
|
||||
|
||||
# Act
|
||||
result = router_interface.execute_command("test command")
|
||||
|
||||
# Assert
|
||||
assert result == "command output"
|
||||
mock_ssh_client.exec_command.assert_called_with("test command")
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_execute_command_handles_command_errors(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that execute_command handles command errors"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
mock_stdout = Mock()
|
||||
mock_stdout.read.return_value = b""
|
||||
mock_stderr = Mock()
|
||||
mock_stderr.read.return_value = b"command error"
|
||||
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
|
||||
|
||||
router_interface.connect()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RouterConnectionError):
|
||||
router_interface.execute_command("failing command")
|
||||
|
||||
def test_execute_command_requires_connection(self, router_interface):
|
||||
"""Test that execute_command requires active connection"""
|
||||
# Act & Assert
|
||||
with pytest.raises(RouterConnectionError):
|
||||
router_interface.execute_command("test command")
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_get_router_info_retrieves_system_information(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that get_router_info retrieves router system information"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
mock_stdout = Mock()
|
||||
mock_stdout.read.return_value = b"Router Model: AC1900\nFirmware: 1.2.3"
|
||||
mock_stderr = Mock()
|
||||
mock_stderr.read.return_value = b""
|
||||
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
|
||||
|
||||
router_interface.connect()
|
||||
|
||||
# Act
|
||||
info = router_interface.get_router_info()
|
||||
|
||||
# Assert
|
||||
assert info is not None
|
||||
assert isinstance(info, dict)
|
||||
assert 'model' in info
|
||||
assert 'firmware' in info
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_enable_monitor_mode_configures_wifi_monitoring(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that enable_monitor_mode configures WiFi monitoring"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
mock_stdout = Mock()
|
||||
mock_stdout.read.return_value = b"Monitor mode enabled"
|
||||
mock_stderr = Mock()
|
||||
mock_stderr.read.return_value = b""
|
||||
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
|
||||
|
||||
router_interface.connect()
|
||||
|
||||
# Act
|
||||
result = router_interface.enable_monitor_mode("wlan0")
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_ssh_client.exec_command.assert_called()
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_disable_monitor_mode_disables_wifi_monitoring(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that disable_monitor_mode disables WiFi monitoring"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
mock_stdout = Mock()
|
||||
mock_stdout.read.return_value = b"Monitor mode disabled"
|
||||
mock_stderr = Mock()
|
||||
mock_stderr.read.return_value = b""
|
||||
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
|
||||
|
||||
router_interface.connect()
|
||||
|
||||
# Act
|
||||
result = router_interface.disable_monitor_mode("wlan0")
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_ssh_client.exec_command.assert_called()
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_interface_supports_context_manager(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that router interface supports context manager protocol"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
|
||||
# Act
|
||||
with router_interface as interface:
|
||||
# Assert
|
||||
assert interface.is_connected is True
|
||||
|
||||
# Assert - connection should be closed after context
|
||||
assert router_interface.is_connected is False
|
||||
mock_ssh_client.close.assert_called_once()
|
||||
|
||||
def test_interface_validates_configuration(self):
|
||||
"""Test that router interface validates configuration parameters"""
|
||||
# Arrange
|
||||
invalid_config = {
|
||||
'router_ip': '', # Invalid IP
|
||||
'username': 'admin',
|
||||
'password': 'password'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
RouterInterface(invalid_config)
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_interface_implements_retry_logic(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that interface implements retry logic for failed operations"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
mock_ssh_client.connect.side_effect = [Exception("Temp failure"), None] # Fail once, then succeed
|
||||
|
||||
# Act
|
||||
result = router_interface.connect()
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert mock_ssh_client.connect.call_count == 2 # Should retry once
|
||||
410
v1/tests/unit/test_router_interface_tdd.py
Normal file
410
v1/tests/unit/test_router_interface_tdd.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""TDD tests for router interface following London School approach."""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||||
from datetime import datetime, timezone
|
||||
import importlib.util
|
||||
|
||||
# Import the router interface module directly
|
||||
import unittest.mock
|
||||
|
||||
# Mock asyncssh before importing
|
||||
with unittest.mock.patch.dict('sys.modules', {'asyncssh': unittest.mock.MagicMock()}):
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
'router_interface',
|
||||
'/workspaces/wifi-densepose/src/hardware/router_interface.py'
|
||||
)
|
||||
router_module = importlib.util.module_from_spec(spec)
|
||||
|
||||
# Import CSI extractor for dependency
|
||||
csi_spec = importlib.util.spec_from_file_location(
|
||||
'csi_extractor',
|
||||
'/workspaces/wifi-densepose/src/hardware/csi_extractor.py'
|
||||
)
|
||||
csi_module = importlib.util.module_from_spec(csi_spec)
|
||||
csi_spec.loader.exec_module(csi_module)
|
||||
|
||||
# Now load the router interface
|
||||
router_module.CSIData = csi_module.CSIData # Make CSIData available
|
||||
spec.loader.exec_module(router_module)
|
||||
|
||||
# Get classes from modules
|
||||
RouterInterface = router_module.RouterInterface
|
||||
RouterConnectionError = router_module.RouterConnectionError
|
||||
CSIData = csi_module.CSIData
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestRouterInterface:
|
||||
"""Test router interface using London School TDD."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger(self):
|
||||
"""Mock logger for testing."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def router_config(self):
|
||||
"""Router configuration for testing."""
|
||||
return {
|
||||
'host': '192.168.1.1',
|
||||
'port': 22,
|
||||
'username': 'admin',
|
||||
'password': 'password',
|
||||
'command_timeout': 30,
|
||||
'connection_timeout': 10,
|
||||
'max_retries': 3,
|
||||
'retry_delay': 1.0
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def router_interface(self, router_config, mock_logger):
|
||||
"""Create router interface for testing."""
|
||||
return RouterInterface(config=router_config, logger=mock_logger)
|
||||
|
||||
# Initialization tests
|
||||
def test_should_initialize_with_valid_config(self, router_config, mock_logger):
|
||||
"""Should initialize router interface with valid configuration."""
|
||||
interface = RouterInterface(config=router_config, logger=mock_logger)
|
||||
|
||||
assert interface.host == '192.168.1.1'
|
||||
assert interface.port == 22
|
||||
assert interface.username == 'admin'
|
||||
assert interface.password == 'password'
|
||||
assert interface.command_timeout == 30
|
||||
assert interface.connection_timeout == 10
|
||||
assert interface.max_retries == 3
|
||||
assert interface.retry_delay == 1.0
|
||||
assert interface.is_connected == False
|
||||
assert interface.logger == mock_logger
|
||||
|
||||
def test_should_raise_error_with_invalid_config(self, mock_logger):
|
||||
"""Should raise error when initialized with invalid configuration."""
|
||||
invalid_config = {'invalid': 'config'}
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required configuration"):
|
||||
RouterInterface(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_should_validate_required_fields(self, mock_logger):
|
||||
"""Should validate all required configuration fields."""
|
||||
required_fields = ['host', 'port', 'username', 'password']
|
||||
base_config = {
|
||||
'host': '192.168.1.1',
|
||||
'port': 22,
|
||||
'username': 'admin',
|
||||
'password': 'password'
|
||||
}
|
||||
|
||||
for field in required_fields:
|
||||
config = base_config.copy()
|
||||
del config[field]
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required configuration"):
|
||||
RouterInterface(config=config, logger=mock_logger)
|
||||
|
||||
def test_should_use_default_values(self, mock_logger):
|
||||
"""Should use default values for optional parameters."""
|
||||
minimal_config = {
|
||||
'host': '192.168.1.1',
|
||||
'port': 22,
|
||||
'username': 'admin',
|
||||
'password': 'password'
|
||||
}
|
||||
|
||||
interface = RouterInterface(config=minimal_config, logger=mock_logger)
|
||||
|
||||
assert interface.command_timeout == 30 # default
|
||||
assert interface.connection_timeout == 10 # default
|
||||
assert interface.max_retries == 3 # default
|
||||
assert interface.retry_delay == 1.0 # default
|
||||
|
||||
def test_should_initialize_without_logger(self, router_config):
|
||||
"""Should initialize without logger provided."""
|
||||
interface = RouterInterface(config=router_config)
|
||||
|
||||
assert interface.logger is not None # Should create default logger
|
||||
|
||||
# Connection tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_connect_successfully(self, router_interface):
|
||||
"""Should establish SSH connection successfully."""
|
||||
mock_ssh_client = Mock()
|
||||
|
||||
with patch('src.hardware.router_interface.asyncssh.connect', new_callable=AsyncMock) as mock_connect:
|
||||
mock_connect.return_value = mock_ssh_client
|
||||
|
||||
result = await router_interface.connect()
|
||||
|
||||
assert result == True
|
||||
assert router_interface.is_connected == True
|
||||
assert router_interface.ssh_client == mock_ssh_client
|
||||
mock_connect.assert_called_once_with(
|
||||
'192.168.1.1',
|
||||
port=22,
|
||||
username='admin',
|
||||
password='password',
|
||||
connect_timeout=10
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_handle_connection_failure(self, router_interface):
|
||||
"""Should handle SSH connection failure gracefully."""
|
||||
with patch('src.hardware.router_interface.asyncssh.connect', new_callable=AsyncMock) as mock_connect:
|
||||
mock_connect.side_effect = ConnectionError("Connection failed")
|
||||
|
||||
result = await router_interface.connect()
|
||||
|
||||
assert result == False
|
||||
assert router_interface.is_connected == False
|
||||
assert router_interface.ssh_client is None
|
||||
router_interface.logger.error.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_disconnect_when_connected(self, router_interface):
|
||||
"""Should disconnect SSH connection when connected."""
|
||||
mock_ssh_client = Mock()
|
||||
router_interface.is_connected = True
|
||||
router_interface.ssh_client = mock_ssh_client
|
||||
|
||||
await router_interface.disconnect()
|
||||
|
||||
assert router_interface.is_connected == False
|
||||
assert router_interface.ssh_client is None
|
||||
mock_ssh_client.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_handle_disconnect_when_not_connected(self, router_interface):
|
||||
"""Should handle disconnect when not connected."""
|
||||
router_interface.is_connected = False
|
||||
router_interface.ssh_client = None
|
||||
|
||||
await router_interface.disconnect()
|
||||
|
||||
# Should not raise any exception
|
||||
assert router_interface.is_connected == False
|
||||
|
||||
# Command execution tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_execute_command_successfully(self, router_interface):
|
||||
"""Should execute SSH command successfully."""
|
||||
mock_ssh_client = Mock()
|
||||
mock_result = Mock()
|
||||
mock_result.stdout = "command output"
|
||||
mock_result.stderr = ""
|
||||
mock_result.returncode = 0
|
||||
|
||||
router_interface.is_connected = True
|
||||
router_interface.ssh_client = mock_ssh_client
|
||||
|
||||
with patch.object(mock_ssh_client, 'run', new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = mock_result
|
||||
|
||||
result = await router_interface.execute_command("test command")
|
||||
|
||||
assert result == "command output"
|
||||
mock_run.assert_called_once_with("test command", timeout=30)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_handle_command_execution_when_not_connected(self, router_interface):
|
||||
"""Should handle command execution when not connected."""
|
||||
router_interface.is_connected = False
|
||||
|
||||
with pytest.raises(RouterConnectionError, match="Not connected to router"):
|
||||
await router_interface.execute_command("test command")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_handle_command_execution_error(self, router_interface):
|
||||
"""Should handle command execution errors."""
|
||||
mock_ssh_client = Mock()
|
||||
mock_result = Mock()
|
||||
mock_result.stdout = ""
|
||||
mock_result.stderr = "command error"
|
||||
mock_result.returncode = 1
|
||||
|
||||
router_interface.is_connected = True
|
||||
router_interface.ssh_client = mock_ssh_client
|
||||
|
||||
with patch.object(mock_ssh_client, 'run', new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = mock_result
|
||||
|
||||
with pytest.raises(RouterConnectionError, match="Command failed"):
|
||||
await router_interface.execute_command("test command")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_retry_command_execution_on_failure(self, router_interface):
|
||||
"""Should retry command execution on temporary failure."""
|
||||
mock_ssh_client = Mock()
|
||||
mock_success_result = Mock()
|
||||
mock_success_result.stdout = "success output"
|
||||
mock_success_result.stderr = ""
|
||||
mock_success_result.returncode = 0
|
||||
|
||||
router_interface.is_connected = True
|
||||
router_interface.ssh_client = mock_ssh_client
|
||||
|
||||
with patch.object(mock_ssh_client, 'run', new_callable=AsyncMock) as mock_run:
|
||||
# First two calls fail, third succeeds
|
||||
mock_run.side_effect = [
|
||||
ConnectionError("Network error"),
|
||||
ConnectionError("Network error"),
|
||||
mock_success_result
|
||||
]
|
||||
|
||||
result = await router_interface.execute_command("test command")
|
||||
|
||||
assert result == "success output"
|
||||
assert mock_run.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_fail_after_max_retries(self, router_interface):
|
||||
"""Should fail after maximum retries exceeded."""
|
||||
mock_ssh_client = Mock()
|
||||
|
||||
router_interface.is_connected = True
|
||||
router_interface.ssh_client = mock_ssh_client
|
||||
|
||||
with patch.object(mock_ssh_client, 'run', new_callable=AsyncMock) as mock_run:
|
||||
mock_run.side_effect = ConnectionError("Network error")
|
||||
|
||||
with pytest.raises(RouterConnectionError, match="Command execution failed after 3 retries"):
|
||||
await router_interface.execute_command("test command")
|
||||
|
||||
assert mock_run.call_count == 3
|
||||
|
||||
# CSI data retrieval tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_get_csi_data_successfully(self, router_interface):
|
||||
"""Should retrieve CSI data successfully."""
|
||||
expected_csi_data = Mock(spec=CSIData)
|
||||
|
||||
with patch.object(router_interface, 'execute_command', new_callable=AsyncMock) as mock_execute:
|
||||
with patch.object(router_interface, '_parse_csi_response', return_value=expected_csi_data) as mock_parse:
|
||||
mock_execute.return_value = "csi data response"
|
||||
|
||||
result = await router_interface.get_csi_data()
|
||||
|
||||
assert result == expected_csi_data
|
||||
mock_execute.assert_called_once_with("iwlist scan | grep CSI")
|
||||
mock_parse.assert_called_once_with("csi data response")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_handle_csi_data_retrieval_failure(self, router_interface):
|
||||
"""Should handle CSI data retrieval failure."""
|
||||
with patch.object(router_interface, 'execute_command', new_callable=AsyncMock) as mock_execute:
|
||||
mock_execute.side_effect = RouterConnectionError("Command failed")
|
||||
|
||||
with pytest.raises(RouterConnectionError):
|
||||
await router_interface.get_csi_data()
|
||||
|
||||
# Router status tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_get_router_status_successfully(self, router_interface):
|
||||
"""Should get router status successfully."""
|
||||
expected_status = {
|
||||
'cpu_usage': 25.5,
|
||||
'memory_usage': 60.2,
|
||||
'wifi_status': 'active',
|
||||
'uptime': '5 days, 3 hours'
|
||||
}
|
||||
|
||||
with patch.object(router_interface, 'execute_command', new_callable=AsyncMock) as mock_execute:
|
||||
with patch.object(router_interface, '_parse_status_response', return_value=expected_status) as mock_parse:
|
||||
mock_execute.return_value = "status response"
|
||||
|
||||
result = await router_interface.get_router_status()
|
||||
|
||||
assert result == expected_status
|
||||
mock_execute.assert_called_once_with("cat /proc/stat && free && iwconfig")
|
||||
mock_parse.assert_called_once_with("status response")
|
||||
|
||||
# Configuration tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_configure_csi_monitoring_successfully(self, router_interface):
|
||||
"""Should configure CSI monitoring successfully."""
|
||||
config = {
|
||||
'channel': 6,
|
||||
'bandwidth': 20,
|
||||
'sample_rate': 100
|
||||
}
|
||||
|
||||
with patch.object(router_interface, 'execute_command', new_callable=AsyncMock) as mock_execute:
|
||||
mock_execute.return_value = "Configuration applied"
|
||||
|
||||
result = await router_interface.configure_csi_monitoring(config)
|
||||
|
||||
assert result == True
|
||||
mock_execute.assert_called_once_with(
|
||||
"iwconfig wlan0 channel 6 && echo 'CSI monitoring configured'"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_handle_csi_monitoring_configuration_failure(self, router_interface):
|
||||
"""Should handle CSI monitoring configuration failure."""
|
||||
config = {
|
||||
'channel': 6,
|
||||
'bandwidth': 20,
|
||||
'sample_rate': 100
|
||||
}
|
||||
|
||||
with patch.object(router_interface, 'execute_command', new_callable=AsyncMock) as mock_execute:
|
||||
mock_execute.side_effect = RouterConnectionError("Command failed")
|
||||
|
||||
result = await router_interface.configure_csi_monitoring(config)
|
||||
|
||||
assert result == False
|
||||
|
||||
# Health check tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_perform_health_check_successfully(self, router_interface):
|
||||
"""Should perform health check successfully."""
|
||||
with patch.object(router_interface, 'execute_command', new_callable=AsyncMock) as mock_execute:
|
||||
mock_execute.return_value = "pong"
|
||||
|
||||
result = await router_interface.health_check()
|
||||
|
||||
assert result == True
|
||||
mock_execute.assert_called_once_with("echo 'ping' && echo 'pong'")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_handle_health_check_failure(self, router_interface):
|
||||
"""Should handle health check failure."""
|
||||
with patch.object(router_interface, 'execute_command', new_callable=AsyncMock) as mock_execute:
|
||||
mock_execute.side_effect = RouterConnectionError("Command failed")
|
||||
|
||||
result = await router_interface.health_check()
|
||||
|
||||
assert result == False
|
||||
|
||||
# Parsing method tests
|
||||
def test_should_parse_csi_response(self, router_interface):
|
||||
"""Should parse CSI response data."""
|
||||
mock_response = "CSI_DATA:timestamp,antennas,subcarriers,frequency,bandwidth"
|
||||
|
||||
with patch('src.hardware.router_interface.CSIData') as mock_csi_data:
|
||||
expected_data = Mock(spec=CSIData)
|
||||
mock_csi_data.return_value = expected_data
|
||||
|
||||
result = router_interface._parse_csi_response(mock_response)
|
||||
|
||||
assert result == expected_data
|
||||
|
||||
def test_should_parse_status_response(self, router_interface):
|
||||
"""Should parse router status response."""
|
||||
mock_response = """
|
||||
cpu 123456 0 78901 234567 0 0 0 0 0 0
|
||||
MemTotal: 1024000 kB
|
||||
MemFree: 512000 kB
|
||||
wlan0 IEEE 802.11 ESSID:"TestNetwork"
|
||||
"""
|
||||
|
||||
result = router_interface._parse_status_response(mock_response)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert 'cpu_usage' in result
|
||||
assert 'memory_usage' in result
|
||||
assert 'wifi_status' in result
|
||||
Reference in New Issue
Block a user