feat: Complete Rust port of WiFi-DensePose with modular crates

Major changes:
- Organized Python v1 implementation into v1/ subdirectory
- Created Rust workspace with 9 modular crates:
  - wifi-densepose-core: Core types, traits, errors
  - wifi-densepose-signal: CSI processing, phase sanitization, FFT
  - wifi-densepose-nn: Neural network inference (ONNX/Candle/tch)
  - wifi-densepose-api: Axum-based REST/WebSocket API
  - wifi-densepose-db: SQLx database layer
  - wifi-densepose-config: Configuration management
  - wifi-densepose-hardware: Hardware abstraction
  - wifi-densepose-wasm: WebAssembly bindings
  - wifi-densepose-cli: Command-line interface

Documentation:
- ADR-001: Workspace structure
- ADR-002: Signal processing library selection
- ADR-003: Neural network inference strategy
- DDD domain model with bounded contexts

Testing:
- 69 tests passing across all crates
- Signal processing: 45 tests
- Neural networks: 21 tests
- Core: 3 doc tests

Performance targets:
- 10x faster CSI processing (~0.5ms vs ~5ms)
- 5x lower memory usage (~100MB vs ~500MB)
- WASM support for browser deployment
This commit is contained in:
Claude
2026-01-13 03:11:16 +00:00
parent 5101504b72
commit 6ed69a3d48
427 changed files with 90993 additions and 0 deletions

View File

@@ -0,0 +1,264 @@
import pytest
import numpy as np
import torch
from unittest.mock import Mock, patch, MagicMock
from src.hardware.csi_extractor import CSIExtractor, CSIExtractionError
class TestCSIExtractor:
"""Test suite for CSI Extractor following London School TDD principles"""
@pytest.fixture
def mock_config(self):
"""Configuration for CSI extractor"""
return {
'interface': 'wlan0',
'channel': 6,
'bandwidth': 20,
'sample_rate': 1000,
'buffer_size': 1024,
'extraction_timeout': 5.0
}
@pytest.fixture
def mock_router_interface(self):
"""Mock router interface for testing"""
mock_router = Mock()
mock_router.is_connected = True
mock_router.execute_command = Mock()
return mock_router
@pytest.fixture
def csi_extractor(self, mock_config, mock_router_interface):
"""Create CSI extractor instance for testing"""
return CSIExtractor(mock_config, mock_router_interface)
@pytest.fixture
def mock_csi_data(self):
"""Generate synthetic CSI data for testing"""
# Simulate CSI data: complex values for multiple subcarriers
num_subcarriers = 56
num_antennas = 3
amplitude = np.random.uniform(0.1, 2.0, (num_antennas, num_subcarriers))
phase = np.random.uniform(-np.pi, np.pi, (num_antennas, num_subcarriers))
return amplitude * np.exp(1j * phase)
def test_extractor_initialization_creates_correct_configuration(self, mock_config, mock_router_interface):
"""Test that CSI extractor initializes with correct configuration"""
# Act
extractor = CSIExtractor(mock_config, mock_router_interface)
# Assert
assert extractor is not None
assert extractor.interface == mock_config['interface']
assert extractor.channel == mock_config['channel']
assert extractor.bandwidth == mock_config['bandwidth']
assert extractor.sample_rate == mock_config['sample_rate']
assert extractor.buffer_size == mock_config['buffer_size']
assert extractor.extraction_timeout == mock_config['extraction_timeout']
assert extractor.router_interface == mock_router_interface
assert not extractor.is_extracting
def test_start_extraction_configures_monitor_mode(self, csi_extractor, mock_router_interface):
"""Test that start_extraction configures monitor mode"""
# Arrange
mock_router_interface.enable_monitor_mode.return_value = True
mock_router_interface.execute_command.return_value = "CSI extraction started"
# Act
result = csi_extractor.start_extraction()
# Assert
assert result is True
assert csi_extractor.is_extracting is True
mock_router_interface.enable_monitor_mode.assert_called_once_with(csi_extractor.interface)
def test_start_extraction_handles_monitor_mode_failure(self, csi_extractor, mock_router_interface):
"""Test that start_extraction handles monitor mode configuration failure"""
# Arrange
mock_router_interface.enable_monitor_mode.return_value = False
# Act & Assert
with pytest.raises(CSIExtractionError):
csi_extractor.start_extraction()
assert csi_extractor.is_extracting is False
def test_stop_extraction_disables_monitor_mode(self, csi_extractor, mock_router_interface):
"""Test that stop_extraction disables monitor mode"""
# Arrange
mock_router_interface.enable_monitor_mode.return_value = True
mock_router_interface.disable_monitor_mode.return_value = True
mock_router_interface.execute_command.return_value = "CSI extraction started"
csi_extractor.start_extraction()
# Act
result = csi_extractor.stop_extraction()
# Assert
assert result is True
assert csi_extractor.is_extracting is False
mock_router_interface.disable_monitor_mode.assert_called_once_with(csi_extractor.interface)
def test_extract_csi_data_returns_valid_format(self, csi_extractor, mock_router_interface, mock_csi_data):
"""Test that extract_csi_data returns data in valid format"""
# Arrange
mock_router_interface.enable_monitor_mode.return_value = True
mock_router_interface.execute_command.return_value = "CSI extraction started"
# Mock the CSI data extraction
with patch.object(csi_extractor, '_parse_csi_output', return_value=mock_csi_data):
csi_extractor.start_extraction()
# Act
csi_data = csi_extractor.extract_csi_data()
# Assert
assert csi_data is not None
assert isinstance(csi_data, np.ndarray)
assert csi_data.dtype == np.complex128
assert csi_data.shape == mock_csi_data.shape
def test_extract_csi_data_requires_active_extraction(self, csi_extractor):
"""Test that extract_csi_data requires active extraction"""
# Act & Assert
with pytest.raises(CSIExtractionError):
csi_extractor.extract_csi_data()
def test_extract_csi_data_handles_timeout(self, csi_extractor, mock_router_interface):
"""Test that extract_csi_data handles extraction timeout"""
# Arrange
mock_router_interface.enable_monitor_mode.return_value = True
mock_router_interface.execute_command.side_effect = [
"CSI extraction started",
Exception("Timeout")
]
csi_extractor.start_extraction()
# Act & Assert
with pytest.raises(CSIExtractionError):
csi_extractor.extract_csi_data()
def test_convert_to_tensor_produces_correct_format(self, csi_extractor, mock_csi_data):
"""Test that convert_to_tensor produces correctly formatted tensor"""
# Act
tensor = csi_extractor.convert_to_tensor(mock_csi_data)
# Assert
assert isinstance(tensor, torch.Tensor)
assert tensor.dtype == torch.float32
assert tensor.shape[0] == mock_csi_data.shape[0] * 2 # Real and imaginary parts
assert tensor.shape[1] == mock_csi_data.shape[1]
def test_convert_to_tensor_handles_invalid_input(self, csi_extractor):
"""Test that convert_to_tensor handles invalid input"""
# Arrange
invalid_data = "not an array"
# Act & Assert
with pytest.raises(ValueError):
csi_extractor.convert_to_tensor(invalid_data)
def test_get_extraction_stats_returns_valid_statistics(self, csi_extractor, mock_router_interface):
"""Test that get_extraction_stats returns valid statistics"""
# Arrange
mock_router_interface.enable_monitor_mode.return_value = True
mock_router_interface.execute_command.return_value = "CSI extraction started"
csi_extractor.start_extraction()
# Act
stats = csi_extractor.get_extraction_stats()
# Assert
assert stats is not None
assert isinstance(stats, dict)
assert 'samples_extracted' in stats
assert 'extraction_rate' in stats
assert 'buffer_utilization' in stats
assert 'last_extraction_time' in stats
def test_set_channel_configures_wifi_channel(self, csi_extractor, mock_router_interface):
"""Test that set_channel configures WiFi channel"""
# Arrange
new_channel = 11
mock_router_interface.execute_command.return_value = f"Channel set to {new_channel}"
# Act
result = csi_extractor.set_channel(new_channel)
# Assert
assert result is True
assert csi_extractor.channel == new_channel
mock_router_interface.execute_command.assert_called()
def test_set_channel_validates_channel_range(self, csi_extractor):
"""Test that set_channel validates channel range"""
# Act & Assert
with pytest.raises(ValueError):
csi_extractor.set_channel(0) # Invalid channel
with pytest.raises(ValueError):
csi_extractor.set_channel(15) # Invalid channel
def test_extractor_supports_context_manager(self, csi_extractor, mock_router_interface):
"""Test that CSI extractor supports context manager protocol"""
# Arrange
mock_router_interface.enable_monitor_mode.return_value = True
mock_router_interface.disable_monitor_mode.return_value = True
mock_router_interface.execute_command.return_value = "CSI extraction started"
# Act
with csi_extractor as extractor:
# Assert
assert extractor.is_extracting is True
# Assert - extraction should be stopped after context
assert csi_extractor.is_extracting is False
def test_extractor_validates_configuration(self, mock_router_interface):
"""Test that CSI extractor validates configuration parameters"""
# Arrange
invalid_config = {
'interface': '', # Invalid interface
'channel': 6,
'bandwidth': 20
}
# Act & Assert
with pytest.raises(ValueError):
CSIExtractor(invalid_config, mock_router_interface)
def test_parse_csi_output_processes_raw_data(self, csi_extractor):
"""Test that _parse_csi_output processes raw CSI data correctly"""
# Arrange
raw_output = "CSI_DATA: 1.5+0.5j,2.0-1.0j,0.8+1.2j"
# Act
parsed_data = csi_extractor._parse_csi_output(raw_output)
# Assert
assert parsed_data is not None
assert isinstance(parsed_data, np.ndarray)
assert parsed_data.dtype == np.complex128
def test_buffer_management_handles_overflow(self, csi_extractor, mock_router_interface, mock_csi_data):
"""Test that buffer management handles overflow correctly"""
# Arrange
mock_router_interface.enable_monitor_mode.return_value = True
mock_router_interface.execute_command.return_value = "CSI extraction started"
with patch.object(csi_extractor, '_parse_csi_output', return_value=mock_csi_data):
csi_extractor.start_extraction()
# Fill buffer beyond capacity
for _ in range(csi_extractor.buffer_size + 10):
csi_extractor._add_to_buffer(mock_csi_data)
# Act
stats = csi_extractor.get_extraction_stats()
# Assert
assert stats['buffer_utilization'] <= 1.0 # Should not exceed 100%

View File

@@ -0,0 +1,588 @@
"""Direct tests for CSI extractor avoiding import issues."""
import pytest
import numpy as np
import sys
import os
from unittest.mock import Mock, patch, AsyncMock, MagicMock
from typing import Dict, Any, Optional
import asyncio
from datetime import datetime, timezone
# Add src to path for direct import
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../'))
# Import the CSI extractor module directly
from src.hardware.csi_extractor import (
CSIExtractor,
CSIParseError,
CSIData,
ESP32CSIParser,
RouterCSIParser,
CSIValidationError
)
@pytest.mark.unit
@pytest.mark.tdd
@pytest.mark.london
class TestCSIExtractorDirect:
"""Test CSI extractor with direct imports."""
@pytest.fixture
def mock_logger(self):
"""Mock logger for testing."""
return Mock()
@pytest.fixture
def esp32_config(self):
"""ESP32 configuration for testing."""
return {
'hardware_type': 'esp32',
'sampling_rate': 100,
'buffer_size': 1024,
'timeout': 5.0,
'validation_enabled': True,
'retry_attempts': 3
}
@pytest.fixture
def router_config(self):
"""Router configuration for testing."""
return {
'hardware_type': 'router',
'sampling_rate': 50,
'buffer_size': 512,
'timeout': 10.0,
'validation_enabled': False,
'retry_attempts': 1
}
@pytest.fixture
def sample_csi_data(self):
"""Sample CSI data for testing."""
return CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=3,
snr=15.5,
metadata={'source': 'esp32', 'channel': 6}
)
# Initialization tests
def test_should_initialize_with_valid_config(self, esp32_config, mock_logger):
"""Should initialize CSI extractor with valid configuration."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
assert extractor.config == esp32_config
assert extractor.logger == mock_logger
assert extractor.is_connected == False
assert extractor.hardware_type == 'esp32'
def test_should_create_esp32_parser(self, esp32_config, mock_logger):
"""Should create ESP32 parser when hardware_type is esp32."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
assert isinstance(extractor.parser, ESP32CSIParser)
def test_should_create_router_parser(self, router_config, mock_logger):
"""Should create router parser when hardware_type is router."""
extractor = CSIExtractor(config=router_config, logger=mock_logger)
assert isinstance(extractor.parser, RouterCSIParser)
assert extractor.hardware_type == 'router'
def test_should_raise_error_for_unsupported_hardware(self, mock_logger):
"""Should raise error for unsupported hardware type."""
invalid_config = {
'hardware_type': 'unsupported',
'sampling_rate': 100,
'buffer_size': 1024,
'timeout': 5.0
}
with pytest.raises(ValueError, match="Unsupported hardware type: unsupported"):
CSIExtractor(config=invalid_config, logger=mock_logger)
# Configuration validation tests
def test_config_validation_missing_fields(self, mock_logger):
"""Should validate required configuration fields."""
invalid_config = {'invalid': 'config'}
with pytest.raises(ValueError, match="Missing required configuration"):
CSIExtractor(config=invalid_config, logger=mock_logger)
def test_config_validation_negative_sampling_rate(self, mock_logger):
"""Should validate sampling_rate is positive."""
invalid_config = {
'hardware_type': 'esp32',
'sampling_rate': -1,
'buffer_size': 1024,
'timeout': 5.0
}
with pytest.raises(ValueError, match="sampling_rate must be positive"):
CSIExtractor(config=invalid_config, logger=mock_logger)
def test_config_validation_zero_buffer_size(self, mock_logger):
"""Should validate buffer_size is positive."""
invalid_config = {
'hardware_type': 'esp32',
'sampling_rate': 100,
'buffer_size': 0,
'timeout': 5.0
}
with pytest.raises(ValueError, match="buffer_size must be positive"):
CSIExtractor(config=invalid_config, logger=mock_logger)
def test_config_validation_negative_timeout(self, mock_logger):
"""Should validate timeout is positive."""
invalid_config = {
'hardware_type': 'esp32',
'sampling_rate': 100,
'buffer_size': 1024,
'timeout': -1.0
}
with pytest.raises(ValueError, match="timeout must be positive"):
CSIExtractor(config=invalid_config, logger=mock_logger)
# Connection tests
@pytest.mark.asyncio
async def test_should_establish_connection_successfully(self, esp32_config, mock_logger):
"""Should establish connection to hardware successfully."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
with patch.object(extractor, '_establish_hardware_connection', new_callable=AsyncMock) as mock_connect:
mock_connect.return_value = True
result = await extractor.connect()
assert result == True
assert extractor.is_connected == True
mock_connect.assert_called_once()
@pytest.mark.asyncio
async def test_should_handle_connection_failure(self, esp32_config, mock_logger):
"""Should handle connection failure gracefully."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
with patch.object(extractor, '_establish_hardware_connection', new_callable=AsyncMock) as mock_connect:
mock_connect.side_effect = ConnectionError("Hardware not found")
result = await extractor.connect()
assert result == False
assert extractor.is_connected == False
extractor.logger.error.assert_called()
@pytest.mark.asyncio
async def test_should_disconnect_properly(self, esp32_config, mock_logger):
"""Should disconnect from hardware properly."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
extractor.is_connected = True
with patch.object(extractor, '_close_hardware_connection', new_callable=AsyncMock) as mock_disconnect:
await extractor.disconnect()
assert extractor.is_connected == False
mock_disconnect.assert_called_once()
@pytest.mark.asyncio
async def test_disconnect_when_not_connected(self, esp32_config, mock_logger):
"""Should handle disconnect when not connected."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
extractor.is_connected = False
with patch.object(extractor, '_close_hardware_connection', new_callable=AsyncMock) as mock_close:
await extractor.disconnect()
# Should not call close when not connected
mock_close.assert_not_called()
assert extractor.is_connected == False
# Data extraction tests
@pytest.mark.asyncio
async def test_should_extract_csi_data_successfully(self, esp32_config, mock_logger, sample_csi_data):
"""Should extract CSI data successfully from hardware."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
extractor.is_connected = True
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
with patch.object(extractor.parser, 'parse', return_value=sample_csi_data) as mock_parse:
mock_read.return_value = b"raw_csi_data"
result = await extractor.extract_csi()
assert result == sample_csi_data
mock_read.assert_called_once()
mock_parse.assert_called_once_with(b"raw_csi_data")
@pytest.mark.asyncio
async def test_should_handle_extraction_failure_when_not_connected(self, esp32_config, mock_logger):
"""Should handle extraction failure when not connected."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
extractor.is_connected = False
with pytest.raises(CSIParseError, match="Not connected to hardware"):
await extractor.extract_csi()
@pytest.mark.asyncio
async def test_should_retry_on_temporary_failure(self, esp32_config, mock_logger, sample_csi_data):
"""Should retry extraction on temporary failure."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
extractor.is_connected = True
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
with patch.object(extractor.parser, 'parse') as mock_parse:
# First two calls fail, third succeeds
mock_read.side_effect = [ConnectionError(), ConnectionError(), b"raw_data"]
mock_parse.return_value = sample_csi_data
result = await extractor.extract_csi()
assert result == sample_csi_data
assert mock_read.call_count == 3
@pytest.mark.asyncio
async def test_extract_with_validation_disabled(self, esp32_config, mock_logger, sample_csi_data):
"""Should skip validation when disabled."""
esp32_config['validation_enabled'] = False
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
extractor.is_connected = True
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
with patch.object(extractor.parser, 'parse', return_value=sample_csi_data) as mock_parse:
with patch.object(extractor, 'validate_csi_data') as mock_validate:
mock_read.return_value = b"raw_data"
result = await extractor.extract_csi()
assert result == sample_csi_data
mock_validate.assert_not_called()
@pytest.mark.asyncio
async def test_extract_max_retries_exceeded(self, esp32_config, mock_logger):
"""Should raise error after max retries exceeded."""
esp32_config['retry_attempts'] = 2
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
extractor.is_connected = True
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
mock_read.side_effect = ConnectionError("Connection failed")
with pytest.raises(CSIParseError, match="Extraction failed after 2 attempts"):
await extractor.extract_csi()
assert mock_read.call_count == 2
# Validation tests
def test_should_validate_csi_data_successfully(self, esp32_config, mock_logger, sample_csi_data):
"""Should validate CSI data successfully."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
result = extractor.validate_csi_data(sample_csi_data)
assert result == True
def test_validation_empty_amplitude(self, esp32_config, mock_logger):
"""Should raise validation error for empty amplitude."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
invalid_data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.array([]),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=3,
snr=15.5,
metadata={}
)
with pytest.raises(CSIValidationError, match="Empty amplitude data"):
extractor.validate_csi_data(invalid_data)
def test_validation_empty_phase(self, esp32_config, mock_logger):
"""Should raise validation error for empty phase."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
invalid_data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.array([]),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=3,
snr=15.5,
metadata={}
)
with pytest.raises(CSIValidationError, match="Empty phase data"):
extractor.validate_csi_data(invalid_data)
def test_validation_invalid_frequency(self, esp32_config, mock_logger):
"""Should raise validation error for invalid frequency."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
invalid_data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=0,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=3,
snr=15.5,
metadata={}
)
with pytest.raises(CSIValidationError, match="Invalid frequency"):
extractor.validate_csi_data(invalid_data)
def test_validation_invalid_bandwidth(self, esp32_config, mock_logger):
"""Should raise validation error for invalid bandwidth."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
invalid_data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=0,
num_subcarriers=56,
num_antennas=3,
snr=15.5,
metadata={}
)
with pytest.raises(CSIValidationError, match="Invalid bandwidth"):
extractor.validate_csi_data(invalid_data)
def test_validation_invalid_subcarriers(self, esp32_config, mock_logger):
"""Should raise validation error for invalid subcarriers."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
invalid_data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=0,
num_antennas=3,
snr=15.5,
metadata={}
)
with pytest.raises(CSIValidationError, match="Invalid number of subcarriers"):
extractor.validate_csi_data(invalid_data)
def test_validation_invalid_antennas(self, esp32_config, mock_logger):
"""Should raise validation error for invalid antennas."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
invalid_data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=0,
snr=15.5,
metadata={}
)
with pytest.raises(CSIValidationError, match="Invalid number of antennas"):
extractor.validate_csi_data(invalid_data)
def test_validation_snr_too_low(self, esp32_config, mock_logger):
"""Should raise validation error for SNR too low."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
invalid_data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=3,
snr=-100,
metadata={}
)
with pytest.raises(CSIValidationError, match="Invalid SNR value"):
extractor.validate_csi_data(invalid_data)
def test_validation_snr_too_high(self, esp32_config, mock_logger):
"""Should raise validation error for SNR too high."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
invalid_data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=3,
snr=100,
metadata={}
)
with pytest.raises(CSIValidationError, match="Invalid SNR value"):
extractor.validate_csi_data(invalid_data)
# Streaming tests
@pytest.mark.asyncio
async def test_should_start_streaming_successfully(self, esp32_config, mock_logger, sample_csi_data):
"""Should start CSI data streaming successfully."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
extractor.is_connected = True
callback = Mock()
with patch.object(extractor, 'extract_csi', new_callable=AsyncMock) as mock_extract:
mock_extract.return_value = sample_csi_data
# Start streaming with limited iterations to avoid infinite loop
streaming_task = asyncio.create_task(extractor.start_streaming(callback))
await asyncio.sleep(0.1) # Let it run briefly
extractor.stop_streaming()
await streaming_task
callback.assert_called()
@pytest.mark.asyncio
async def test_should_stop_streaming_gracefully(self, esp32_config, mock_logger):
"""Should stop streaming gracefully."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
extractor.is_streaming = True
extractor.stop_streaming()
assert extractor.is_streaming == False
@pytest.mark.asyncio
async def test_streaming_with_exception(self, esp32_config, mock_logger):
"""Should handle exceptions during streaming."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
extractor.is_connected = True
callback = Mock()
with patch.object(extractor, 'extract_csi', new_callable=AsyncMock) as mock_extract:
mock_extract.side_effect = Exception("Extraction error")
# Start streaming and let it handle the exception
streaming_task = asyncio.create_task(extractor.start_streaming(callback))
await asyncio.sleep(0.1) # Let it run briefly and hit the exception
await streaming_task
# Should log error and stop streaming
assert extractor.is_streaming == False
extractor.logger.error.assert_called()
@pytest.mark.unit
@pytest.mark.tdd
@pytest.mark.london
class TestESP32CSIParserDirect:
"""Test ESP32 CSI parser with direct imports."""
@pytest.fixture
def parser(self):
"""Create ESP32 CSI parser for testing."""
return ESP32CSIParser()
@pytest.fixture
def raw_esp32_data(self):
"""Sample raw ESP32 CSI data."""
return b"CSI_DATA:1234567890,3,56,2400,20,15.5,[1.0,2.0,3.0],[0.5,1.5,2.5]"
def test_should_parse_valid_esp32_data(self, parser, raw_esp32_data):
"""Should parse valid ESP32 CSI data successfully."""
result = parser.parse(raw_esp32_data)
assert isinstance(result, CSIData)
assert result.num_antennas == 3
assert result.num_subcarriers == 56
assert result.frequency == 2400000000 # 2.4 GHz
assert result.bandwidth == 20000000 # 20 MHz
assert result.snr == 15.5
def test_should_handle_malformed_data(self, parser):
"""Should handle malformed ESP32 data gracefully."""
malformed_data = b"INVALID_DATA"
with pytest.raises(CSIParseError, match="Invalid ESP32 CSI data format"):
parser.parse(malformed_data)
def test_should_handle_empty_data(self, parser):
"""Should handle empty data gracefully."""
with pytest.raises(CSIParseError, match="Empty data received"):
parser.parse(b"")
def test_parse_with_value_error(self, parser):
"""Should handle ValueError during parsing."""
invalid_data = b"CSI_DATA:invalid_timestamp,3,56,2400,20,15.5"
with pytest.raises(CSIParseError, match="Failed to parse ESP32 data"):
parser.parse(invalid_data)
def test_parse_with_index_error(self, parser):
"""Should handle IndexError during parsing."""
invalid_data = b"CSI_DATA:1234567890" # Missing fields
with pytest.raises(CSIParseError, match="Failed to parse ESP32 data"):
parser.parse(invalid_data)
@pytest.mark.unit
@pytest.mark.tdd
@pytest.mark.london
class TestRouterCSIParserDirect:
"""Test Router CSI parser with direct imports."""
@pytest.fixture
def parser(self):
"""Create Router CSI parser for testing."""
return RouterCSIParser()
def test_should_parse_atheros_format(self, parser):
"""Should parse Atheros CSI format successfully."""
raw_data = b"ATHEROS_CSI:mock_data"
with patch.object(parser, '_parse_atheros_format', return_value=Mock(spec=CSIData)) as mock_parse:
result = parser.parse(raw_data)
mock_parse.assert_called_once()
assert result is not None
def test_should_handle_unknown_format(self, parser):
"""Should handle unknown router format gracefully."""
unknown_data = b"UNKNOWN_FORMAT:data"
with pytest.raises(CSIParseError, match="Unknown router CSI format"):
parser.parse(unknown_data)
def test_parse_atheros_format_directly(self, parser):
"""Should parse Atheros format directly."""
raw_data = b"ATHEROS_CSI:mock_data"
result = parser.parse(raw_data)
assert isinstance(result, CSIData)
assert result.metadata['source'] == 'atheros_router'
def test_should_handle_empty_data_router(self, parser):
"""Should handle empty data gracefully."""
with pytest.raises(CSIParseError, match="Empty data received"):
parser.parse(b"")

View File

@@ -0,0 +1,275 @@
"""Test-Driven Development tests for CSI extractor using London School approach."""
import pytest
import numpy as np
from unittest.mock import Mock, patch, AsyncMock, MagicMock
from typing import Dict, Any, Optional
import asyncio
from datetime import datetime, timezone
from src.hardware.csi_extractor import (
CSIExtractor,
CSIParseError,
CSIData,
ESP32CSIParser,
RouterCSIParser,
CSIValidationError
)
@pytest.mark.unit
@pytest.mark.tdd
@pytest.mark.london
class TestCSIExtractor:
"""Test CSI extractor using London School TDD - focus on interactions and behavior."""
@pytest.fixture
def mock_logger(self):
"""Mock logger for testing."""
return Mock()
@pytest.fixture
def mock_config(self):
"""Mock configuration for CSI extractor."""
return {
'hardware_type': 'esp32',
'sampling_rate': 100,
'buffer_size': 1024,
'timeout': 5.0,
'validation_enabled': True,
'retry_attempts': 3
}
@pytest.fixture
def csi_extractor(self, mock_config, mock_logger):
"""Create CSI extractor instance for testing."""
return CSIExtractor(config=mock_config, logger=mock_logger)
@pytest.fixture
def sample_csi_data(self):
"""Sample CSI data for testing."""
return CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=3,
snr=15.5,
metadata={'source': 'esp32', 'channel': 6}
)
def test_should_initialize_with_valid_config(self, mock_config, mock_logger):
"""Should initialize CSI extractor with valid configuration."""
extractor = CSIExtractor(config=mock_config, logger=mock_logger)
assert extractor.config == mock_config
assert extractor.logger == mock_logger
assert extractor.is_connected == False
assert extractor.hardware_type == 'esp32'
def test_should_raise_error_with_invalid_config(self, mock_logger):
"""Should raise error when initialized with invalid configuration."""
invalid_config = {'invalid': 'config'}
with pytest.raises(ValueError, match="Missing required configuration"):
CSIExtractor(config=invalid_config, logger=mock_logger)
def test_should_create_appropriate_parser(self, mock_config, mock_logger):
"""Should create appropriate parser based on hardware type."""
extractor = CSIExtractor(config=mock_config, logger=mock_logger)
assert isinstance(extractor.parser, ESP32CSIParser)
@pytest.mark.asyncio
async def test_should_establish_connection_successfully(self, csi_extractor):
"""Should establish connection to hardware successfully."""
with patch.object(csi_extractor, '_establish_hardware_connection', new_callable=AsyncMock) as mock_connect:
mock_connect.return_value = True
result = await csi_extractor.connect()
assert result == True
assert csi_extractor.is_connected == True
mock_connect.assert_called_once()
@pytest.mark.asyncio
async def test_should_handle_connection_failure(self, csi_extractor):
"""Should handle connection failure gracefully."""
with patch.object(csi_extractor, '_establish_hardware_connection', new_callable=AsyncMock) as mock_connect:
mock_connect.side_effect = ConnectionError("Hardware not found")
result = await csi_extractor.connect()
assert result == False
assert csi_extractor.is_connected == False
csi_extractor.logger.error.assert_called()
@pytest.mark.asyncio
async def test_should_disconnect_properly(self, csi_extractor):
"""Should disconnect from hardware properly."""
csi_extractor.is_connected = True
with patch.object(csi_extractor, '_close_hardware_connection', new_callable=AsyncMock) as mock_disconnect:
await csi_extractor.disconnect()
assert csi_extractor.is_connected == False
mock_disconnect.assert_called_once()
@pytest.mark.asyncio
async def test_should_extract_csi_data_successfully(self, csi_extractor, sample_csi_data):
"""Should extract CSI data successfully from hardware."""
csi_extractor.is_connected = True
with patch.object(csi_extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
with patch.object(csi_extractor.parser, 'parse', return_value=sample_csi_data) as mock_parse:
mock_read.return_value = b"raw_csi_data"
result = await csi_extractor.extract_csi()
assert result == sample_csi_data
mock_read.assert_called_once()
mock_parse.assert_called_once_with(b"raw_csi_data")
@pytest.mark.asyncio
async def test_should_handle_extraction_failure_when_not_connected(self, csi_extractor):
"""Should handle extraction failure when not connected."""
csi_extractor.is_connected = False
with pytest.raises(CSIParseError, match="Not connected to hardware"):
await csi_extractor.extract_csi()
@pytest.mark.asyncio
async def test_should_retry_on_temporary_failure(self, csi_extractor, sample_csi_data):
"""Should retry extraction on temporary failure."""
csi_extractor.is_connected = True
with patch.object(csi_extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
with patch.object(csi_extractor.parser, 'parse') as mock_parse:
# First two calls fail, third succeeds
mock_read.side_effect = [ConnectionError(), ConnectionError(), b"raw_data"]
mock_parse.return_value = sample_csi_data
result = await csi_extractor.extract_csi()
assert result == sample_csi_data
assert mock_read.call_count == 3
def test_should_validate_csi_data_successfully(self, csi_extractor, sample_csi_data):
"""Should validate CSI data successfully."""
result = csi_extractor.validate_csi_data(sample_csi_data)
assert result == True
def test_should_reject_invalid_csi_data(self, csi_extractor):
"""Should reject CSI data with invalid structure."""
invalid_data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.array([]), # Empty array
phase=np.array([]),
frequency=0, # Invalid frequency
bandwidth=0,
num_subcarriers=0,
num_antennas=0,
snr=-100, # Invalid SNR
metadata={}
)
with pytest.raises(CSIValidationError):
csi_extractor.validate_csi_data(invalid_data)
@pytest.mark.asyncio
async def test_should_start_streaming_successfully(self, csi_extractor, sample_csi_data):
"""Should start CSI data streaming successfully."""
csi_extractor.is_connected = True
callback = Mock()
with patch.object(csi_extractor, 'extract_csi', new_callable=AsyncMock) as mock_extract:
mock_extract.return_value = sample_csi_data
# Start streaming with limited iterations to avoid infinite loop
streaming_task = asyncio.create_task(csi_extractor.start_streaming(callback))
await asyncio.sleep(0.1) # Let it run briefly
csi_extractor.stop_streaming()
await streaming_task
callback.assert_called()
@pytest.mark.asyncio
async def test_should_stop_streaming_gracefully(self, csi_extractor):
"""Should stop streaming gracefully."""
csi_extractor.is_streaming = True
csi_extractor.stop_streaming()
assert csi_extractor.is_streaming == False
@pytest.mark.unit
@pytest.mark.tdd
@pytest.mark.london
class TestESP32CSIParser:
"""Test ESP32 CSI parser using London School TDD."""
@pytest.fixture
def parser(self):
"""Create ESP32 CSI parser for testing."""
return ESP32CSIParser()
@pytest.fixture
def raw_esp32_data(self):
"""Sample raw ESP32 CSI data."""
return b"CSI_DATA:1234567890,3,56,2400,20,15.5,[1.0,2.0,3.0],[0.5,1.5,2.5]"
def test_should_parse_valid_esp32_data(self, parser, raw_esp32_data):
"""Should parse valid ESP32 CSI data successfully."""
result = parser.parse(raw_esp32_data)
assert isinstance(result, CSIData)
assert result.num_antennas == 3
assert result.num_subcarriers == 56
assert result.frequency == 2400000000 # 2.4 GHz
assert result.bandwidth == 20000000 # 20 MHz
assert result.snr == 15.5
def test_should_handle_malformed_data(self, parser):
"""Should handle malformed ESP32 data gracefully."""
malformed_data = b"INVALID_DATA"
with pytest.raises(CSIParseError, match="Invalid ESP32 CSI data format"):
parser.parse(malformed_data)
def test_should_handle_empty_data(self, parser):
"""Should handle empty data gracefully."""
with pytest.raises(CSIParseError, match="Empty data received"):
parser.parse(b"")
@pytest.mark.unit
@pytest.mark.tdd
@pytest.mark.london
class TestRouterCSIParser:
"""Test Router CSI parser using London School TDD."""
@pytest.fixture
def parser(self):
"""Create Router CSI parser for testing."""
return RouterCSIParser()
def test_should_parse_atheros_format(self, parser):
"""Should parse Atheros CSI format successfully."""
raw_data = b"ATHEROS_CSI:mock_data"
with patch.object(parser, '_parse_atheros_format', return_value=Mock(spec=CSIData)) as mock_parse:
result = parser.parse(raw_data)
mock_parse.assert_called_once()
assert result is not None
def test_should_handle_unknown_format(self, parser):
"""Should handle unknown router format gracefully."""
unknown_data = b"UNKNOWN_FORMAT:data"
with pytest.raises(CSIParseError, match="Unknown router CSI format"):
parser.parse(unknown_data)

View File

@@ -0,0 +1,386 @@
"""Complete TDD tests for CSI extractor with 100% coverage."""
import pytest
import numpy as np
from unittest.mock import Mock, patch, AsyncMock, MagicMock
from typing import Dict, Any, Optional
import asyncio
from datetime import datetime, timezone
from src.hardware.csi_extractor import (
CSIExtractor,
CSIParseError,
CSIData,
ESP32CSIParser,
RouterCSIParser,
CSIValidationError
)
@pytest.mark.unit
@pytest.mark.tdd
@pytest.mark.london
class TestCSIExtractorComplete:
"""Complete CSI extractor tests for 100% coverage."""
@pytest.fixture
def mock_logger(self):
"""Mock logger for testing."""
return Mock()
@pytest.fixture
def esp32_config(self):
"""ESP32 configuration for testing."""
return {
'hardware_type': 'esp32',
'sampling_rate': 100,
'buffer_size': 1024,
'timeout': 5.0,
'validation_enabled': True,
'retry_attempts': 3
}
@pytest.fixture
def router_config(self):
"""Router configuration for testing."""
return {
'hardware_type': 'router',
'sampling_rate': 50,
'buffer_size': 512,
'timeout': 10.0,
'validation_enabled': False,
'retry_attempts': 1
}
@pytest.fixture
def sample_csi_data(self):
"""Sample CSI data for testing."""
return CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=3,
snr=15.5,
metadata={'source': 'esp32', 'channel': 6}
)
def test_should_create_router_parser(self, router_config, mock_logger):
"""Should create router parser when hardware_type is router."""
extractor = CSIExtractor(config=router_config, logger=mock_logger)
assert isinstance(extractor.parser, RouterCSIParser)
assert extractor.hardware_type == 'router'
def test_should_raise_error_for_unsupported_hardware(self, mock_logger):
"""Should raise error for unsupported hardware type."""
invalid_config = {
'hardware_type': 'unsupported',
'sampling_rate': 100,
'buffer_size': 1024,
'timeout': 5.0
}
with pytest.raises(ValueError, match="Unsupported hardware type: unsupported"):
CSIExtractor(config=invalid_config, logger=mock_logger)
def test_config_validation_negative_sampling_rate(self, mock_logger):
"""Should validate sampling_rate is positive."""
invalid_config = {
'hardware_type': 'esp32',
'sampling_rate': -1,
'buffer_size': 1024,
'timeout': 5.0
}
with pytest.raises(ValueError, match="sampling_rate must be positive"):
CSIExtractor(config=invalid_config, logger=mock_logger)
def test_config_validation_zero_buffer_size(self, mock_logger):
"""Should validate buffer_size is positive."""
invalid_config = {
'hardware_type': 'esp32',
'sampling_rate': 100,
'buffer_size': 0,
'timeout': 5.0
}
with pytest.raises(ValueError, match="buffer_size must be positive"):
CSIExtractor(config=invalid_config, logger=mock_logger)
def test_config_validation_negative_timeout(self, mock_logger):
"""Should validate timeout is positive."""
invalid_config = {
'hardware_type': 'esp32',
'sampling_rate': 100,
'buffer_size': 1024,
'timeout': -1.0
}
with pytest.raises(ValueError, match="timeout must be positive"):
CSIExtractor(config=invalid_config, logger=mock_logger)
@pytest.mark.asyncio
async def test_disconnect_when_not_connected(self, esp32_config, mock_logger):
"""Should handle disconnect when not connected."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
extractor.is_connected = False
with patch.object(extractor, '_close_hardware_connection', new_callable=AsyncMock) as mock_close:
await extractor.disconnect()
# Should not call close when not connected
mock_close.assert_not_called()
assert extractor.is_connected == False
@pytest.mark.asyncio
async def test_extract_with_validation_disabled(self, esp32_config, mock_logger, sample_csi_data):
"""Should skip validation when disabled."""
esp32_config['validation_enabled'] = False
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
extractor.is_connected = True
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
with patch.object(extractor.parser, 'parse', return_value=sample_csi_data) as mock_parse:
with patch.object(extractor, 'validate_csi_data') as mock_validate:
mock_read.return_value = b"raw_data"
result = await extractor.extract_csi()
assert result == sample_csi_data
mock_validate.assert_not_called()
@pytest.mark.asyncio
async def test_extract_max_retries_exceeded(self, esp32_config, mock_logger):
"""Should raise error after max retries exceeded."""
esp32_config['retry_attempts'] = 2
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
extractor.is_connected = True
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
mock_read.side_effect = ConnectionError("Connection failed")
with pytest.raises(CSIParseError, match="Extraction failed after 2 attempts"):
await extractor.extract_csi()
assert mock_read.call_count == 2
def test_validation_empty_amplitude(self, esp32_config, mock_logger):
"""Should raise validation error for empty amplitude."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
invalid_data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.array([]),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=3,
snr=15.5,
metadata={}
)
with pytest.raises(CSIValidationError, match="Empty amplitude data"):
extractor.validate_csi_data(invalid_data)
def test_validation_empty_phase(self, esp32_config, mock_logger):
"""Should raise validation error for empty phase."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
invalid_data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.array([]),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=3,
snr=15.5,
metadata={}
)
with pytest.raises(CSIValidationError, match="Empty phase data"):
extractor.validate_csi_data(invalid_data)
def test_validation_invalid_frequency(self, esp32_config, mock_logger):
"""Should raise validation error for invalid frequency."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
invalid_data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=0,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=3,
snr=15.5,
metadata={}
)
with pytest.raises(CSIValidationError, match="Invalid frequency"):
extractor.validate_csi_data(invalid_data)
def test_validation_invalid_bandwidth(self, esp32_config, mock_logger):
"""Should raise validation error for invalid bandwidth."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
invalid_data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=0,
num_subcarriers=56,
num_antennas=3,
snr=15.5,
metadata={}
)
with pytest.raises(CSIValidationError, match="Invalid bandwidth"):
extractor.validate_csi_data(invalid_data)
def test_validation_invalid_subcarriers(self, esp32_config, mock_logger):
"""Should raise validation error for invalid subcarriers."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
invalid_data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=0,
num_antennas=3,
snr=15.5,
metadata={}
)
with pytest.raises(CSIValidationError, match="Invalid number of subcarriers"):
extractor.validate_csi_data(invalid_data)
def test_validation_invalid_antennas(self, esp32_config, mock_logger):
"""Should raise validation error for invalid antennas."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
invalid_data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=0,
snr=15.5,
metadata={}
)
with pytest.raises(CSIValidationError, match="Invalid number of antennas"):
extractor.validate_csi_data(invalid_data)
def test_validation_snr_too_low(self, esp32_config, mock_logger):
"""Should raise validation error for SNR too low."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
invalid_data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=3,
snr=-100,
metadata={}
)
with pytest.raises(CSIValidationError, match="Invalid SNR value"):
extractor.validate_csi_data(invalid_data)
def test_validation_snr_too_high(self, esp32_config, mock_logger):
"""Should raise validation error for SNR too high."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
invalid_data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=3,
snr=100,
metadata={}
)
with pytest.raises(CSIValidationError, match="Invalid SNR value"):
extractor.validate_csi_data(invalid_data)
@pytest.mark.asyncio
async def test_streaming_with_exception(self, esp32_config, mock_logger):
"""Should handle exceptions during streaming."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
extractor.is_connected = True
callback = Mock()
with patch.object(extractor, 'extract_csi', new_callable=AsyncMock) as mock_extract:
mock_extract.side_effect = Exception("Extraction error")
# Start streaming and let it handle the exception
streaming_task = asyncio.create_task(extractor.start_streaming(callback))
await asyncio.sleep(0.1) # Let it run briefly and hit the exception
await streaming_task
# Should log error and stop streaming
assert extractor.is_streaming == False
extractor.logger.error.assert_called()
@pytest.mark.unit
@pytest.mark.tdd
@pytest.mark.london
class TestESP32CSIParserComplete:
"""Complete ESP32 CSI parser tests for 100% coverage."""
@pytest.fixture
def parser(self):
"""Create ESP32 CSI parser for testing."""
return ESP32CSIParser()
def test_parse_with_value_error(self, parser):
"""Should handle ValueError during parsing."""
invalid_data = b"CSI_DATA:invalid_timestamp,3,56,2400,20,15.5"
with pytest.raises(CSIParseError, match="Failed to parse ESP32 data"):
parser.parse(invalid_data)
def test_parse_with_index_error(self, parser):
"""Should handle IndexError during parsing."""
invalid_data = b"CSI_DATA:1234567890" # Missing fields
with pytest.raises(CSIParseError, match="Failed to parse ESP32 data"):
parser.parse(invalid_data)
@pytest.mark.unit
@pytest.mark.tdd
@pytest.mark.london
class TestRouterCSIParserComplete:
"""Complete Router CSI parser tests for 100% coverage."""
@pytest.fixture
def parser(self):
"""Create Router CSI parser for testing."""
return RouterCSIParser()
def test_parse_atheros_format_directly(self, parser):
"""Should parse Atheros format directly."""
raw_data = b"ATHEROS_CSI:mock_data"
result = parser.parse(raw_data)
assert isinstance(result, CSIData)
assert result.metadata['source'] == 'atheros_router'

View File

@@ -0,0 +1,87 @@
import pytest
import numpy as np
from unittest.mock import Mock, patch
from src.core.csi_processor import CSIProcessor
class TestCSIProcessor:
"""Test suite for CSI processor following London School TDD principles"""
@pytest.fixture
def mock_csi_data(self):
"""Generate synthetic CSI data for testing"""
# Simple raw CSI data array for testing
return np.random.uniform(0.1, 2.0, (3, 56, 100))
@pytest.fixture
def csi_processor(self):
"""Create CSI processor instance for testing"""
return CSIProcessor()
def test_process_csi_data_returns_normalized_output(self, csi_processor, mock_csi_data):
"""Test that CSI processing returns properly normalized output"""
# Act
result = csi_processor.process_raw_csi(mock_csi_data)
# Assert
assert result is not None
assert isinstance(result, np.ndarray)
assert result.shape == mock_csi_data.shape
# Verify normalization - mean should be close to 0, std close to 1
assert abs(result.mean()) < 0.1
assert abs(result.std() - 1.0) < 0.1
def test_process_csi_data_handles_invalid_input(self, csi_processor):
"""Test that CSI processor handles invalid input gracefully"""
# Arrange
invalid_data = np.array([])
# Act & Assert
with pytest.raises(ValueError, match="Raw CSI data cannot be empty"):
csi_processor.process_raw_csi(invalid_data)
def test_process_csi_data_removes_nan_values(self, csi_processor, mock_csi_data):
"""Test that CSI processor removes NaN values from input"""
# Arrange
mock_csi_data[0, 0, 0] = np.nan
# Act
result = csi_processor.process_raw_csi(mock_csi_data)
# Assert
assert not np.isnan(result).any()
def test_process_csi_data_applies_temporal_filtering(self, csi_processor, mock_csi_data):
"""Test that temporal filtering is applied to CSI data"""
# Arrange - Add noise to make filtering effect visible
noisy_data = mock_csi_data + np.random.normal(0, 0.1, mock_csi_data.shape)
# Act
result = csi_processor.process_raw_csi(noisy_data)
# Assert - Result should be normalized
assert isinstance(result, np.ndarray)
assert result.shape == noisy_data.shape
def test_process_csi_data_preserves_metadata(self, csi_processor, mock_csi_data):
"""Test that metadata is preserved during processing"""
# Act
result = csi_processor.process_raw_csi(mock_csi_data)
# Assert - For now, just verify processing works
assert result is not None
assert isinstance(result, np.ndarray)
def test_process_csi_data_performance_requirement(self, csi_processor, mock_csi_data):
"""Test that CSI processing meets performance requirements (<10ms)"""
import time
# Act
start_time = time.time()
result = csi_processor.process_raw_csi(mock_csi_data)
processing_time = time.time() - start_time
# Assert
assert processing_time < 0.01 # <10ms requirement
assert result is not None

View File

@@ -0,0 +1,479 @@
"""TDD tests for CSI processor following London School approach."""
import pytest
import numpy as np
import sys
import os
from unittest.mock import Mock, patch, AsyncMock, MagicMock
from datetime import datetime, timezone
import importlib.util
from typing import Dict, List, Any
# Import the CSI processor module directly
spec = importlib.util.spec_from_file_location(
'csi_processor',
'/workspaces/wifi-densepose/src/core/csi_processor.py'
)
csi_processor_module = importlib.util.module_from_spec(spec)
# Import CSI extractor for dependencies
csi_spec = importlib.util.spec_from_file_location(
'csi_extractor',
'/workspaces/wifi-densepose/src/hardware/csi_extractor.py'
)
csi_module = importlib.util.module_from_spec(csi_spec)
csi_spec.loader.exec_module(csi_module)
# Make dependencies available and load the processor
csi_processor_module.CSIData = csi_module.CSIData
spec.loader.exec_module(csi_processor_module)
# Get classes from modules
CSIProcessor = csi_processor_module.CSIProcessor
CSIProcessingError = csi_processor_module.CSIProcessingError
HumanDetectionResult = csi_processor_module.HumanDetectionResult
CSIFeatures = csi_processor_module.CSIFeatures
CSIData = csi_module.CSIData
@pytest.mark.unit
@pytest.mark.tdd
@pytest.mark.london
class TestCSIProcessor:
"""Test CSI processor using London School TDD."""
@pytest.fixture
def mock_logger(self):
"""Mock logger for testing."""
return Mock()
@pytest.fixture
def processor_config(self):
"""CSI processor configuration for testing."""
return {
'sampling_rate': 100,
'window_size': 256,
'overlap': 0.5,
'noise_threshold': -60.0,
'human_detection_threshold': 0.7,
'smoothing_factor': 0.8,
'max_history_size': 1000,
'enable_preprocessing': True,
'enable_feature_extraction': True,
'enable_human_detection': True
}
@pytest.fixture
def csi_processor(self, processor_config, mock_logger):
"""Create CSI processor for testing."""
return CSIProcessor(config=processor_config, logger=mock_logger)
@pytest.fixture
def sample_csi_data(self):
"""Sample CSI data for testing."""
return CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56) + 1.0, # Ensure positive amplitude
phase=np.random.uniform(-np.pi, np.pi, (3, 56)),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=3,
snr=15.5,
metadata={'source': 'test'}
)
@pytest.fixture
def sample_features(self):
"""Sample CSI features for testing."""
return CSIFeatures(
amplitude_mean=np.random.rand(56),
amplitude_variance=np.random.rand(56),
phase_difference=np.random.rand(56),
correlation_matrix=np.random.rand(3, 3),
doppler_shift=np.random.rand(10),
power_spectral_density=np.random.rand(128),
timestamp=datetime.now(timezone.utc),
metadata={'processing_params': {}}
)
# Initialization tests
def test_should_initialize_with_valid_config(self, processor_config, mock_logger):
"""Should initialize CSI processor with valid configuration."""
processor = CSIProcessor(config=processor_config, logger=mock_logger)
assert processor.config == processor_config
assert processor.logger == mock_logger
assert processor.sampling_rate == 100
assert processor.window_size == 256
assert processor.overlap == 0.5
assert processor.noise_threshold == -60.0
assert processor.human_detection_threshold == 0.7
assert processor.smoothing_factor == 0.8
assert processor.max_history_size == 1000
assert len(processor.csi_history) == 0
def test_should_raise_error_with_invalid_config(self, mock_logger):
"""Should raise error when initialized with invalid configuration."""
invalid_config = {'invalid': 'config'}
with pytest.raises(ValueError, match="Missing required configuration"):
CSIProcessor(config=invalid_config, logger=mock_logger)
def test_should_validate_required_fields(self, mock_logger):
"""Should validate all required configuration fields."""
required_fields = ['sampling_rate', 'window_size', 'overlap', 'noise_threshold']
base_config = {
'sampling_rate': 100,
'window_size': 256,
'overlap': 0.5,
'noise_threshold': -60.0
}
for field in required_fields:
config = base_config.copy()
del config[field]
with pytest.raises(ValueError, match="Missing required configuration"):
CSIProcessor(config=config, logger=mock_logger)
def test_should_use_default_values(self, mock_logger):
"""Should use default values for optional parameters."""
minimal_config = {
'sampling_rate': 100,
'window_size': 256,
'overlap': 0.5,
'noise_threshold': -60.0
}
processor = CSIProcessor(config=minimal_config, logger=mock_logger)
assert processor.human_detection_threshold == 0.8 # default
assert processor.smoothing_factor == 0.9 # default
assert processor.max_history_size == 500 # default
def test_should_initialize_without_logger(self, processor_config):
"""Should initialize without logger provided."""
processor = CSIProcessor(config=processor_config)
assert processor.logger is not None # Should create default logger
# Preprocessing tests
def test_should_preprocess_csi_data_successfully(self, csi_processor, sample_csi_data):
"""Should preprocess CSI data successfully."""
with patch.object(csi_processor, '_remove_noise') as mock_noise:
with patch.object(csi_processor, '_apply_windowing') as mock_window:
with patch.object(csi_processor, '_normalize_amplitude') as mock_normalize:
mock_noise.return_value = sample_csi_data
mock_window.return_value = sample_csi_data
mock_normalize.return_value = sample_csi_data
result = csi_processor.preprocess_csi_data(sample_csi_data)
assert result == sample_csi_data
mock_noise.assert_called_once_with(sample_csi_data)
mock_window.assert_called_once()
mock_normalize.assert_called_once()
def test_should_skip_preprocessing_when_disabled(self, processor_config, mock_logger, sample_csi_data):
"""Should skip preprocessing when disabled."""
processor_config['enable_preprocessing'] = False
processor = CSIProcessor(config=processor_config, logger=mock_logger)
result = processor.preprocess_csi_data(sample_csi_data)
assert result == sample_csi_data
def test_should_handle_preprocessing_error(self, csi_processor, sample_csi_data):
"""Should handle preprocessing errors gracefully."""
with patch.object(csi_processor, '_remove_noise') as mock_noise:
mock_noise.side_effect = Exception("Preprocessing error")
with pytest.raises(CSIProcessingError, match="Failed to preprocess CSI data"):
csi_processor.preprocess_csi_data(sample_csi_data)
# Feature extraction tests
def test_should_extract_features_successfully(self, csi_processor, sample_csi_data, sample_features):
"""Should extract features from CSI data successfully."""
with patch.object(csi_processor, '_extract_amplitude_features') as mock_amp:
with patch.object(csi_processor, '_extract_phase_features') as mock_phase:
with patch.object(csi_processor, '_extract_correlation_features') as mock_corr:
with patch.object(csi_processor, '_extract_doppler_features') as mock_doppler:
mock_amp.return_value = (sample_features.amplitude_mean, sample_features.amplitude_variance)
mock_phase.return_value = sample_features.phase_difference
mock_corr.return_value = sample_features.correlation_matrix
mock_doppler.return_value = (sample_features.doppler_shift, sample_features.power_spectral_density)
result = csi_processor.extract_features(sample_csi_data)
assert isinstance(result, CSIFeatures)
assert np.array_equal(result.amplitude_mean, sample_features.amplitude_mean)
assert np.array_equal(result.amplitude_variance, sample_features.amplitude_variance)
mock_amp.assert_called_once()
mock_phase.assert_called_once()
mock_corr.assert_called_once()
mock_doppler.assert_called_once()
def test_should_skip_feature_extraction_when_disabled(self, processor_config, mock_logger, sample_csi_data):
"""Should skip feature extraction when disabled."""
processor_config['enable_feature_extraction'] = False
processor = CSIProcessor(config=processor_config, logger=mock_logger)
result = processor.extract_features(sample_csi_data)
assert result is None
def test_should_handle_feature_extraction_error(self, csi_processor, sample_csi_data):
"""Should handle feature extraction errors gracefully."""
with patch.object(csi_processor, '_extract_amplitude_features') as mock_amp:
mock_amp.side_effect = Exception("Feature extraction error")
with pytest.raises(CSIProcessingError, match="Failed to extract features"):
csi_processor.extract_features(sample_csi_data)
# Human detection tests
def test_should_detect_human_presence_successfully(self, csi_processor, sample_features):
"""Should detect human presence successfully."""
with patch.object(csi_processor, '_analyze_motion_patterns') as mock_motion:
with patch.object(csi_processor, '_calculate_detection_confidence') as mock_confidence:
with patch.object(csi_processor, '_apply_temporal_smoothing') as mock_smooth:
mock_motion.return_value = 0.9
mock_confidence.return_value = 0.85
mock_smooth.return_value = 0.88
result = csi_processor.detect_human_presence(sample_features)
assert isinstance(result, HumanDetectionResult)
assert result.human_detected == True
assert result.confidence == 0.88
assert result.motion_score == 0.9
mock_motion.assert_called_once()
mock_confidence.assert_called_once()
mock_smooth.assert_called_once()
def test_should_detect_no_human_presence(self, csi_processor, sample_features):
"""Should detect no human presence when confidence is low."""
with patch.object(csi_processor, '_analyze_motion_patterns') as mock_motion:
with patch.object(csi_processor, '_calculate_detection_confidence') as mock_confidence:
with patch.object(csi_processor, '_apply_temporal_smoothing') as mock_smooth:
mock_motion.return_value = 0.3
mock_confidence.return_value = 0.2
mock_smooth.return_value = 0.25
result = csi_processor.detect_human_presence(sample_features)
assert result.human_detected == False
assert result.confidence == 0.25
assert result.motion_score == 0.3
def test_should_skip_human_detection_when_disabled(self, processor_config, mock_logger, sample_features):
"""Should skip human detection when disabled."""
processor_config['enable_human_detection'] = False
processor = CSIProcessor(config=processor_config, logger=mock_logger)
result = processor.detect_human_presence(sample_features)
assert result is None
def test_should_handle_human_detection_error(self, csi_processor, sample_features):
"""Should handle human detection errors gracefully."""
with patch.object(csi_processor, '_analyze_motion_patterns') as mock_motion:
mock_motion.side_effect = Exception("Detection error")
with pytest.raises(CSIProcessingError, match="Failed to detect human presence"):
csi_processor.detect_human_presence(sample_features)
# Processing pipeline tests
@pytest.mark.asyncio
async def test_should_process_csi_data_pipeline_successfully(self, csi_processor, sample_csi_data, sample_features):
"""Should process CSI data through full pipeline successfully."""
expected_detection = HumanDetectionResult(
human_detected=True,
confidence=0.85,
motion_score=0.9,
timestamp=datetime.now(timezone.utc),
features=sample_features,
metadata={}
)
with patch.object(csi_processor, 'preprocess_csi_data', return_value=sample_csi_data) as mock_preprocess:
with patch.object(csi_processor, 'extract_features', return_value=sample_features) as mock_features:
with patch.object(csi_processor, 'detect_human_presence', return_value=expected_detection) as mock_detect:
result = await csi_processor.process_csi_data(sample_csi_data)
assert result == expected_detection
mock_preprocess.assert_called_once_with(sample_csi_data)
mock_features.assert_called_once_with(sample_csi_data)
mock_detect.assert_called_once_with(sample_features)
@pytest.mark.asyncio
async def test_should_handle_pipeline_processing_error(self, csi_processor, sample_csi_data):
"""Should handle pipeline processing errors gracefully."""
with patch.object(csi_processor, 'preprocess_csi_data') as mock_preprocess:
mock_preprocess.side_effect = CSIProcessingError("Pipeline error")
with pytest.raises(CSIProcessingError):
await csi_processor.process_csi_data(sample_csi_data)
# History management tests
def test_should_add_csi_data_to_history(self, csi_processor, sample_csi_data):
"""Should add CSI data to history successfully."""
csi_processor.add_to_history(sample_csi_data)
assert len(csi_processor.csi_history) == 1
assert csi_processor.csi_history[0] == sample_csi_data
def test_should_maintain_history_size_limit(self, processor_config, mock_logger):
"""Should maintain history size within limits."""
processor_config['max_history_size'] = 2
processor = CSIProcessor(config=processor_config, logger=mock_logger)
# Add 3 items to history of size 2
for i in range(3):
csi_data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=3,
snr=15.5,
metadata={'index': i}
)
processor.add_to_history(csi_data)
assert len(processor.csi_history) == 2
assert processor.csi_history[0].metadata['index'] == 1 # First item removed
assert processor.csi_history[1].metadata['index'] == 2
def test_should_clear_history(self, csi_processor, sample_csi_data):
"""Should clear history successfully."""
csi_processor.add_to_history(sample_csi_data)
assert len(csi_processor.csi_history) > 0
csi_processor.clear_history()
assert len(csi_processor.csi_history) == 0
def test_should_get_recent_history(self, csi_processor):
"""Should get recent history entries."""
# Add 5 items to history
for i in range(5):
csi_data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=3,
snr=15.5,
metadata={'index': i}
)
csi_processor.add_to_history(csi_data)
recent = csi_processor.get_recent_history(3)
assert len(recent) == 3
assert recent[0].metadata['index'] == 2 # Most recent first
assert recent[1].metadata['index'] == 3
assert recent[2].metadata['index'] == 4
# Statistics and monitoring tests
def test_should_get_processing_statistics(self, csi_processor):
"""Should get processing statistics."""
# Simulate some processing
csi_processor._total_processed = 100
csi_processor._processing_errors = 5
csi_processor._human_detections = 25
stats = csi_processor.get_processing_statistics()
assert isinstance(stats, dict)
assert stats['total_processed'] == 100
assert stats['processing_errors'] == 5
assert stats['human_detections'] == 25
assert stats['error_rate'] == 0.05
assert stats['detection_rate'] == 0.25
def test_should_reset_statistics(self, csi_processor):
"""Should reset processing statistics."""
csi_processor._total_processed = 100
csi_processor._processing_errors = 5
csi_processor._human_detections = 25
csi_processor.reset_statistics()
assert csi_processor._total_processed == 0
assert csi_processor._processing_errors == 0
assert csi_processor._human_detections == 0
@pytest.mark.unit
@pytest.mark.tdd
@pytest.mark.london
class TestCSIFeatures:
"""Test CSI features data structure."""
def test_should_create_csi_features(self):
"""Should create CSI features successfully."""
features = CSIFeatures(
amplitude_mean=np.random.rand(56),
amplitude_variance=np.random.rand(56),
phase_difference=np.random.rand(56),
correlation_matrix=np.random.rand(3, 3),
doppler_shift=np.random.rand(10),
power_spectral_density=np.random.rand(128),
timestamp=datetime.now(timezone.utc),
metadata={'test': 'data'}
)
assert features.amplitude_mean.shape == (56,)
assert features.amplitude_variance.shape == (56,)
assert features.phase_difference.shape == (56,)
assert features.correlation_matrix.shape == (3, 3)
assert features.doppler_shift.shape == (10,)
assert features.power_spectral_density.shape == (128,)
assert isinstance(features.timestamp, datetime)
assert features.metadata['test'] == 'data'
@pytest.mark.unit
@pytest.mark.tdd
@pytest.mark.london
class TestHumanDetectionResult:
"""Test human detection result data structure."""
@pytest.fixture
def sample_features(self):
"""Sample features for testing."""
return CSIFeatures(
amplitude_mean=np.random.rand(56),
amplitude_variance=np.random.rand(56),
phase_difference=np.random.rand(56),
correlation_matrix=np.random.rand(3, 3),
doppler_shift=np.random.rand(10),
power_spectral_density=np.random.rand(128),
timestamp=datetime.now(timezone.utc),
metadata={}
)
def test_should_create_detection_result(self, sample_features):
"""Should create human detection result successfully."""
result = HumanDetectionResult(
human_detected=True,
confidence=0.85,
motion_score=0.92,
timestamp=datetime.now(timezone.utc),
features=sample_features,
metadata={'test': 'data'}
)
assert result.human_detected == True
assert result.confidence == 0.85
assert result.motion_score == 0.92
assert isinstance(result.timestamp, datetime)
assert result.features == sample_features
assert result.metadata['test'] == 'data'

View File

@@ -0,0 +1,599 @@
"""Standalone tests for CSI extractor module."""
import pytest
import numpy as np
import sys
import os
from unittest.mock import Mock, patch, AsyncMock
import asyncio
from datetime import datetime, timezone
import importlib.util
# Import the module directly to avoid circular imports
spec = importlib.util.spec_from_file_location(
'csi_extractor',
'/workspaces/wifi-densepose/src/hardware/csi_extractor.py'
)
csi_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(csi_module)
# Get classes from the module
CSIExtractor = csi_module.CSIExtractor
CSIParseError = csi_module.CSIParseError
CSIData = csi_module.CSIData
ESP32CSIParser = csi_module.ESP32CSIParser
RouterCSIParser = csi_module.RouterCSIParser
CSIValidationError = csi_module.CSIValidationError
@pytest.mark.unit
@pytest.mark.tdd
@pytest.mark.london
class TestCSIExtractorStandalone:
"""Standalone tests for CSI extractor with 100% coverage."""
@pytest.fixture
def mock_logger(self):
"""Mock logger for testing."""
return Mock()
@pytest.fixture
def esp32_config(self):
"""ESP32 configuration for testing."""
return {
'hardware_type': 'esp32',
'sampling_rate': 100,
'buffer_size': 1024,
'timeout': 5.0,
'validation_enabled': True,
'retry_attempts': 3
}
@pytest.fixture
def router_config(self):
"""Router configuration for testing."""
return {
'hardware_type': 'router',
'sampling_rate': 50,
'buffer_size': 512,
'timeout': 10.0,
'validation_enabled': False,
'retry_attempts': 1
}
@pytest.fixture
def sample_csi_data(self):
"""Sample CSI data for testing."""
return CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=3,
snr=15.5,
metadata={'source': 'esp32', 'channel': 6}
)
# Test all initialization paths
def test_init_esp32_config(self, esp32_config, mock_logger):
"""Should initialize with ESP32 configuration."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
assert extractor.config == esp32_config
assert extractor.logger == mock_logger
assert extractor.is_connected == False
assert extractor.hardware_type == 'esp32'
assert isinstance(extractor.parser, ESP32CSIParser)
def test_init_router_config(self, router_config, mock_logger):
"""Should initialize with router configuration."""
extractor = CSIExtractor(config=router_config, logger=mock_logger)
assert isinstance(extractor.parser, RouterCSIParser)
assert extractor.hardware_type == 'router'
def test_init_unsupported_hardware(self, mock_logger):
"""Should raise error for unsupported hardware type."""
invalid_config = {
'hardware_type': 'unsupported',
'sampling_rate': 100,
'buffer_size': 1024,
'timeout': 5.0
}
with pytest.raises(ValueError, match="Unsupported hardware type: unsupported"):
CSIExtractor(config=invalid_config, logger=mock_logger)
def test_init_without_logger(self, esp32_config):
"""Should initialize without logger."""
extractor = CSIExtractor(config=esp32_config)
assert extractor.logger is not None # Should create default logger
# Test all validation paths
def test_validation_missing_fields(self, mock_logger):
"""Should validate missing required fields."""
for missing_field in ['hardware_type', 'sampling_rate', 'buffer_size', 'timeout']:
config = {
'hardware_type': 'esp32',
'sampling_rate': 100,
'buffer_size': 1024,
'timeout': 5.0
}
del config[missing_field]
with pytest.raises(ValueError, match="Missing required configuration"):
CSIExtractor(config=config, logger=mock_logger)
def test_validation_negative_sampling_rate(self, mock_logger):
"""Should validate sampling_rate is positive."""
config = {
'hardware_type': 'esp32',
'sampling_rate': -1,
'buffer_size': 1024,
'timeout': 5.0
}
with pytest.raises(ValueError, match="sampling_rate must be positive"):
CSIExtractor(config=config, logger=mock_logger)
def test_validation_zero_buffer_size(self, mock_logger):
"""Should validate buffer_size is positive."""
config = {
'hardware_type': 'esp32',
'sampling_rate': 100,
'buffer_size': 0,
'timeout': 5.0
}
with pytest.raises(ValueError, match="buffer_size must be positive"):
CSIExtractor(config=config, logger=mock_logger)
def test_validation_negative_timeout(self, mock_logger):
"""Should validate timeout is positive."""
config = {
'hardware_type': 'esp32',
'sampling_rate': 100,
'buffer_size': 1024,
'timeout': -1.0
}
with pytest.raises(ValueError, match="timeout must be positive"):
CSIExtractor(config=config, logger=mock_logger)
# Test connection management
@pytest.mark.asyncio
async def test_connect_success(self, esp32_config, mock_logger):
"""Should connect successfully."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
with patch.object(extractor, '_establish_hardware_connection', new_callable=AsyncMock) as mock_conn:
mock_conn.return_value = True
result = await extractor.connect()
assert result == True
assert extractor.is_connected == True
@pytest.mark.asyncio
async def test_connect_failure(self, esp32_config, mock_logger):
"""Should handle connection failure."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
with patch.object(extractor, '_establish_hardware_connection', new_callable=AsyncMock) as mock_conn:
mock_conn.side_effect = ConnectionError("Failed")
result = await extractor.connect()
assert result == False
assert extractor.is_connected == False
@pytest.mark.asyncio
async def test_disconnect_when_connected(self, esp32_config, mock_logger):
"""Should disconnect when connected."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
extractor.is_connected = True
with patch.object(extractor, '_close_hardware_connection', new_callable=AsyncMock) as mock_close:
await extractor.disconnect()
assert extractor.is_connected == False
mock_close.assert_called_once()
@pytest.mark.asyncio
async def test_disconnect_when_not_connected(self, esp32_config, mock_logger):
"""Should not disconnect when not connected."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
extractor.is_connected = False
with patch.object(extractor, '_close_hardware_connection', new_callable=AsyncMock) as mock_close:
await extractor.disconnect()
mock_close.assert_not_called()
# Test extraction
@pytest.mark.asyncio
async def test_extract_not_connected(self, esp32_config, mock_logger):
"""Should raise error when not connected."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
extractor.is_connected = False
with pytest.raises(CSIParseError, match="Not connected to hardware"):
await extractor.extract_csi()
@pytest.mark.asyncio
async def test_extract_success_with_validation(self, esp32_config, mock_logger, sample_csi_data):
"""Should extract successfully with validation."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
extractor.is_connected = True
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
with patch.object(extractor.parser, 'parse', return_value=sample_csi_data):
with patch.object(extractor, 'validate_csi_data', return_value=True) as mock_validate:
mock_read.return_value = b"raw_data"
result = await extractor.extract_csi()
assert result == sample_csi_data
mock_validate.assert_called_once()
@pytest.mark.asyncio
async def test_extract_success_without_validation(self, esp32_config, mock_logger, sample_csi_data):
"""Should extract successfully without validation."""
esp32_config['validation_enabled'] = False
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
extractor.is_connected = True
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
with patch.object(extractor.parser, 'parse', return_value=sample_csi_data):
with patch.object(extractor, 'validate_csi_data') as mock_validate:
mock_read.return_value = b"raw_data"
result = await extractor.extract_csi()
assert result == sample_csi_data
mock_validate.assert_not_called()
@pytest.mark.asyncio
async def test_extract_retry_success(self, esp32_config, mock_logger, sample_csi_data):
"""Should retry and succeed."""
esp32_config['retry_attempts'] = 3
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
extractor.is_connected = True
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
with patch.object(extractor.parser, 'parse', return_value=sample_csi_data):
# Fail first two attempts, succeed on third
mock_read.side_effect = [ConnectionError(), ConnectionError(), b"raw_data"]
result = await extractor.extract_csi()
assert result == sample_csi_data
assert mock_read.call_count == 3
@pytest.mark.asyncio
async def test_extract_retry_failure(self, esp32_config, mock_logger):
"""Should fail after max retries."""
esp32_config['retry_attempts'] = 2
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
extractor.is_connected = True
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
mock_read.side_effect = ConnectionError("Failed")
with pytest.raises(CSIParseError, match="Extraction failed after 2 attempts"):
await extractor.extract_csi()
# Test validation
def test_validate_success(self, esp32_config, mock_logger, sample_csi_data):
"""Should validate successfully."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
result = extractor.validate_csi_data(sample_csi_data)
assert result == True
def test_validate_empty_amplitude(self, esp32_config, mock_logger):
"""Should reject empty amplitude."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.array([]),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=3,
snr=15.5,
metadata={}
)
with pytest.raises(CSIValidationError, match="Empty amplitude data"):
extractor.validate_csi_data(data)
def test_validate_empty_phase(self, esp32_config, mock_logger):
"""Should reject empty phase."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.array([]),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=3,
snr=15.5,
metadata={}
)
with pytest.raises(CSIValidationError, match="Empty phase data"):
extractor.validate_csi_data(data)
def test_validate_invalid_frequency(self, esp32_config, mock_logger):
"""Should reject invalid frequency."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=0,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=3,
snr=15.5,
metadata={}
)
with pytest.raises(CSIValidationError, match="Invalid frequency"):
extractor.validate_csi_data(data)
def test_validate_invalid_bandwidth(self, esp32_config, mock_logger):
"""Should reject invalid bandwidth."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=0,
num_subcarriers=56,
num_antennas=3,
snr=15.5,
metadata={}
)
with pytest.raises(CSIValidationError, match="Invalid bandwidth"):
extractor.validate_csi_data(data)
def test_validate_invalid_subcarriers(self, esp32_config, mock_logger):
"""Should reject invalid subcarriers."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=0,
num_antennas=3,
snr=15.5,
metadata={}
)
with pytest.raises(CSIValidationError, match="Invalid number of subcarriers"):
extractor.validate_csi_data(data)
def test_validate_invalid_antennas(self, esp32_config, mock_logger):
"""Should reject invalid antennas."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=0,
snr=15.5,
metadata={}
)
with pytest.raises(CSIValidationError, match="Invalid number of antennas"):
extractor.validate_csi_data(data)
def test_validate_snr_too_low(self, esp32_config, mock_logger):
"""Should reject SNR too low."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=3,
snr=-100,
metadata={}
)
with pytest.raises(CSIValidationError, match="Invalid SNR value"):
extractor.validate_csi_data(data)
def test_validate_snr_too_high(self, esp32_config, mock_logger):
"""Should reject SNR too high."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
data = CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=3,
snr=100,
metadata={}
)
with pytest.raises(CSIValidationError, match="Invalid SNR value"):
extractor.validate_csi_data(data)
# Test streaming
@pytest.mark.asyncio
async def test_streaming_success(self, esp32_config, mock_logger, sample_csi_data):
"""Should stream successfully."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
extractor.is_connected = True
callback = Mock()
with patch.object(extractor, 'extract_csi', new_callable=AsyncMock) as mock_extract:
mock_extract.return_value = sample_csi_data
# Start streaming task
task = asyncio.create_task(extractor.start_streaming(callback))
await asyncio.sleep(0.1) # Let it run briefly
extractor.stop_streaming()
await task
callback.assert_called()
@pytest.mark.asyncio
async def test_streaming_exception(self, esp32_config, mock_logger):
"""Should handle streaming exceptions."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
extractor.is_connected = True
callback = Mock()
with patch.object(extractor, 'extract_csi', new_callable=AsyncMock) as mock_extract:
mock_extract.side_effect = Exception("Test error")
# Start streaming and let it handle exception
task = asyncio.create_task(extractor.start_streaming(callback))
await task # This should complete due to exception
assert extractor.is_streaming == False
def test_stop_streaming(self, esp32_config, mock_logger):
"""Should stop streaming."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
extractor.is_streaming = True
extractor.stop_streaming()
assert extractor.is_streaming == False
# Test placeholder implementations for 100% coverage
@pytest.mark.asyncio
async def test_establish_hardware_connection_placeholder(self, esp32_config, mock_logger):
"""Should test placeholder hardware connection."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
result = await extractor._establish_hardware_connection()
assert result == True
@pytest.mark.asyncio
async def test_close_hardware_connection_placeholder(self, esp32_config, mock_logger):
"""Should test placeholder hardware disconnection."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
# Should not raise any exception
await extractor._close_hardware_connection()
@pytest.mark.asyncio
async def test_read_raw_data_placeholder(self, esp32_config, mock_logger):
"""Should test placeholder raw data reading."""
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
result = await extractor._read_raw_data()
assert result == b"CSI_DATA:1234567890,3,56,2400,20,15.5,[1.0,2.0,3.0],[0.5,1.5,2.5]"
@pytest.mark.unit
@pytest.mark.tdd
class TestESP32CSIParserStandalone:
"""Standalone tests for ESP32 CSI parser."""
@pytest.fixture
def parser(self):
"""Create parser instance."""
return ESP32CSIParser()
def test_parse_valid_data(self, parser):
"""Should parse valid ESP32 data."""
data = b"CSI_DATA:1234567890,3,56,2400,20,15.5,[1.0,2.0,3.0],[0.5,1.5,2.5]"
result = parser.parse(data)
assert isinstance(result, CSIData)
assert result.num_antennas == 3
assert result.num_subcarriers == 56
assert result.frequency == 2400000000
assert result.bandwidth == 20000000
assert result.snr == 15.5
def test_parse_empty_data(self, parser):
"""Should reject empty data."""
with pytest.raises(CSIParseError, match="Empty data received"):
parser.parse(b"")
def test_parse_invalid_format(self, parser):
"""Should reject invalid format."""
with pytest.raises(CSIParseError, match="Invalid ESP32 CSI data format"):
parser.parse(b"INVALID_DATA")
def test_parse_value_error(self, parser):
"""Should handle ValueError."""
data = b"CSI_DATA:invalid_number,3,56,2400,20,15.5"
with pytest.raises(CSIParseError, match="Failed to parse ESP32 data"):
parser.parse(data)
def test_parse_index_error(self, parser):
"""Should handle IndexError."""
data = b"CSI_DATA:1234567890" # Missing fields
with pytest.raises(CSIParseError, match="Failed to parse ESP32 data"):
parser.parse(data)
@pytest.mark.unit
@pytest.mark.tdd
class TestRouterCSIParserStandalone:
"""Standalone tests for Router CSI parser."""
@pytest.fixture
def parser(self):
"""Create parser instance."""
return RouterCSIParser()
def test_parse_empty_data(self, parser):
"""Should reject empty data."""
with pytest.raises(CSIParseError, match="Empty data received"):
parser.parse(b"")
def test_parse_atheros_format(self, parser):
"""Should parse Atheros format."""
data = b"ATHEROS_CSI:mock_data"
result = parser.parse(data)
assert isinstance(result, CSIData)
assert result.metadata['source'] == 'atheros_router'
def test_parse_unknown_format(self, parser):
"""Should reject unknown format."""
data = b"UNKNOWN_FORMAT:data"
with pytest.raises(CSIParseError, match="Unknown router CSI format"):
parser.parse(data)

View File

@@ -0,0 +1,367 @@
import pytest
import torch
import torch.nn as nn
import numpy as np
from unittest.mock import Mock, patch
from src.models.densepose_head import DensePoseHead, DensePoseError
class TestDensePoseHead:
"""Test suite for DensePose Head following London School TDD principles"""
@pytest.fixture
def mock_config(self):
"""Configuration for DensePose head"""
return {
'input_channels': 256,
'num_body_parts': 24,
'num_uv_coordinates': 2,
'hidden_channels': [128, 64],
'kernel_size': 3,
'padding': 1,
'dropout_rate': 0.1,
'use_deformable_conv': False,
'use_fpn': True,
'fpn_levels': [2, 3, 4, 5],
'output_stride': 4
}
@pytest.fixture
def densepose_head(self, mock_config):
"""Create DensePose head instance for testing"""
return DensePoseHead(mock_config)
@pytest.fixture
def mock_feature_input(self):
"""Generate mock feature input tensor"""
batch_size = 2
channels = 256
height = 56
width = 56
return torch.randn(batch_size, channels, height, width)
@pytest.fixture
def mock_target_masks(self):
"""Generate mock target segmentation masks"""
batch_size = 2
num_parts = 24
height = 224
width = 224
return torch.randint(0, num_parts + 1, (batch_size, height, width))
@pytest.fixture
def mock_target_uv(self):
"""Generate mock target UV coordinates"""
batch_size = 2
num_coords = 2
height = 224
width = 224
return torch.randn(batch_size, num_coords, height, width)
def test_head_initialization_creates_correct_architecture(self, mock_config):
"""Test that DensePose head initializes with correct architecture"""
# Act
head = DensePoseHead(mock_config)
# Assert
assert head is not None
assert isinstance(head, nn.Module)
assert head.input_channels == mock_config['input_channels']
assert head.num_body_parts == mock_config['num_body_parts']
assert head.num_uv_coordinates == mock_config['num_uv_coordinates']
assert head.use_fpn == mock_config['use_fpn']
assert hasattr(head, 'segmentation_head')
assert hasattr(head, 'uv_regression_head')
if mock_config['use_fpn']:
assert hasattr(head, 'fpn')
def test_forward_pass_produces_correct_output_format(self, densepose_head, mock_feature_input):
"""Test that forward pass produces correctly formatted output"""
# Act
output = densepose_head(mock_feature_input)
# Assert
assert output is not None
assert isinstance(output, dict)
assert 'segmentation' in output
assert 'uv_coordinates' in output
seg_output = output['segmentation']
uv_output = output['uv_coordinates']
assert isinstance(seg_output, torch.Tensor)
assert isinstance(uv_output, torch.Tensor)
assert seg_output.shape[0] == mock_feature_input.shape[0] # Batch size preserved
assert uv_output.shape[0] == mock_feature_input.shape[0] # Batch size preserved
def test_segmentation_head_produces_correct_shape(self, densepose_head, mock_feature_input):
"""Test that segmentation head produces correct output shape"""
# Act
output = densepose_head(mock_feature_input)
seg_output = output['segmentation']
# Assert
expected_channels = densepose_head.num_body_parts + 1 # +1 for background
assert seg_output.shape[1] == expected_channels
assert seg_output.shape[2] >= mock_feature_input.shape[2] # Height upsampled
assert seg_output.shape[3] >= mock_feature_input.shape[3] # Width upsampled
def test_uv_regression_head_produces_correct_shape(self, densepose_head, mock_feature_input):
"""Test that UV regression head produces correct output shape"""
# Act
output = densepose_head(mock_feature_input)
uv_output = output['uv_coordinates']
# Assert
assert uv_output.shape[1] == densepose_head.num_uv_coordinates
assert uv_output.shape[2] >= mock_feature_input.shape[2] # Height upsampled
assert uv_output.shape[3] >= mock_feature_input.shape[3] # Width upsampled
def test_compute_segmentation_loss_measures_pixel_classification(self, densepose_head, mock_feature_input, mock_target_masks):
"""Test that compute_segmentation_loss measures pixel classification accuracy"""
# Arrange
output = densepose_head(mock_feature_input)
seg_logits = output['segmentation']
# Resize target to match output
target_resized = torch.nn.functional.interpolate(
mock_target_masks.float().unsqueeze(1),
size=seg_logits.shape[2:],
mode='nearest'
).squeeze(1).long()
# Act
loss = densepose_head.compute_segmentation_loss(seg_logits, target_resized)
# Assert
assert loss is not None
assert isinstance(loss, torch.Tensor)
assert loss.dim() == 0 # Scalar loss
assert loss.item() >= 0 # Loss should be non-negative
def test_compute_uv_loss_measures_coordinate_regression(self, densepose_head, mock_feature_input, mock_target_uv):
"""Test that compute_uv_loss measures UV coordinate regression accuracy"""
# Arrange
output = densepose_head(mock_feature_input)
uv_pred = output['uv_coordinates']
# Resize target to match output
target_resized = torch.nn.functional.interpolate(
mock_target_uv,
size=uv_pred.shape[2:],
mode='bilinear',
align_corners=False
)
# Act
loss = densepose_head.compute_uv_loss(uv_pred, target_resized)
# Assert
assert loss is not None
assert isinstance(loss, torch.Tensor)
assert loss.dim() == 0 # Scalar loss
assert loss.item() >= 0 # Loss should be non-negative
def test_compute_total_loss_combines_segmentation_and_uv_losses(self, densepose_head, mock_feature_input, mock_target_masks, mock_target_uv):
"""Test that compute_total_loss combines segmentation and UV losses"""
# Arrange
output = densepose_head(mock_feature_input)
# Resize targets to match outputs
seg_target = torch.nn.functional.interpolate(
mock_target_masks.float().unsqueeze(1),
size=output['segmentation'].shape[2:],
mode='nearest'
).squeeze(1).long()
uv_target = torch.nn.functional.interpolate(
mock_target_uv,
size=output['uv_coordinates'].shape[2:],
mode='bilinear',
align_corners=False
)
# Act
total_loss = densepose_head.compute_total_loss(output, seg_target, uv_target)
seg_loss = densepose_head.compute_segmentation_loss(output['segmentation'], seg_target)
uv_loss = densepose_head.compute_uv_loss(output['uv_coordinates'], uv_target)
# Assert
assert total_loss is not None
assert isinstance(total_loss, torch.Tensor)
assert total_loss.item() > 0
# Total loss should be combination of individual losses
expected_total = seg_loss + uv_loss
assert torch.allclose(total_loss, expected_total, atol=1e-6)
def test_fpn_integration_enhances_multi_scale_features(self, mock_config, mock_feature_input):
"""Test that FPN integration enhances multi-scale feature processing"""
# Arrange
config_with_fpn = mock_config.copy()
config_with_fpn['use_fpn'] = True
config_without_fpn = mock_config.copy()
config_without_fpn['use_fpn'] = False
head_with_fpn = DensePoseHead(config_with_fpn)
head_without_fpn = DensePoseHead(config_without_fpn)
# Act
output_with_fpn = head_with_fpn(mock_feature_input)
output_without_fpn = head_without_fpn(mock_feature_input)
# Assert
assert output_with_fpn['segmentation'].shape == output_without_fpn['segmentation'].shape
assert output_with_fpn['uv_coordinates'].shape == output_without_fpn['uv_coordinates'].shape
# Outputs should be different due to FPN
assert not torch.allclose(output_with_fpn['segmentation'], output_without_fpn['segmentation'], atol=1e-6)
def test_get_prediction_confidence_provides_uncertainty_estimates(self, densepose_head, mock_feature_input):
"""Test that get_prediction_confidence provides uncertainty estimates"""
# Arrange
output = densepose_head(mock_feature_input)
# Act
confidence = densepose_head.get_prediction_confidence(output)
# Assert
assert confidence is not None
assert isinstance(confidence, dict)
assert 'segmentation_confidence' in confidence
assert 'uv_confidence' in confidence
seg_conf = confidence['segmentation_confidence']
uv_conf = confidence['uv_confidence']
assert isinstance(seg_conf, torch.Tensor)
assert isinstance(uv_conf, torch.Tensor)
assert seg_conf.shape[0] == mock_feature_input.shape[0]
assert uv_conf.shape[0] == mock_feature_input.shape[0]
def test_post_process_predictions_formats_output(self, densepose_head, mock_feature_input):
"""Test that post_process_predictions formats output correctly"""
# Arrange
raw_output = densepose_head(mock_feature_input)
# Act
processed = densepose_head.post_process_predictions(raw_output)
# Assert
assert processed is not None
assert isinstance(processed, dict)
assert 'body_parts' in processed
assert 'uv_coordinates' in processed
assert 'confidence_scores' in processed
def test_training_mode_enables_dropout(self, densepose_head, mock_feature_input):
"""Test that training mode enables dropout for regularization"""
# Arrange
densepose_head.train()
# Act
output1 = densepose_head(mock_feature_input)
output2 = densepose_head(mock_feature_input)
# Assert - outputs should be different due to dropout
assert not torch.allclose(output1['segmentation'], output2['segmentation'], atol=1e-6)
assert not torch.allclose(output1['uv_coordinates'], output2['uv_coordinates'], atol=1e-6)
def test_evaluation_mode_disables_dropout(self, densepose_head, mock_feature_input):
"""Test that evaluation mode disables dropout for consistent inference"""
# Arrange
densepose_head.eval()
# Act
output1 = densepose_head(mock_feature_input)
output2 = densepose_head(mock_feature_input)
# Assert - outputs should be identical in eval mode
assert torch.allclose(output1['segmentation'], output2['segmentation'], atol=1e-6)
assert torch.allclose(output1['uv_coordinates'], output2['uv_coordinates'], atol=1e-6)
def test_head_validates_input_dimensions(self, densepose_head):
"""Test that head validates input dimensions"""
# Arrange
invalid_input = torch.randn(2, 128, 56, 56) # Wrong number of channels
# Act & Assert
with pytest.raises(DensePoseError):
densepose_head(invalid_input)
def test_head_handles_different_input_sizes(self, densepose_head):
"""Test that head handles different input sizes"""
# Arrange
small_input = torch.randn(1, 256, 28, 28)
large_input = torch.randn(1, 256, 112, 112)
# Act
small_output = densepose_head(small_input)
large_output = densepose_head(large_input)
# Assert
assert small_output['segmentation'].shape[2:] != large_output['segmentation'].shape[2:]
assert small_output['uv_coordinates'].shape[2:] != large_output['uv_coordinates'].shape[2:]
def test_head_supports_gradient_computation(self, densepose_head, mock_feature_input, mock_target_masks, mock_target_uv):
"""Test that head supports gradient computation for training"""
# Arrange
densepose_head.train()
optimizer = torch.optim.Adam(densepose_head.parameters(), lr=0.001)
output = densepose_head(mock_feature_input)
# Resize targets
seg_target = torch.nn.functional.interpolate(
mock_target_masks.float().unsqueeze(1),
size=output['segmentation'].shape[2:],
mode='nearest'
).squeeze(1).long()
uv_target = torch.nn.functional.interpolate(
mock_target_uv,
size=output['uv_coordinates'].shape[2:],
mode='bilinear',
align_corners=False
)
# Act
loss = densepose_head.compute_total_loss(output, seg_target, uv_target)
optimizer.zero_grad()
loss.backward()
# Assert
for param in densepose_head.parameters():
if param.requires_grad:
assert param.grad is not None
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
def test_head_configuration_validation(self):
"""Test that head validates configuration parameters"""
# Arrange
invalid_config = {
'input_channels': 0, # Invalid
'num_body_parts': -1, # Invalid
'num_uv_coordinates': 2
}
# Act & Assert
with pytest.raises(ValueError):
DensePoseHead(invalid_config)
def test_save_and_load_model_state(self, densepose_head, mock_feature_input):
"""Test that model state can be saved and loaded"""
# Arrange
original_output = densepose_head(mock_feature_input)
# Act - Save state
state_dict = densepose_head.state_dict()
# Create new head and load state
new_head = DensePoseHead(densepose_head.config)
new_head.load_state_dict(state_dict)
new_output = new_head(mock_feature_input)
# Assert
assert torch.allclose(original_output['segmentation'], new_output['segmentation'], atol=1e-6)
assert torch.allclose(original_output['uv_coordinates'], new_output['uv_coordinates'], atol=1e-6)

View File

@@ -0,0 +1,293 @@
import pytest
import torch
import torch.nn as nn
import numpy as np
from unittest.mock import Mock, patch
from src.models.modality_translation import ModalityTranslationNetwork, ModalityTranslationError
class TestModalityTranslationNetwork:
"""Test suite for Modality Translation Network following London School TDD principles"""
@pytest.fixture
def mock_config(self):
"""Configuration for modality translation network"""
return {
'input_channels': 6, # Real and imaginary parts for 3 antennas
'hidden_channels': [64, 128, 256],
'output_channels': 256,
'kernel_size': 3,
'stride': 1,
'padding': 1,
'dropout_rate': 0.1,
'activation': 'relu',
'normalization': 'batch',
'use_attention': True,
'attention_heads': 8
}
@pytest.fixture
def translation_network(self, mock_config):
"""Create modality translation network instance for testing"""
return ModalityTranslationNetwork(mock_config)
@pytest.fixture
def mock_csi_input(self):
"""Generate mock CSI input tensor"""
batch_size = 4
channels = 6 # Real and imaginary parts for 3 antennas
height = 56 # Number of subcarriers
width = 100 # Time samples
return torch.randn(batch_size, channels, height, width)
@pytest.fixture
def mock_target_features(self):
"""Generate mock target feature tensor for training"""
batch_size = 4
feature_dim = 256
spatial_height = 56
spatial_width = 100
return torch.randn(batch_size, feature_dim, spatial_height, spatial_width)
def test_network_initialization_creates_correct_architecture(self, mock_config):
"""Test that modality translation network initializes with correct architecture"""
# Act
network = ModalityTranslationNetwork(mock_config)
# Assert
assert network is not None
assert isinstance(network, nn.Module)
assert network.input_channels == mock_config['input_channels']
assert network.output_channels == mock_config['output_channels']
assert network.use_attention == mock_config['use_attention']
assert hasattr(network, 'encoder')
assert hasattr(network, 'decoder')
if mock_config['use_attention']:
assert hasattr(network, 'attention')
def test_forward_pass_produces_correct_output_shape(self, translation_network, mock_csi_input):
"""Test that forward pass produces correctly shaped output"""
# Act
output = translation_network(mock_csi_input)
# Assert
assert output is not None
assert isinstance(output, torch.Tensor)
assert output.shape[0] == mock_csi_input.shape[0] # Batch size preserved
assert output.shape[1] == translation_network.output_channels # Correct output channels
assert output.shape[2] == mock_csi_input.shape[2] # Spatial height preserved
assert output.shape[3] == mock_csi_input.shape[3] # Spatial width preserved
def test_forward_pass_handles_different_input_sizes(self, translation_network):
"""Test that forward pass handles different input sizes"""
# Arrange
small_input = torch.randn(2, 6, 28, 50)
large_input = torch.randn(8, 6, 112, 200)
# Act
small_output = translation_network(small_input)
large_output = translation_network(large_input)
# Assert
assert small_output.shape == (2, 256, 28, 50)
assert large_output.shape == (8, 256, 112, 200)
def test_encoder_extracts_hierarchical_features(self, translation_network, mock_csi_input):
"""Test that encoder extracts hierarchical features"""
# Act
features = translation_network.encode(mock_csi_input)
# Assert
assert features is not None
assert isinstance(features, list)
assert len(features) == len(translation_network.encoder)
# Check feature map sizes decrease with depth
for i in range(1, len(features)):
assert features[i].shape[2] <= features[i-1].shape[2] # Height decreases or stays same
assert features[i].shape[3] <= features[i-1].shape[3] # Width decreases or stays same
def test_decoder_reconstructs_target_features(self, translation_network, mock_csi_input):
"""Test that decoder reconstructs target feature representation"""
# Arrange
encoded_features = translation_network.encode(mock_csi_input)
# Act
decoded_output = translation_network.decode(encoded_features)
# Assert
assert decoded_output is not None
assert isinstance(decoded_output, torch.Tensor)
assert decoded_output.shape[1] == translation_network.output_channels
assert decoded_output.shape[2:] == mock_csi_input.shape[2:]
def test_attention_mechanism_enhances_features(self, mock_config, mock_csi_input):
"""Test that attention mechanism enhances feature representation"""
# Arrange
config_with_attention = mock_config.copy()
config_with_attention['use_attention'] = True
config_without_attention = mock_config.copy()
config_without_attention['use_attention'] = False
network_with_attention = ModalityTranslationNetwork(config_with_attention)
network_without_attention = ModalityTranslationNetwork(config_without_attention)
# Act
output_with_attention = network_with_attention(mock_csi_input)
output_without_attention = network_without_attention(mock_csi_input)
# Assert
assert output_with_attention.shape == output_without_attention.shape
# Outputs should be different due to attention mechanism
assert not torch.allclose(output_with_attention, output_without_attention, atol=1e-6)
def test_training_mode_enables_dropout(self, translation_network, mock_csi_input):
"""Test that training mode enables dropout for regularization"""
# Arrange
translation_network.train()
# Act
output1 = translation_network(mock_csi_input)
output2 = translation_network(mock_csi_input)
# Assert - outputs should be different due to dropout
assert not torch.allclose(output1, output2, atol=1e-6)
def test_evaluation_mode_disables_dropout(self, translation_network, mock_csi_input):
"""Test that evaluation mode disables dropout for consistent inference"""
# Arrange
translation_network.eval()
# Act
output1 = translation_network(mock_csi_input)
output2 = translation_network(mock_csi_input)
# Assert - outputs should be identical in eval mode
assert torch.allclose(output1, output2, atol=1e-6)
def test_compute_translation_loss_measures_feature_alignment(self, translation_network, mock_csi_input, mock_target_features):
"""Test that compute_translation_loss measures feature alignment"""
# Arrange
predicted_features = translation_network(mock_csi_input)
# Act
loss = translation_network.compute_translation_loss(predicted_features, mock_target_features)
# Assert
assert loss is not None
assert isinstance(loss, torch.Tensor)
assert loss.dim() == 0 # Scalar loss
assert loss.item() >= 0 # Loss should be non-negative
def test_compute_translation_loss_handles_different_loss_types(self, translation_network, mock_csi_input, mock_target_features):
"""Test that compute_translation_loss handles different loss types"""
# Arrange
predicted_features = translation_network(mock_csi_input)
# Act
mse_loss = translation_network.compute_translation_loss(predicted_features, mock_target_features, loss_type='mse')
l1_loss = translation_network.compute_translation_loss(predicted_features, mock_target_features, loss_type='l1')
# Assert
assert mse_loss is not None
assert l1_loss is not None
assert mse_loss.item() != l1_loss.item() # Different loss types should give different values
def test_get_feature_statistics_provides_analysis(self, translation_network, mock_csi_input):
"""Test that get_feature_statistics provides feature analysis"""
# Arrange
output = translation_network(mock_csi_input)
# Act
stats = translation_network.get_feature_statistics(output)
# Assert
assert stats is not None
assert isinstance(stats, dict)
assert 'mean' in stats
assert 'std' in stats
assert 'min' in stats
assert 'max' in stats
assert 'sparsity' in stats
def test_network_supports_gradient_computation(self, translation_network, mock_csi_input, mock_target_features):
"""Test that network supports gradient computation for training"""
# Arrange
translation_network.train()
optimizer = torch.optim.Adam(translation_network.parameters(), lr=0.001)
# Act
output = translation_network(mock_csi_input)
loss = translation_network.compute_translation_loss(output, mock_target_features)
optimizer.zero_grad()
loss.backward()
# Assert
for param in translation_network.parameters():
if param.requires_grad:
assert param.grad is not None
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
def test_network_validates_input_dimensions(self, translation_network):
"""Test that network validates input dimensions"""
# Arrange
invalid_input = torch.randn(4, 3, 56, 100) # Wrong number of channels
# Act & Assert
with pytest.raises(ModalityTranslationError):
translation_network(invalid_input)
def test_network_handles_batch_size_one(self, translation_network):
"""Test that network handles single sample inference"""
# Arrange
single_input = torch.randn(1, 6, 56, 100)
# Act
output = translation_network(single_input)
# Assert
assert output.shape == (1, 256, 56, 100)
def test_save_and_load_model_state(self, translation_network, mock_csi_input):
"""Test that model state can be saved and loaded"""
# Arrange
original_output = translation_network(mock_csi_input)
# Act - Save state
state_dict = translation_network.state_dict()
# Create new network and load state
new_network = ModalityTranslationNetwork(translation_network.config)
new_network.load_state_dict(state_dict)
new_output = new_network(mock_csi_input)
# Assert
assert torch.allclose(original_output, new_output, atol=1e-6)
def test_network_configuration_validation(self):
"""Test that network validates configuration parameters"""
# Arrange
invalid_config = {
'input_channels': 0, # Invalid
'hidden_channels': [], # Invalid
'output_channels': 256
}
# Act & Assert
with pytest.raises(ValueError):
ModalityTranslationNetwork(invalid_config)
def test_feature_visualization_support(self, translation_network, mock_csi_input):
"""Test that network supports feature visualization"""
# Act
features = translation_network.get_intermediate_features(mock_csi_input)
# Assert
assert features is not None
assert isinstance(features, dict)
assert 'encoder_features' in features
assert 'decoder_features' in features
if translation_network.use_attention:
assert 'attention_weights' in features

View File

@@ -0,0 +1,107 @@
import pytest
import numpy as np
from unittest.mock import Mock, patch
from src.core.phase_sanitizer import PhaseSanitizer
class TestPhaseSanitizer:
"""Test suite for Phase Sanitizer following London School TDD principles"""
@pytest.fixture
def mock_phase_data(self):
"""Generate synthetic phase data for testing"""
# Phase data with unwrapping issues and outliers
return np.array([
[0.1, 0.2, 6.0, 0.4, 0.5], # Contains phase jump at index 2
[-3.0, -0.1, 0.0, 0.1, 0.2], # Contains wrapped phase at index 0
[0.0, 0.1, 0.2, 0.3, 0.4] # Clean phase data
])
@pytest.fixture
def phase_sanitizer(self):
"""Create Phase Sanitizer instance for testing"""
return PhaseSanitizer()
def test_unwrap_phase_removes_discontinuities(self, phase_sanitizer, mock_phase_data):
"""Test that phase unwrapping removes 2π discontinuities"""
# Act
result = phase_sanitizer.unwrap_phase(mock_phase_data)
# Assert
assert result is not None
assert isinstance(result, np.ndarray)
assert result.shape == mock_phase_data.shape
# Check that large jumps are reduced
for i in range(result.shape[0]):
phase_diffs = np.abs(np.diff(result[i]))
assert np.all(phase_diffs < np.pi) # No jumps larger than π
def test_remove_outliers_filters_anomalous_values(self, phase_sanitizer, mock_phase_data):
"""Test that outlier removal filters anomalous phase values"""
# Arrange - Add clear outliers
outlier_data = mock_phase_data.copy()
outlier_data[0, 2] = 100.0 # Clear outlier
# Act
result = phase_sanitizer.remove_outliers(outlier_data)
# Assert
assert result is not None
assert isinstance(result, np.ndarray)
assert result.shape == outlier_data.shape
assert np.abs(result[0, 2]) < 10.0 # Outlier should be corrected
def test_smooth_phase_reduces_noise(self, phase_sanitizer, mock_phase_data):
"""Test that phase smoothing reduces noise while preserving trends"""
# Arrange - Add noise
noisy_data = mock_phase_data + np.random.normal(0, 0.1, mock_phase_data.shape)
# Act
result = phase_sanitizer.smooth_phase(noisy_data)
# Assert
assert result is not None
assert isinstance(result, np.ndarray)
assert result.shape == noisy_data.shape
# Smoothed data should have lower variance
original_variance = np.var(noisy_data)
smoothed_variance = np.var(result)
assert smoothed_variance <= original_variance
def test_sanitize_handles_empty_input(self, phase_sanitizer):
"""Test that sanitizer handles empty input gracefully"""
# Arrange
empty_data = np.array([])
# Act & Assert
with pytest.raises(ValueError, match="Phase data cannot be empty"):
phase_sanitizer.sanitize(empty_data)
def test_sanitize_full_pipeline_integration(self, phase_sanitizer, mock_phase_data):
"""Test that full sanitization pipeline works correctly"""
# Act
result = phase_sanitizer.sanitize(mock_phase_data)
# Assert
assert result is not None
assert isinstance(result, np.ndarray)
assert result.shape == mock_phase_data.shape
# Result should be within reasonable phase bounds
assert np.all(result >= -2*np.pi)
assert np.all(result <= 2*np.pi)
def test_sanitize_performance_requirement(self, phase_sanitizer, mock_phase_data):
"""Test that phase sanitization meets performance requirements (<5ms)"""
import time
# Act
start_time = time.time()
result = phase_sanitizer.sanitize(mock_phase_data)
processing_time = time.time() - start_time
# Assert
assert processing_time < 0.005 # <5ms requirement
assert result is not None

View File

@@ -0,0 +1,407 @@
"""TDD tests for phase sanitizer following London School approach."""
import pytest
import numpy as np
import sys
import os
from unittest.mock import Mock, patch, AsyncMock
from datetime import datetime, timezone
import importlib.util
# Import the phase sanitizer module directly
spec = importlib.util.spec_from_file_location(
'phase_sanitizer',
'/workspaces/wifi-densepose/src/core/phase_sanitizer.py'
)
phase_sanitizer_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(phase_sanitizer_module)
# Get classes from the module
PhaseSanitizer = phase_sanitizer_module.PhaseSanitizer
PhaseSanitizationError = phase_sanitizer_module.PhaseSanitizationError
@pytest.mark.unit
@pytest.mark.tdd
@pytest.mark.london
class TestPhaseSanitizer:
"""Test phase sanitizer using London School TDD."""
@pytest.fixture
def mock_logger(self):
"""Mock logger for testing."""
return Mock()
@pytest.fixture
def sanitizer_config(self):
"""Phase sanitizer configuration for testing."""
return {
'unwrapping_method': 'numpy',
'outlier_threshold': 3.0,
'smoothing_window': 5,
'enable_outlier_removal': True,
'enable_smoothing': True,
'enable_noise_filtering': True,
'noise_threshold': 0.1,
'phase_range': (-np.pi, np.pi)
}
@pytest.fixture
def phase_sanitizer(self, sanitizer_config, mock_logger):
"""Create phase sanitizer for testing."""
return PhaseSanitizer(config=sanitizer_config, logger=mock_logger)
@pytest.fixture
def sample_wrapped_phase(self):
"""Sample wrapped phase data with discontinuities."""
# Create phase data with wrapping
phase = np.linspace(0, 4*np.pi, 100)
wrapped_phase = np.angle(np.exp(1j * phase)) # Wrap to [-π, π]
return wrapped_phase.reshape(1, -1) # Shape: (1, 100)
@pytest.fixture
def sample_noisy_phase(self):
"""Sample phase data with noise and outliers."""
clean_phase = np.linspace(-np.pi, np.pi, 50)
noise = np.random.normal(0, 0.05, 50)
# Add some outliers
outliers = np.random.choice(50, 5, replace=False)
noisy_phase = clean_phase + noise
noisy_phase[outliers] += np.random.uniform(-2, 2, 5) # Add outliers
return noisy_phase.reshape(1, -1)
# Initialization tests
def test_should_initialize_with_valid_config(self, sanitizer_config, mock_logger):
"""Should initialize phase sanitizer with valid configuration."""
sanitizer = PhaseSanitizer(config=sanitizer_config, logger=mock_logger)
assert sanitizer.config == sanitizer_config
assert sanitizer.logger == mock_logger
assert sanitizer.unwrapping_method == 'numpy'
assert sanitizer.outlier_threshold == 3.0
assert sanitizer.smoothing_window == 5
assert sanitizer.enable_outlier_removal == True
assert sanitizer.enable_smoothing == True
assert sanitizer.enable_noise_filtering == True
assert sanitizer.noise_threshold == 0.1
assert sanitizer.phase_range == (-np.pi, np.pi)
def test_should_raise_error_with_invalid_config(self, mock_logger):
"""Should raise error when initialized with invalid configuration."""
invalid_config = {'invalid': 'config'}
with pytest.raises(ValueError, match="Missing required configuration"):
PhaseSanitizer(config=invalid_config, logger=mock_logger)
def test_should_validate_required_fields(self, mock_logger):
"""Should validate required configuration fields."""
required_fields = ['unwrapping_method', 'outlier_threshold', 'smoothing_window']
base_config = {
'unwrapping_method': 'numpy',
'outlier_threshold': 3.0,
'smoothing_window': 5
}
for field in required_fields:
config = base_config.copy()
del config[field]
with pytest.raises(ValueError, match="Missing required configuration"):
PhaseSanitizer(config=config, logger=mock_logger)
def test_should_use_default_values(self, mock_logger):
"""Should use default values for optional parameters."""
minimal_config = {
'unwrapping_method': 'numpy',
'outlier_threshold': 3.0,
'smoothing_window': 5
}
sanitizer = PhaseSanitizer(config=minimal_config, logger=mock_logger)
assert sanitizer.enable_outlier_removal == True # default
assert sanitizer.enable_smoothing == True # default
assert sanitizer.enable_noise_filtering == False # default
assert sanitizer.noise_threshold == 0.05 # default
assert sanitizer.phase_range == (-np.pi, np.pi) # default
def test_should_initialize_without_logger(self, sanitizer_config):
"""Should initialize without logger provided."""
sanitizer = PhaseSanitizer(config=sanitizer_config)
assert sanitizer.logger is not None # Should create default logger
# Phase unwrapping tests
def test_should_unwrap_phase_successfully(self, phase_sanitizer, sample_wrapped_phase):
"""Should unwrap phase data successfully."""
result = phase_sanitizer.unwrap_phase(sample_wrapped_phase)
# Check that result has same shape
assert result.shape == sample_wrapped_phase.shape
# Check that unwrapping removed discontinuities
phase_diff = np.diff(result.flatten())
large_jumps = np.abs(phase_diff) > np.pi
assert np.sum(large_jumps) < np.sum(np.abs(np.diff(sample_wrapped_phase.flatten())) > np.pi)
def test_should_handle_different_unwrapping_methods(self, sanitizer_config, mock_logger):
"""Should handle different unwrapping methods."""
for method in ['numpy', 'scipy', 'custom']:
sanitizer_config['unwrapping_method'] = method
sanitizer = PhaseSanitizer(config=sanitizer_config, logger=mock_logger)
phase_data = np.random.uniform(-np.pi, np.pi, (2, 50))
with patch.object(sanitizer, f'_unwrap_{method}', return_value=phase_data) as mock_unwrap:
result = sanitizer.unwrap_phase(phase_data)
assert result.shape == phase_data.shape
mock_unwrap.assert_called_once()
def test_should_handle_unwrapping_error(self, phase_sanitizer):
"""Should handle phase unwrapping errors gracefully."""
invalid_phase = np.array([[]]) # Empty array
with pytest.raises(PhaseSanitizationError, match="Failed to unwrap phase"):
phase_sanitizer.unwrap_phase(invalid_phase)
# Outlier removal tests
def test_should_remove_outliers_successfully(self, phase_sanitizer, sample_noisy_phase):
"""Should remove outliers from phase data successfully."""
with patch.object(phase_sanitizer, '_detect_outliers') as mock_detect:
with patch.object(phase_sanitizer, '_interpolate_outliers') as mock_interpolate:
outlier_mask = np.zeros(sample_noisy_phase.shape, dtype=bool)
outlier_mask[0, [10, 20, 30]] = True # Mark some outliers
clean_phase = sample_noisy_phase.copy()
mock_detect.return_value = outlier_mask
mock_interpolate.return_value = clean_phase
result = phase_sanitizer.remove_outliers(sample_noisy_phase)
assert result.shape == sample_noisy_phase.shape
mock_detect.assert_called_once_with(sample_noisy_phase)
mock_interpolate.assert_called_once()
def test_should_skip_outlier_removal_when_disabled(self, sanitizer_config, mock_logger, sample_noisy_phase):
"""Should skip outlier removal when disabled."""
sanitizer_config['enable_outlier_removal'] = False
sanitizer = PhaseSanitizer(config=sanitizer_config, logger=mock_logger)
result = sanitizer.remove_outliers(sample_noisy_phase)
assert np.array_equal(result, sample_noisy_phase)
def test_should_handle_outlier_removal_error(self, phase_sanitizer):
"""Should handle outlier removal errors gracefully."""
with patch.object(phase_sanitizer, '_detect_outliers') as mock_detect:
mock_detect.side_effect = Exception("Detection error")
phase_data = np.random.uniform(-np.pi, np.pi, (2, 50))
with pytest.raises(PhaseSanitizationError, match="Failed to remove outliers"):
phase_sanitizer.remove_outliers(phase_data)
# Smoothing tests
def test_should_smooth_phase_successfully(self, phase_sanitizer, sample_noisy_phase):
"""Should smooth phase data successfully."""
with patch.object(phase_sanitizer, '_apply_moving_average') as mock_smooth:
smoothed_phase = sample_noisy_phase * 0.9 # Simulate smoothing
mock_smooth.return_value = smoothed_phase
result = phase_sanitizer.smooth_phase(sample_noisy_phase)
assert result.shape == sample_noisy_phase.shape
mock_smooth.assert_called_once_with(sample_noisy_phase, phase_sanitizer.smoothing_window)
def test_should_skip_smoothing_when_disabled(self, sanitizer_config, mock_logger, sample_noisy_phase):
"""Should skip smoothing when disabled."""
sanitizer_config['enable_smoothing'] = False
sanitizer = PhaseSanitizer(config=sanitizer_config, logger=mock_logger)
result = sanitizer.smooth_phase(sample_noisy_phase)
assert np.array_equal(result, sample_noisy_phase)
def test_should_handle_smoothing_error(self, phase_sanitizer):
"""Should handle smoothing errors gracefully."""
with patch.object(phase_sanitizer, '_apply_moving_average') as mock_smooth:
mock_smooth.side_effect = Exception("Smoothing error")
phase_data = np.random.uniform(-np.pi, np.pi, (2, 50))
with pytest.raises(PhaseSanitizationError, match="Failed to smooth phase"):
phase_sanitizer.smooth_phase(phase_data)
# Noise filtering tests
def test_should_filter_noise_successfully(self, phase_sanitizer, sample_noisy_phase):
"""Should filter noise from phase data successfully."""
with patch.object(phase_sanitizer, '_apply_low_pass_filter') as mock_filter:
filtered_phase = sample_noisy_phase * 0.95 # Simulate filtering
mock_filter.return_value = filtered_phase
result = phase_sanitizer.filter_noise(sample_noisy_phase)
assert result.shape == sample_noisy_phase.shape
mock_filter.assert_called_once_with(sample_noisy_phase, phase_sanitizer.noise_threshold)
def test_should_skip_noise_filtering_when_disabled(self, sanitizer_config, mock_logger, sample_noisy_phase):
"""Should skip noise filtering when disabled."""
sanitizer_config['enable_noise_filtering'] = False
sanitizer = PhaseSanitizer(config=sanitizer_config, logger=mock_logger)
result = sanitizer.filter_noise(sample_noisy_phase)
assert np.array_equal(result, sample_noisy_phase)
def test_should_handle_noise_filtering_error(self, phase_sanitizer):
"""Should handle noise filtering errors gracefully."""
with patch.object(phase_sanitizer, '_apply_low_pass_filter') as mock_filter:
mock_filter.side_effect = Exception("Filtering error")
phase_data = np.random.uniform(-np.pi, np.pi, (2, 50))
with pytest.raises(PhaseSanitizationError, match="Failed to filter noise"):
phase_sanitizer.filter_noise(phase_data)
# Complete sanitization pipeline tests
def test_should_sanitize_phase_pipeline_successfully(self, phase_sanitizer, sample_wrapped_phase):
"""Should sanitize phase through complete pipeline successfully."""
with patch.object(phase_sanitizer, 'unwrap_phase', return_value=sample_wrapped_phase) as mock_unwrap:
with patch.object(phase_sanitizer, 'remove_outliers', return_value=sample_wrapped_phase) as mock_outliers:
with patch.object(phase_sanitizer, 'smooth_phase', return_value=sample_wrapped_phase) as mock_smooth:
with patch.object(phase_sanitizer, 'filter_noise', return_value=sample_wrapped_phase) as mock_filter:
result = phase_sanitizer.sanitize_phase(sample_wrapped_phase)
assert result.shape == sample_wrapped_phase.shape
mock_unwrap.assert_called_once_with(sample_wrapped_phase)
mock_outliers.assert_called_once()
mock_smooth.assert_called_once()
mock_filter.assert_called_once()
def test_should_handle_sanitization_pipeline_error(self, phase_sanitizer, sample_wrapped_phase):
"""Should handle sanitization pipeline errors gracefully."""
with patch.object(phase_sanitizer, 'unwrap_phase') as mock_unwrap:
mock_unwrap.side_effect = PhaseSanitizationError("Unwrapping failed")
with pytest.raises(PhaseSanitizationError):
phase_sanitizer.sanitize_phase(sample_wrapped_phase)
# Phase validation tests
def test_should_validate_phase_data_successfully(self, phase_sanitizer):
"""Should validate phase data successfully."""
valid_phase = np.random.uniform(-np.pi, np.pi, (3, 56))
result = phase_sanitizer.validate_phase_data(valid_phase)
assert result == True
def test_should_reject_invalid_phase_shape(self, phase_sanitizer):
"""Should reject phase data with invalid shape."""
invalid_phase = np.array([1, 2, 3]) # 1D array
with pytest.raises(PhaseSanitizationError, match="Phase data must be 2D"):
phase_sanitizer.validate_phase_data(invalid_phase)
def test_should_reject_empty_phase_data(self, phase_sanitizer):
"""Should reject empty phase data."""
empty_phase = np.array([]).reshape(0, 0)
with pytest.raises(PhaseSanitizationError, match="Phase data cannot be empty"):
phase_sanitizer.validate_phase_data(empty_phase)
def test_should_reject_phase_out_of_range(self, phase_sanitizer):
"""Should reject phase data outside valid range."""
invalid_phase = np.array([[10.0, -10.0, 5.0, -5.0]]) # Outside [-π, π]
with pytest.raises(PhaseSanitizationError, match="Phase values outside valid range"):
phase_sanitizer.validate_phase_data(invalid_phase)
# Statistics and monitoring tests
def test_should_get_sanitization_statistics(self, phase_sanitizer):
"""Should get sanitization statistics."""
# Simulate some processing
phase_sanitizer._total_processed = 50
phase_sanitizer._outliers_removed = 5
phase_sanitizer._sanitization_errors = 2
stats = phase_sanitizer.get_sanitization_statistics()
assert isinstance(stats, dict)
assert stats['total_processed'] == 50
assert stats['outliers_removed'] == 5
assert stats['sanitization_errors'] == 2
assert stats['outlier_rate'] == 0.1
assert stats['error_rate'] == 0.04
def test_should_reset_statistics(self, phase_sanitizer):
"""Should reset sanitization statistics."""
phase_sanitizer._total_processed = 50
phase_sanitizer._outliers_removed = 5
phase_sanitizer._sanitization_errors = 2
phase_sanitizer.reset_statistics()
assert phase_sanitizer._total_processed == 0
assert phase_sanitizer._outliers_removed == 0
assert phase_sanitizer._sanitization_errors == 0
# Configuration validation tests
def test_should_validate_unwrapping_method(self, mock_logger):
"""Should validate unwrapping method."""
invalid_config = {
'unwrapping_method': 'invalid_method',
'outlier_threshold': 3.0,
'smoothing_window': 5
}
with pytest.raises(ValueError, match="Invalid unwrapping method"):
PhaseSanitizer(config=invalid_config, logger=mock_logger)
def test_should_validate_outlier_threshold(self, mock_logger):
"""Should validate outlier threshold."""
invalid_config = {
'unwrapping_method': 'numpy',
'outlier_threshold': -1.0, # Negative threshold
'smoothing_window': 5
}
with pytest.raises(ValueError, match="outlier_threshold must be positive"):
PhaseSanitizer(config=invalid_config, logger=mock_logger)
def test_should_validate_smoothing_window(self, mock_logger):
"""Should validate smoothing window."""
invalid_config = {
'unwrapping_method': 'numpy',
'outlier_threshold': 3.0,
'smoothing_window': 0 # Invalid window size
}
with pytest.raises(ValueError, match="smoothing_window must be positive"):
PhaseSanitizer(config=invalid_config, logger=mock_logger)
# Edge case tests
def test_should_handle_single_antenna_data(self, phase_sanitizer):
"""Should handle single antenna phase data."""
single_antenna_phase = np.random.uniform(-np.pi, np.pi, (1, 56))
result = phase_sanitizer.sanitize_phase(single_antenna_phase)
assert result.shape == single_antenna_phase.shape
def test_should_handle_small_phase_arrays(self, phase_sanitizer):
"""Should handle small phase arrays."""
small_phase = np.random.uniform(-np.pi, np.pi, (2, 5))
result = phase_sanitizer.sanitize_phase(small_phase)
assert result.shape == small_phase.shape
def test_should_handle_constant_phase_data(self, phase_sanitizer):
"""Should handle constant phase data."""
constant_phase = np.full((3, 20), 0.5)
result = phase_sanitizer.sanitize_phase(constant_phase)
assert result.shape == constant_phase.shape

View File

@@ -0,0 +1,244 @@
import pytest
import numpy as np
from unittest.mock import Mock, patch, MagicMock
from src.hardware.router_interface import RouterInterface, RouterConnectionError
class TestRouterInterface:
"""Test suite for Router Interface following London School TDD principles"""
@pytest.fixture
def mock_config(self):
"""Configuration for router interface"""
return {
'router_ip': '192.168.1.1',
'username': 'admin',
'password': 'password',
'ssh_port': 22,
'timeout': 30,
'max_retries': 3
}
@pytest.fixture
def router_interface(self, mock_config):
"""Create router interface instance for testing"""
return RouterInterface(mock_config)
@pytest.fixture
def mock_ssh_client(self):
"""Mock SSH client for testing"""
mock_client = Mock()
mock_client.connect = Mock()
mock_client.exec_command = Mock()
mock_client.close = Mock()
return mock_client
def test_interface_initialization_creates_correct_configuration(self, mock_config):
"""Test that router interface initializes with correct configuration"""
# Act
interface = RouterInterface(mock_config)
# Assert
assert interface is not None
assert interface.router_ip == mock_config['router_ip']
assert interface.username == mock_config['username']
assert interface.password == mock_config['password']
assert interface.ssh_port == mock_config['ssh_port']
assert interface.timeout == mock_config['timeout']
assert interface.max_retries == mock_config['max_retries']
assert not interface.is_connected
@patch('paramiko.SSHClient')
def test_connect_establishes_ssh_connection(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that connect method establishes SSH connection"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
# Act
result = router_interface.connect()
# Assert
assert result is True
assert router_interface.is_connected is True
mock_ssh_client.set_missing_host_key_policy.assert_called_once()
mock_ssh_client.connect.assert_called_once_with(
hostname=router_interface.router_ip,
port=router_interface.ssh_port,
username=router_interface.username,
password=router_interface.password,
timeout=router_interface.timeout
)
@patch('paramiko.SSHClient')
def test_connect_handles_connection_failure(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that connect method handles connection failures gracefully"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
mock_ssh_client.connect.side_effect = Exception("Connection failed")
# Act & Assert
with pytest.raises(RouterConnectionError):
router_interface.connect()
assert router_interface.is_connected is False
@patch('paramiko.SSHClient')
def test_disconnect_closes_ssh_connection(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that disconnect method closes SSH connection"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
router_interface.connect()
# Act
router_interface.disconnect()
# Assert
assert router_interface.is_connected is False
mock_ssh_client.close.assert_called_once()
@patch('paramiko.SSHClient')
def test_execute_command_runs_ssh_command(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that execute_command runs SSH commands correctly"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
mock_stdout = Mock()
mock_stdout.read.return_value = b"command output"
mock_stderr = Mock()
mock_stderr.read.return_value = b""
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
router_interface.connect()
# Act
result = router_interface.execute_command("test command")
# Assert
assert result == "command output"
mock_ssh_client.exec_command.assert_called_with("test command")
@patch('paramiko.SSHClient')
def test_execute_command_handles_command_errors(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that execute_command handles command errors"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
mock_stdout = Mock()
mock_stdout.read.return_value = b""
mock_stderr = Mock()
mock_stderr.read.return_value = b"command error"
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
router_interface.connect()
# Act & Assert
with pytest.raises(RouterConnectionError):
router_interface.execute_command("failing command")
def test_execute_command_requires_connection(self, router_interface):
"""Test that execute_command requires active connection"""
# Act & Assert
with pytest.raises(RouterConnectionError):
router_interface.execute_command("test command")
@patch('paramiko.SSHClient')
def test_get_router_info_retrieves_system_information(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that get_router_info retrieves router system information"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
mock_stdout = Mock()
mock_stdout.read.return_value = b"Router Model: AC1900\nFirmware: 1.2.3"
mock_stderr = Mock()
mock_stderr.read.return_value = b""
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
router_interface.connect()
# Act
info = router_interface.get_router_info()
# Assert
assert info is not None
assert isinstance(info, dict)
assert 'model' in info
assert 'firmware' in info
@patch('paramiko.SSHClient')
def test_enable_monitor_mode_configures_wifi_monitoring(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that enable_monitor_mode configures WiFi monitoring"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
mock_stdout = Mock()
mock_stdout.read.return_value = b"Monitor mode enabled"
mock_stderr = Mock()
mock_stderr.read.return_value = b""
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
router_interface.connect()
# Act
result = router_interface.enable_monitor_mode("wlan0")
# Assert
assert result is True
mock_ssh_client.exec_command.assert_called()
@patch('paramiko.SSHClient')
def test_disable_monitor_mode_disables_wifi_monitoring(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that disable_monitor_mode disables WiFi monitoring"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
mock_stdout = Mock()
mock_stdout.read.return_value = b"Monitor mode disabled"
mock_stderr = Mock()
mock_stderr.read.return_value = b""
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
router_interface.connect()
# Act
result = router_interface.disable_monitor_mode("wlan0")
# Assert
assert result is True
mock_ssh_client.exec_command.assert_called()
@patch('paramiko.SSHClient')
def test_interface_supports_context_manager(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that router interface supports context manager protocol"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
# Act
with router_interface as interface:
# Assert
assert interface.is_connected is True
# Assert - connection should be closed after context
assert router_interface.is_connected is False
mock_ssh_client.close.assert_called_once()
def test_interface_validates_configuration(self):
"""Test that router interface validates configuration parameters"""
# Arrange
invalid_config = {
'router_ip': '', # Invalid IP
'username': 'admin',
'password': 'password'
}
# Act & Assert
with pytest.raises(ValueError):
RouterInterface(invalid_config)
@patch('paramiko.SSHClient')
def test_interface_implements_retry_logic(self, mock_ssh_class, router_interface, mock_ssh_client):
"""Test that interface implements retry logic for failed operations"""
# Arrange
mock_ssh_class.return_value = mock_ssh_client
mock_ssh_client.connect.side_effect = [Exception("Temp failure"), None] # Fail once, then succeed
# Act
result = router_interface.connect()
# Assert
assert result is True
assert mock_ssh_client.connect.call_count == 2 # Should retry once

View File

@@ -0,0 +1,410 @@
"""TDD tests for router interface following London School approach."""
import pytest
import asyncio
import sys
import os
from unittest.mock import Mock, patch, AsyncMock, MagicMock
from datetime import datetime, timezone
import importlib.util
# Import the router interface module directly
import unittest.mock
# Mock asyncssh before importing
with unittest.mock.patch.dict('sys.modules', {'asyncssh': unittest.mock.MagicMock()}):
spec = importlib.util.spec_from_file_location(
'router_interface',
'/workspaces/wifi-densepose/src/hardware/router_interface.py'
)
router_module = importlib.util.module_from_spec(spec)
# Import CSI extractor for dependency
csi_spec = importlib.util.spec_from_file_location(
'csi_extractor',
'/workspaces/wifi-densepose/src/hardware/csi_extractor.py'
)
csi_module = importlib.util.module_from_spec(csi_spec)
csi_spec.loader.exec_module(csi_module)
# Now load the router interface
router_module.CSIData = csi_module.CSIData # Make CSIData available
spec.loader.exec_module(router_module)
# Get classes from modules
RouterInterface = router_module.RouterInterface
RouterConnectionError = router_module.RouterConnectionError
CSIData = csi_module.CSIData
@pytest.mark.unit
@pytest.mark.tdd
@pytest.mark.london
class TestRouterInterface:
"""Test router interface using London School TDD."""
@pytest.fixture
def mock_logger(self):
"""Mock logger for testing."""
return Mock()
@pytest.fixture
def router_config(self):
"""Router configuration for testing."""
return {
'host': '192.168.1.1',
'port': 22,
'username': 'admin',
'password': 'password',
'command_timeout': 30,
'connection_timeout': 10,
'max_retries': 3,
'retry_delay': 1.0
}
@pytest.fixture
def router_interface(self, router_config, mock_logger):
"""Create router interface for testing."""
return RouterInterface(config=router_config, logger=mock_logger)
# Initialization tests
def test_should_initialize_with_valid_config(self, router_config, mock_logger):
"""Should initialize router interface with valid configuration."""
interface = RouterInterface(config=router_config, logger=mock_logger)
assert interface.host == '192.168.1.1'
assert interface.port == 22
assert interface.username == 'admin'
assert interface.password == 'password'
assert interface.command_timeout == 30
assert interface.connection_timeout == 10
assert interface.max_retries == 3
assert interface.retry_delay == 1.0
assert interface.is_connected == False
assert interface.logger == mock_logger
def test_should_raise_error_with_invalid_config(self, mock_logger):
"""Should raise error when initialized with invalid configuration."""
invalid_config = {'invalid': 'config'}
with pytest.raises(ValueError, match="Missing required configuration"):
RouterInterface(config=invalid_config, logger=mock_logger)
def test_should_validate_required_fields(self, mock_logger):
"""Should validate all required configuration fields."""
required_fields = ['host', 'port', 'username', 'password']
base_config = {
'host': '192.168.1.1',
'port': 22,
'username': 'admin',
'password': 'password'
}
for field in required_fields:
config = base_config.copy()
del config[field]
with pytest.raises(ValueError, match="Missing required configuration"):
RouterInterface(config=config, logger=mock_logger)
def test_should_use_default_values(self, mock_logger):
"""Should use default values for optional parameters."""
minimal_config = {
'host': '192.168.1.1',
'port': 22,
'username': 'admin',
'password': 'password'
}
interface = RouterInterface(config=minimal_config, logger=mock_logger)
assert interface.command_timeout == 30 # default
assert interface.connection_timeout == 10 # default
assert interface.max_retries == 3 # default
assert interface.retry_delay == 1.0 # default
def test_should_initialize_without_logger(self, router_config):
"""Should initialize without logger provided."""
interface = RouterInterface(config=router_config)
assert interface.logger is not None # Should create default logger
# Connection tests
@pytest.mark.asyncio
async def test_should_connect_successfully(self, router_interface):
"""Should establish SSH connection successfully."""
mock_ssh_client = Mock()
with patch('src.hardware.router_interface.asyncssh.connect', new_callable=AsyncMock) as mock_connect:
mock_connect.return_value = mock_ssh_client
result = await router_interface.connect()
assert result == True
assert router_interface.is_connected == True
assert router_interface.ssh_client == mock_ssh_client
mock_connect.assert_called_once_with(
'192.168.1.1',
port=22,
username='admin',
password='password',
connect_timeout=10
)
@pytest.mark.asyncio
async def test_should_handle_connection_failure(self, router_interface):
"""Should handle SSH connection failure gracefully."""
with patch('src.hardware.router_interface.asyncssh.connect', new_callable=AsyncMock) as mock_connect:
mock_connect.side_effect = ConnectionError("Connection failed")
result = await router_interface.connect()
assert result == False
assert router_interface.is_connected == False
assert router_interface.ssh_client is None
router_interface.logger.error.assert_called()
@pytest.mark.asyncio
async def test_should_disconnect_when_connected(self, router_interface):
"""Should disconnect SSH connection when connected."""
mock_ssh_client = Mock()
router_interface.is_connected = True
router_interface.ssh_client = mock_ssh_client
await router_interface.disconnect()
assert router_interface.is_connected == False
assert router_interface.ssh_client is None
mock_ssh_client.close.assert_called_once()
@pytest.mark.asyncio
async def test_should_handle_disconnect_when_not_connected(self, router_interface):
"""Should handle disconnect when not connected."""
router_interface.is_connected = False
router_interface.ssh_client = None
await router_interface.disconnect()
# Should not raise any exception
assert router_interface.is_connected == False
# Command execution tests
@pytest.mark.asyncio
async def test_should_execute_command_successfully(self, router_interface):
"""Should execute SSH command successfully."""
mock_ssh_client = Mock()
mock_result = Mock()
mock_result.stdout = "command output"
mock_result.stderr = ""
mock_result.returncode = 0
router_interface.is_connected = True
router_interface.ssh_client = mock_ssh_client
with patch.object(mock_ssh_client, 'run', new_callable=AsyncMock) as mock_run:
mock_run.return_value = mock_result
result = await router_interface.execute_command("test command")
assert result == "command output"
mock_run.assert_called_once_with("test command", timeout=30)
@pytest.mark.asyncio
async def test_should_handle_command_execution_when_not_connected(self, router_interface):
"""Should handle command execution when not connected."""
router_interface.is_connected = False
with pytest.raises(RouterConnectionError, match="Not connected to router"):
await router_interface.execute_command("test command")
@pytest.mark.asyncio
async def test_should_handle_command_execution_error(self, router_interface):
"""Should handle command execution errors."""
mock_ssh_client = Mock()
mock_result = Mock()
mock_result.stdout = ""
mock_result.stderr = "command error"
mock_result.returncode = 1
router_interface.is_connected = True
router_interface.ssh_client = mock_ssh_client
with patch.object(mock_ssh_client, 'run', new_callable=AsyncMock) as mock_run:
mock_run.return_value = mock_result
with pytest.raises(RouterConnectionError, match="Command failed"):
await router_interface.execute_command("test command")
@pytest.mark.asyncio
async def test_should_retry_command_execution_on_failure(self, router_interface):
"""Should retry command execution on temporary failure."""
mock_ssh_client = Mock()
mock_success_result = Mock()
mock_success_result.stdout = "success output"
mock_success_result.stderr = ""
mock_success_result.returncode = 0
router_interface.is_connected = True
router_interface.ssh_client = mock_ssh_client
with patch.object(mock_ssh_client, 'run', new_callable=AsyncMock) as mock_run:
# First two calls fail, third succeeds
mock_run.side_effect = [
ConnectionError("Network error"),
ConnectionError("Network error"),
mock_success_result
]
result = await router_interface.execute_command("test command")
assert result == "success output"
assert mock_run.call_count == 3
@pytest.mark.asyncio
async def test_should_fail_after_max_retries(self, router_interface):
"""Should fail after maximum retries exceeded."""
mock_ssh_client = Mock()
router_interface.is_connected = True
router_interface.ssh_client = mock_ssh_client
with patch.object(mock_ssh_client, 'run', new_callable=AsyncMock) as mock_run:
mock_run.side_effect = ConnectionError("Network error")
with pytest.raises(RouterConnectionError, match="Command execution failed after 3 retries"):
await router_interface.execute_command("test command")
assert mock_run.call_count == 3
# CSI data retrieval tests
@pytest.mark.asyncio
async def test_should_get_csi_data_successfully(self, router_interface):
"""Should retrieve CSI data successfully."""
expected_csi_data = Mock(spec=CSIData)
with patch.object(router_interface, 'execute_command', new_callable=AsyncMock) as mock_execute:
with patch.object(router_interface, '_parse_csi_response', return_value=expected_csi_data) as mock_parse:
mock_execute.return_value = "csi data response"
result = await router_interface.get_csi_data()
assert result == expected_csi_data
mock_execute.assert_called_once_with("iwlist scan | grep CSI")
mock_parse.assert_called_once_with("csi data response")
@pytest.mark.asyncio
async def test_should_handle_csi_data_retrieval_failure(self, router_interface):
"""Should handle CSI data retrieval failure."""
with patch.object(router_interface, 'execute_command', new_callable=AsyncMock) as mock_execute:
mock_execute.side_effect = RouterConnectionError("Command failed")
with pytest.raises(RouterConnectionError):
await router_interface.get_csi_data()
# Router status tests
@pytest.mark.asyncio
async def test_should_get_router_status_successfully(self, router_interface):
"""Should get router status successfully."""
expected_status = {
'cpu_usage': 25.5,
'memory_usage': 60.2,
'wifi_status': 'active',
'uptime': '5 days, 3 hours'
}
with patch.object(router_interface, 'execute_command', new_callable=AsyncMock) as mock_execute:
with patch.object(router_interface, '_parse_status_response', return_value=expected_status) as mock_parse:
mock_execute.return_value = "status response"
result = await router_interface.get_router_status()
assert result == expected_status
mock_execute.assert_called_once_with("cat /proc/stat && free && iwconfig")
mock_parse.assert_called_once_with("status response")
# Configuration tests
@pytest.mark.asyncio
async def test_should_configure_csi_monitoring_successfully(self, router_interface):
"""Should configure CSI monitoring successfully."""
config = {
'channel': 6,
'bandwidth': 20,
'sample_rate': 100
}
with patch.object(router_interface, 'execute_command', new_callable=AsyncMock) as mock_execute:
mock_execute.return_value = "Configuration applied"
result = await router_interface.configure_csi_monitoring(config)
assert result == True
mock_execute.assert_called_once_with(
"iwconfig wlan0 channel 6 && echo 'CSI monitoring configured'"
)
@pytest.mark.asyncio
async def test_should_handle_csi_monitoring_configuration_failure(self, router_interface):
"""Should handle CSI monitoring configuration failure."""
config = {
'channel': 6,
'bandwidth': 20,
'sample_rate': 100
}
with patch.object(router_interface, 'execute_command', new_callable=AsyncMock) as mock_execute:
mock_execute.side_effect = RouterConnectionError("Command failed")
result = await router_interface.configure_csi_monitoring(config)
assert result == False
# Health check tests
@pytest.mark.asyncio
async def test_should_perform_health_check_successfully(self, router_interface):
"""Should perform health check successfully."""
with patch.object(router_interface, 'execute_command', new_callable=AsyncMock) as mock_execute:
mock_execute.return_value = "pong"
result = await router_interface.health_check()
assert result == True
mock_execute.assert_called_once_with("echo 'ping' && echo 'pong'")
@pytest.mark.asyncio
async def test_should_handle_health_check_failure(self, router_interface):
"""Should handle health check failure."""
with patch.object(router_interface, 'execute_command', new_callable=AsyncMock) as mock_execute:
mock_execute.side_effect = RouterConnectionError("Command failed")
result = await router_interface.health_check()
assert result == False
# Parsing method tests
def test_should_parse_csi_response(self, router_interface):
"""Should parse CSI response data."""
mock_response = "CSI_DATA:timestamp,antennas,subcarriers,frequency,bandwidth"
with patch('src.hardware.router_interface.CSIData') as mock_csi_data:
expected_data = Mock(spec=CSIData)
mock_csi_data.return_value = expected_data
result = router_interface._parse_csi_response(mock_response)
assert result == expected_data
def test_should_parse_status_response(self, router_interface):
"""Should parse router status response."""
mock_response = """
cpu 123456 0 78901 234567 0 0 0 0 0 0
MemTotal: 1024000 kB
MemFree: 512000 kB
wlan0 IEEE 802.11 ESSID:"TestNetwork"
"""
result = router_interface._parse_status_response(mock_response)
assert isinstance(result, dict)
assert 'cpu_usage' in result
assert 'memory_usage' in result
assert 'wifi_status' in result