feat: Complete Rust port of WiFi-DensePose with modular crates
Major changes: - Organized Python v1 implementation into v1/ subdirectory - Created Rust workspace with 9 modular crates: - wifi-densepose-core: Core types, traits, errors - wifi-densepose-signal: CSI processing, phase sanitization, FFT - wifi-densepose-nn: Neural network inference (ONNX/Candle/tch) - wifi-densepose-api: Axum-based REST/WebSocket API - wifi-densepose-db: SQLx database layer - wifi-densepose-config: Configuration management - wifi-densepose-hardware: Hardware abstraction - wifi-densepose-wasm: WebAssembly bindings - wifi-densepose-cli: Command-line interface Documentation: - ADR-001: Workspace structure - ADR-002: Signal processing library selection - ADR-003: Neural network inference strategy - DDD domain model with bounded contexts Testing: - 69 tests passing across all crates - Signal processing: 45 tests - Neural networks: 21 tests - Core: 3 doc tests Performance targets: - 10x faster CSI processing (~0.5ms vs ~5ms) - 5x lower memory usage (~100MB vs ~500MB) - WASM support for browser deployment
This commit is contained in:
736
v1/tests/e2e/test_healthcare_scenario.py
Normal file
736
v1/tests/e2e/test_healthcare_scenario.py
Normal file
@@ -0,0 +1,736 @@
|
||||
"""
|
||||
End-to-end tests for healthcare fall detection scenario.
|
||||
|
||||
Tests complete workflow from CSI data collection to fall alert generation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class AlertSeverity(Enum):
|
||||
"""Alert severity levels."""
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HealthcareAlert:
|
||||
"""Healthcare alert data structure."""
|
||||
alert_id: str
|
||||
timestamp: datetime
|
||||
alert_type: str
|
||||
severity: AlertSeverity
|
||||
patient_id: str
|
||||
location: str
|
||||
confidence: float
|
||||
description: str
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class MockPatientMonitor:
|
||||
"""Mock patient monitoring system."""
|
||||
|
||||
def __init__(self, patient_id: str, room_id: str):
|
||||
self.patient_id = patient_id
|
||||
self.room_id = room_id
|
||||
self.is_monitoring = False
|
||||
self.baseline_activity = None
|
||||
self.activity_history = []
|
||||
self.alerts_generated = []
|
||||
self.fall_detection_enabled = True
|
||||
self.sensitivity_level = "medium"
|
||||
|
||||
async def start_monitoring(self) -> bool:
|
||||
"""Start patient monitoring."""
|
||||
if self.is_monitoring:
|
||||
return False
|
||||
|
||||
self.is_monitoring = True
|
||||
return True
|
||||
|
||||
async def stop_monitoring(self) -> bool:
|
||||
"""Stop patient monitoring."""
|
||||
if not self.is_monitoring:
|
||||
return False
|
||||
|
||||
self.is_monitoring = False
|
||||
return True
|
||||
|
||||
async def process_pose_data(self, pose_data: Dict[str, Any]) -> Optional[HealthcareAlert]:
|
||||
"""Process pose data and detect potential issues."""
|
||||
if not self.is_monitoring:
|
||||
return None
|
||||
|
||||
# Extract activity metrics
|
||||
activity_metrics = self._extract_activity_metrics(pose_data)
|
||||
self.activity_history.append(activity_metrics)
|
||||
|
||||
# Keep only recent history
|
||||
if len(self.activity_history) > 100:
|
||||
self.activity_history = self.activity_history[-100:]
|
||||
|
||||
# Detect anomalies
|
||||
alert = await self._detect_anomalies(activity_metrics, pose_data)
|
||||
|
||||
if alert:
|
||||
self.alerts_generated.append(alert)
|
||||
|
||||
return alert
|
||||
|
||||
def _extract_activity_metrics(self, pose_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Extract activity metrics from pose data."""
|
||||
persons = pose_data.get("persons", [])
|
||||
|
||||
if not persons:
|
||||
return {
|
||||
"person_count": 0,
|
||||
"activity_level": 0.0,
|
||||
"posture": "unknown",
|
||||
"movement_speed": 0.0,
|
||||
"stability_score": 1.0
|
||||
}
|
||||
|
||||
# Analyze first person (primary patient)
|
||||
person = persons[0]
|
||||
|
||||
# Extract posture from activity field or bounding box analysis
|
||||
posture = person.get("activity", "standing")
|
||||
|
||||
# If no activity specified, analyze bounding box for fall detection
|
||||
if posture == "standing" and "bounding_box" in person:
|
||||
bbox = person["bounding_box"]
|
||||
width = bbox.get("width", 80)
|
||||
height = bbox.get("height", 180)
|
||||
|
||||
# Fall detection: if width > height, likely fallen
|
||||
if width > height * 1.5:
|
||||
posture = "fallen"
|
||||
|
||||
# Calculate activity metrics based on posture
|
||||
if posture == "fallen":
|
||||
activity_level = 0.1
|
||||
movement_speed = 0.0
|
||||
stability_score = 0.2
|
||||
elif posture == "walking":
|
||||
activity_level = 0.8
|
||||
movement_speed = 1.5
|
||||
stability_score = 0.7
|
||||
elif posture == "sitting":
|
||||
activity_level = 0.3
|
||||
movement_speed = 0.1
|
||||
stability_score = 0.9
|
||||
else: # standing or other
|
||||
activity_level = 0.5
|
||||
movement_speed = 0.2
|
||||
stability_score = 0.8
|
||||
|
||||
return {
|
||||
"person_count": len(persons),
|
||||
"activity_level": activity_level,
|
||||
"posture": posture,
|
||||
"movement_speed": movement_speed,
|
||||
"stability_score": stability_score,
|
||||
"confidence": person.get("confidence", 0.0)
|
||||
}
|
||||
|
||||
async def _detect_anomalies(self, current_metrics: Dict[str, Any], pose_data: Dict[str, Any]) -> Optional[HealthcareAlert]:
|
||||
"""Detect health-related anomalies."""
|
||||
# Fall detection
|
||||
if current_metrics["posture"] == "fallen":
|
||||
return await self._generate_fall_alert(current_metrics, pose_data)
|
||||
|
||||
# Prolonged inactivity detection
|
||||
if len(self.activity_history) >= 10:
|
||||
recent_activity = [m["activity_level"] for m in self.activity_history[-10:]]
|
||||
avg_activity = np.mean(recent_activity)
|
||||
|
||||
if avg_activity < 0.1: # Very low activity
|
||||
return await self._generate_inactivity_alert(current_metrics, pose_data)
|
||||
|
||||
# Unusual movement patterns
|
||||
if current_metrics["stability_score"] < 0.4:
|
||||
return await self._generate_instability_alert(current_metrics, pose_data)
|
||||
|
||||
return None
|
||||
|
||||
async def _generate_fall_alert(self, metrics: Dict[str, Any], pose_data: Dict[str, Any]) -> HealthcareAlert:
|
||||
"""Generate fall detection alert."""
|
||||
return HealthcareAlert(
|
||||
alert_id=f"fall_{self.patient_id}_{int(datetime.utcnow().timestamp())}",
|
||||
timestamp=datetime.utcnow(),
|
||||
alert_type="fall_detected",
|
||||
severity=AlertSeverity.CRITICAL,
|
||||
patient_id=self.patient_id,
|
||||
location=self.room_id,
|
||||
confidence=metrics["confidence"],
|
||||
description=f"Fall detected for patient {self.patient_id} in {self.room_id}",
|
||||
metadata={
|
||||
"posture": metrics["posture"],
|
||||
"stability_score": metrics["stability_score"],
|
||||
"pose_data": pose_data
|
||||
}
|
||||
)
|
||||
|
||||
async def _generate_inactivity_alert(self, metrics: Dict[str, Any], pose_data: Dict[str, Any]) -> HealthcareAlert:
|
||||
"""Generate prolonged inactivity alert."""
|
||||
return HealthcareAlert(
|
||||
alert_id=f"inactivity_{self.patient_id}_{int(datetime.utcnow().timestamp())}",
|
||||
timestamp=datetime.utcnow(),
|
||||
alert_type="prolonged_inactivity",
|
||||
severity=AlertSeverity.MEDIUM,
|
||||
patient_id=self.patient_id,
|
||||
location=self.room_id,
|
||||
confidence=metrics["confidence"],
|
||||
description=f"Prolonged inactivity detected for patient {self.patient_id}",
|
||||
metadata={
|
||||
"activity_level": metrics["activity_level"],
|
||||
"duration_minutes": 10,
|
||||
"pose_data": pose_data
|
||||
}
|
||||
)
|
||||
|
||||
async def _generate_instability_alert(self, metrics: Dict[str, Any], pose_data: Dict[str, Any]) -> HealthcareAlert:
|
||||
"""Generate movement instability alert."""
|
||||
return HealthcareAlert(
|
||||
alert_id=f"instability_{self.patient_id}_{int(datetime.utcnow().timestamp())}",
|
||||
timestamp=datetime.utcnow(),
|
||||
alert_type="movement_instability",
|
||||
severity=AlertSeverity.HIGH,
|
||||
patient_id=self.patient_id,
|
||||
location=self.room_id,
|
||||
confidence=metrics["confidence"],
|
||||
description=f"Movement instability detected for patient {self.patient_id}",
|
||||
metadata={
|
||||
"stability_score": metrics["stability_score"],
|
||||
"movement_speed": metrics["movement_speed"],
|
||||
"pose_data": pose_data
|
||||
}
|
||||
)
|
||||
|
||||
def get_monitoring_stats(self) -> Dict[str, Any]:
|
||||
"""Get monitoring statistics."""
|
||||
return {
|
||||
"patient_id": self.patient_id,
|
||||
"room_id": self.room_id,
|
||||
"is_monitoring": self.is_monitoring,
|
||||
"total_alerts": len(self.alerts_generated),
|
||||
"alert_types": {
|
||||
alert.alert_type: len([a for a in self.alerts_generated if a.alert_type == alert.alert_type])
|
||||
for alert in self.alerts_generated
|
||||
},
|
||||
"activity_samples": len(self.activity_history),
|
||||
"fall_detection_enabled": self.fall_detection_enabled
|
||||
}
|
||||
|
||||
|
||||
class MockHealthcareNotificationSystem:
|
||||
"""Mock healthcare notification system."""
|
||||
|
||||
def __init__(self):
|
||||
self.notifications_sent = []
|
||||
self.notification_channels = {
|
||||
"nurse_station": True,
|
||||
"mobile_app": True,
|
||||
"email": True,
|
||||
"sms": False
|
||||
}
|
||||
self.escalation_rules = {
|
||||
AlertSeverity.CRITICAL: ["nurse_station", "mobile_app", "sms"],
|
||||
AlertSeverity.HIGH: ["nurse_station", "mobile_app"],
|
||||
AlertSeverity.MEDIUM: ["nurse_station"],
|
||||
AlertSeverity.LOW: ["mobile_app"]
|
||||
}
|
||||
|
||||
async def send_alert_notification(self, alert: HealthcareAlert) -> Dict[str, bool]:
|
||||
"""Send alert notification through appropriate channels."""
|
||||
channels_to_notify = self.escalation_rules.get(alert.severity, ["nurse_station"])
|
||||
results = {}
|
||||
|
||||
for channel in channels_to_notify:
|
||||
if self.notification_channels.get(channel, False):
|
||||
success = await self._send_to_channel(channel, alert)
|
||||
results[channel] = success
|
||||
|
||||
if success:
|
||||
self.notifications_sent.append({
|
||||
"alert_id": alert.alert_id,
|
||||
"channel": channel,
|
||||
"timestamp": datetime.utcnow(),
|
||||
"severity": alert.severity.value
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
async def _send_to_channel(self, channel: str, alert: HealthcareAlert) -> bool:
|
||||
"""Send notification to specific channel."""
|
||||
# Simulate network delay
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Simulate occasional failures
|
||||
if np.random.random() < 0.05: # 5% failure rate
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_notification_stats(self) -> Dict[str, Any]:
|
||||
"""Get notification statistics."""
|
||||
return {
|
||||
"total_notifications": len(self.notifications_sent),
|
||||
"notifications_by_channel": {
|
||||
channel: len([n for n in self.notifications_sent if n["channel"] == channel])
|
||||
for channel in self.notification_channels.keys()
|
||||
},
|
||||
"notifications_by_severity": {
|
||||
severity.value: len([n for n in self.notifications_sent if n["severity"] == severity.value])
|
||||
for severity in AlertSeverity
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TestHealthcareFallDetection:
|
||||
"""Test healthcare fall detection workflow."""
|
||||
|
||||
@pytest.fixture
|
||||
def patient_monitor(self):
|
||||
"""Create patient monitor."""
|
||||
return MockPatientMonitor("patient_001", "room_101")
|
||||
|
||||
@pytest.fixture
|
||||
def notification_system(self):
|
||||
"""Create notification system."""
|
||||
return MockHealthcareNotificationSystem()
|
||||
|
||||
@pytest.fixture
|
||||
def fall_pose_data(self):
|
||||
"""Create pose data indicating a fall."""
|
||||
return {
|
||||
"persons": [
|
||||
{
|
||||
"person_id": "patient_001",
|
||||
"confidence": 0.92,
|
||||
"bounding_box": {"x": 200, "y": 400, "width": 150, "height": 80}, # Horizontal position
|
||||
"activity": "fallen",
|
||||
"keypoints": [[x, y, 0.8] for x, y in zip(range(17), range(17))]
|
||||
}
|
||||
],
|
||||
"zone_summary": {"room_101": 1},
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def normal_pose_data(self):
|
||||
"""Create normal pose data."""
|
||||
return {
|
||||
"persons": [
|
||||
{
|
||||
"person_id": "patient_001",
|
||||
"confidence": 0.88,
|
||||
"bounding_box": {"x": 200, "y": 150, "width": 80, "height": 180},
|
||||
"activity": "standing",
|
||||
"keypoints": [[x, y, 0.9] for x, y in zip(range(17), range(17))]
|
||||
}
|
||||
],
|
||||
"zone_summary": {"room_101": 1},
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fall_detection_workflow_should_fail_initially(self, patient_monitor, notification_system, fall_pose_data):
|
||||
"""Test fall detection workflow - should fail initially."""
|
||||
# Start monitoring
|
||||
result = await patient_monitor.start_monitoring()
|
||||
|
||||
# This will fail initially
|
||||
assert result is True
|
||||
assert patient_monitor.is_monitoring is True
|
||||
|
||||
# Process fall pose data
|
||||
alert = await patient_monitor.process_pose_data(fall_pose_data)
|
||||
|
||||
# Should generate fall alert
|
||||
assert alert is not None
|
||||
assert alert.alert_type == "fall_detected"
|
||||
assert alert.severity == AlertSeverity.CRITICAL
|
||||
assert alert.patient_id == "patient_001"
|
||||
|
||||
# Send notification
|
||||
notification_results = await notification_system.send_alert_notification(alert)
|
||||
|
||||
# Should notify appropriate channels
|
||||
assert len(notification_results) > 0
|
||||
assert any(notification_results.values()) # At least one channel should succeed
|
||||
|
||||
# Check statistics
|
||||
monitor_stats = patient_monitor.get_monitoring_stats()
|
||||
assert monitor_stats["total_alerts"] == 1
|
||||
|
||||
notification_stats = notification_system.get_notification_stats()
|
||||
assert notification_stats["total_notifications"] > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normal_activity_monitoring_should_fail_initially(self, patient_monitor, normal_pose_data):
|
||||
"""Test normal activity monitoring - should fail initially."""
|
||||
await patient_monitor.start_monitoring()
|
||||
|
||||
# Process multiple normal pose data samples
|
||||
alerts_generated = []
|
||||
|
||||
for i in range(10):
|
||||
alert = await patient_monitor.process_pose_data(normal_pose_data)
|
||||
if alert:
|
||||
alerts_generated.append(alert)
|
||||
|
||||
# This will fail initially
|
||||
# Should not generate alerts for normal activity
|
||||
assert len(alerts_generated) == 0
|
||||
|
||||
# Should have activity history
|
||||
stats = patient_monitor.get_monitoring_stats()
|
||||
assert stats["activity_samples"] == 10
|
||||
assert stats["is_monitoring"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prolonged_inactivity_detection_should_fail_initially(self, patient_monitor):
|
||||
"""Test prolonged inactivity detection - should fail initially."""
|
||||
await patient_monitor.start_monitoring()
|
||||
|
||||
# Simulate prolonged inactivity
|
||||
inactive_pose_data = {
|
||||
"persons": [], # No person detected
|
||||
"zone_summary": {"room_101": 0},
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
alerts_generated = []
|
||||
|
||||
# Process multiple inactive samples
|
||||
for i in range(15):
|
||||
alert = await patient_monitor.process_pose_data(inactive_pose_data)
|
||||
if alert:
|
||||
alerts_generated.append(alert)
|
||||
|
||||
# This will fail initially
|
||||
# Should generate inactivity alert after sufficient samples
|
||||
inactivity_alerts = [a for a in alerts_generated if a.alert_type == "prolonged_inactivity"]
|
||||
assert len(inactivity_alerts) > 0
|
||||
|
||||
# Check alert properties
|
||||
alert = inactivity_alerts[0]
|
||||
assert alert.severity == AlertSeverity.MEDIUM
|
||||
assert alert.patient_id == "patient_001"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_movement_instability_detection_should_fail_initially(self, patient_monitor):
|
||||
"""Test movement instability detection - should fail initially."""
|
||||
await patient_monitor.start_monitoring()
|
||||
|
||||
# Simulate unstable movement
|
||||
unstable_pose_data = {
|
||||
"persons": [
|
||||
{
|
||||
"person_id": "patient_001",
|
||||
"confidence": 0.65, # Lower confidence indicates instability
|
||||
"bounding_box": {"x": 200, "y": 150, "width": 80, "height": 180},
|
||||
"activity": "walking",
|
||||
"keypoints": [[x, y, 0.5] for x, y in zip(range(17), range(17))] # Low keypoint confidence
|
||||
}
|
||||
],
|
||||
"zone_summary": {"room_101": 1},
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# Process unstable pose data
|
||||
alert = await patient_monitor.process_pose_data(unstable_pose_data)
|
||||
|
||||
# This will fail initially
|
||||
# May generate instability alert based on stability score
|
||||
if alert and alert.alert_type == "movement_instability":
|
||||
assert alert.severity == AlertSeverity.HIGH
|
||||
assert alert.patient_id == "patient_001"
|
||||
assert "stability_score" in alert.metadata
|
||||
|
||||
|
||||
class TestHealthcareMultiPatientMonitoring:
|
||||
"""Test multi-patient monitoring scenarios."""
|
||||
|
||||
@pytest.fixture
|
||||
def multi_patient_setup(self):
|
||||
"""Create multi-patient monitoring setup."""
|
||||
patients = {
|
||||
"patient_001": MockPatientMonitor("patient_001", "room_101"),
|
||||
"patient_002": MockPatientMonitor("patient_002", "room_102"),
|
||||
"patient_003": MockPatientMonitor("patient_003", "room_103")
|
||||
}
|
||||
|
||||
notification_system = MockHealthcareNotificationSystem()
|
||||
|
||||
return patients, notification_system
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_patient_monitoring_should_fail_initially(self, multi_patient_setup):
|
||||
"""Test concurrent patient monitoring - should fail initially."""
|
||||
patients, notification_system = multi_patient_setup
|
||||
|
||||
# Start monitoring for all patients
|
||||
start_results = []
|
||||
for patient_id, monitor in patients.items():
|
||||
result = await monitor.start_monitoring()
|
||||
start_results.append(result)
|
||||
|
||||
# This will fail initially
|
||||
assert all(start_results)
|
||||
assert all(monitor.is_monitoring for monitor in patients.values())
|
||||
|
||||
# Simulate concurrent pose data processing
|
||||
pose_data_samples = [
|
||||
{
|
||||
"persons": [
|
||||
{
|
||||
"person_id": patient_id,
|
||||
"confidence": 0.85,
|
||||
"bounding_box": {"x": 200, "y": 150, "width": 80, "height": 180},
|
||||
"activity": "standing"
|
||||
}
|
||||
],
|
||||
"zone_summary": {f"room_{101 + i}": 1},
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
for i, patient_id in enumerate(patients.keys())
|
||||
]
|
||||
|
||||
# Process data for all patients concurrently
|
||||
tasks = []
|
||||
for (patient_id, monitor), pose_data in zip(patients.items(), pose_data_samples):
|
||||
task = asyncio.create_task(monitor.process_pose_data(pose_data))
|
||||
tasks.append(task)
|
||||
|
||||
alerts = await asyncio.gather(*tasks)
|
||||
|
||||
# Check results
|
||||
assert len(alerts) == len(patients)
|
||||
|
||||
# Get statistics for all patients
|
||||
all_stats = {}
|
||||
for patient_id, monitor in patients.items():
|
||||
all_stats[patient_id] = monitor.get_monitoring_stats()
|
||||
|
||||
assert len(all_stats) == 3
|
||||
assert all(stats["is_monitoring"] for stats in all_stats.values())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alert_prioritization_should_fail_initially(self, multi_patient_setup):
|
||||
"""Test alert prioritization across patients - should fail initially."""
|
||||
patients, notification_system = multi_patient_setup
|
||||
|
||||
# Start monitoring
|
||||
for monitor in patients.values():
|
||||
await monitor.start_monitoring()
|
||||
|
||||
# Generate different severity alerts
|
||||
alert_scenarios = [
|
||||
("patient_001", "fall_detected", AlertSeverity.CRITICAL),
|
||||
("patient_002", "prolonged_inactivity", AlertSeverity.MEDIUM),
|
||||
("patient_003", "movement_instability", AlertSeverity.HIGH)
|
||||
]
|
||||
|
||||
generated_alerts = []
|
||||
|
||||
for patient_id, alert_type, expected_severity in alert_scenarios:
|
||||
# Create appropriate pose data for each scenario
|
||||
if alert_type == "fall_detected":
|
||||
pose_data = {
|
||||
"persons": [{"person_id": patient_id, "confidence": 0.9, "activity": "fallen"}],
|
||||
"zone_summary": {f"room_{patients[patient_id].room_id}": 1}
|
||||
}
|
||||
else:
|
||||
pose_data = {
|
||||
"persons": [{"person_id": patient_id, "confidence": 0.7, "activity": "standing"}],
|
||||
"zone_summary": {f"room_{patients[patient_id].room_id}": 1}
|
||||
}
|
||||
|
||||
alert = await patients[patient_id].process_pose_data(pose_data)
|
||||
if alert:
|
||||
generated_alerts.append(alert)
|
||||
|
||||
# This will fail initially
|
||||
# Should have generated alerts
|
||||
assert len(generated_alerts) > 0
|
||||
|
||||
# Send notifications for all alerts
|
||||
notification_tasks = [
|
||||
notification_system.send_alert_notification(alert)
|
||||
for alert in generated_alerts
|
||||
]
|
||||
|
||||
notification_results = await asyncio.gather(*notification_tasks)
|
||||
|
||||
# Check notification prioritization
|
||||
notification_stats = notification_system.get_notification_stats()
|
||||
assert notification_stats["total_notifications"] > 0
|
||||
|
||||
# Critical alerts should use more channels
|
||||
critical_notifications = [
|
||||
n for n in notification_system.notifications_sent
|
||||
if n["severity"] == "critical"
|
||||
]
|
||||
|
||||
if critical_notifications:
|
||||
# Critical alerts should be sent to multiple channels
|
||||
critical_channels = set(n["channel"] for n in critical_notifications)
|
||||
assert len(critical_channels) >= 1
|
||||
|
||||
|
||||
class TestHealthcareSystemIntegration:
|
||||
"""Test healthcare system integration scenarios."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_healthcare_workflow_should_fail_initially(self):
|
||||
"""Test complete end-to-end healthcare workflow - should fail initially."""
|
||||
# Setup complete healthcare monitoring system
|
||||
class HealthcareMonitoringSystem:
|
||||
def __init__(self):
|
||||
self.patient_monitors = {}
|
||||
self.notification_system = MockHealthcareNotificationSystem()
|
||||
self.alert_history = []
|
||||
self.system_status = "operational"
|
||||
|
||||
async def add_patient(self, patient_id: str, room_id: str) -> bool:
|
||||
"""Add patient to monitoring system."""
|
||||
if patient_id in self.patient_monitors:
|
||||
return False
|
||||
|
||||
monitor = MockPatientMonitor(patient_id, room_id)
|
||||
self.patient_monitors[patient_id] = monitor
|
||||
return await monitor.start_monitoring()
|
||||
|
||||
async def process_pose_update(self, room_id: str, pose_data: Dict[str, Any]) -> List[HealthcareAlert]:
|
||||
"""Process pose update for room."""
|
||||
alerts = []
|
||||
|
||||
# Find patients in this room
|
||||
room_patients = [
|
||||
(patient_id, monitor) for patient_id, monitor in self.patient_monitors.items()
|
||||
if monitor.room_id == room_id
|
||||
]
|
||||
|
||||
for patient_id, monitor in room_patients:
|
||||
alert = await monitor.process_pose_data(pose_data)
|
||||
if alert:
|
||||
alerts.append(alert)
|
||||
self.alert_history.append(alert)
|
||||
|
||||
# Send notification
|
||||
await self.notification_system.send_alert_notification(alert)
|
||||
|
||||
return alerts
|
||||
|
||||
def get_system_status(self) -> Dict[str, Any]:
|
||||
"""Get overall system status."""
|
||||
return {
|
||||
"system_status": self.system_status,
|
||||
"total_patients": len(self.patient_monitors),
|
||||
"active_monitors": sum(1 for m in self.patient_monitors.values() if m.is_monitoring),
|
||||
"total_alerts": len(self.alert_history),
|
||||
"notification_stats": self.notification_system.get_notification_stats()
|
||||
}
|
||||
|
||||
healthcare_system = HealthcareMonitoringSystem()
|
||||
|
||||
# Add patients to system
|
||||
patients = [
|
||||
("patient_001", "room_101"),
|
||||
("patient_002", "room_102"),
|
||||
("patient_003", "room_103")
|
||||
]
|
||||
|
||||
for patient_id, room_id in patients:
|
||||
result = await healthcare_system.add_patient(patient_id, room_id)
|
||||
assert result is True
|
||||
|
||||
# Simulate pose data updates for different rooms
|
||||
pose_updates = [
|
||||
("room_101", {
|
||||
"persons": [{"person_id": "patient_001", "confidence": 0.9, "activity": "fallen"}],
|
||||
"zone_summary": {"room_101": 1}
|
||||
}),
|
||||
("room_102", {
|
||||
"persons": [{"person_id": "patient_002", "confidence": 0.8, "activity": "standing"}],
|
||||
"zone_summary": {"room_102": 1}
|
||||
}),
|
||||
("room_103", {
|
||||
"persons": [], # No person detected
|
||||
"zone_summary": {"room_103": 0}
|
||||
})
|
||||
]
|
||||
|
||||
all_alerts = []
|
||||
for room_id, pose_data in pose_updates:
|
||||
alerts = await healthcare_system.process_pose_update(room_id, pose_data)
|
||||
all_alerts.extend(alerts)
|
||||
|
||||
# This will fail initially
|
||||
# Should have processed all updates
|
||||
assert len(pose_updates) == 3
|
||||
|
||||
# Check system status
|
||||
system_status = healthcare_system.get_system_status()
|
||||
assert system_status["total_patients"] == 3
|
||||
assert system_status["active_monitors"] == 3
|
||||
assert system_status["system_status"] == "operational"
|
||||
|
||||
# Should have generated some alerts
|
||||
if all_alerts:
|
||||
assert len(all_alerts) > 0
|
||||
assert system_status["total_alerts"] > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_healthcare_system_resilience_should_fail_initially(self):
|
||||
"""Test healthcare system resilience - should fail initially."""
|
||||
patient_monitor = MockPatientMonitor("patient_001", "room_101")
|
||||
notification_system = MockHealthcareNotificationSystem()
|
||||
|
||||
await patient_monitor.start_monitoring()
|
||||
|
||||
# Simulate system stress with rapid pose updates
|
||||
rapid_updates = 50
|
||||
alerts_generated = []
|
||||
|
||||
for i in range(rapid_updates):
|
||||
# Alternate between normal and concerning pose data
|
||||
if i % 10 == 0: # Every 10th update is concerning
|
||||
pose_data = {
|
||||
"persons": [{"person_id": "patient_001", "confidence": 0.9, "activity": "fallen"}],
|
||||
"zone_summary": {"room_101": 1}
|
||||
}
|
||||
else:
|
||||
pose_data = {
|
||||
"persons": [{"person_id": "patient_001", "confidence": 0.85, "activity": "standing"}],
|
||||
"zone_summary": {"room_101": 1}
|
||||
}
|
||||
|
||||
alert = await patient_monitor.process_pose_data(pose_data)
|
||||
if alert:
|
||||
alerts_generated.append(alert)
|
||||
await notification_system.send_alert_notification(alert)
|
||||
|
||||
# This will fail initially
|
||||
# System should handle rapid updates gracefully
|
||||
stats = patient_monitor.get_monitoring_stats()
|
||||
assert stats["activity_samples"] == rapid_updates
|
||||
assert stats["is_monitoring"] is True
|
||||
|
||||
# Should have generated some alerts but not excessive
|
||||
assert len(alerts_generated) <= rapid_updates / 5 # At most 20% alert rate
|
||||
|
||||
notification_stats = notification_system.get_notification_stats()
|
||||
assert notification_stats["total_notifications"] >= len(alerts_generated)
|
||||
661
v1/tests/fixtures/api_client.py
vendored
Normal file
661
v1/tests/fixtures/api_client.py
vendored
Normal file
@@ -0,0 +1,661 @@
|
||||
"""
|
||||
Test client utilities for API testing.
|
||||
|
||||
Provides mock and real API clients for comprehensive testing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional, Union, AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
import websockets
|
||||
import jwt
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
"""Authentication related errors."""
|
||||
pass
|
||||
|
||||
|
||||
class APIError(Exception):
|
||||
"""General API errors."""
|
||||
pass
|
||||
|
||||
|
||||
class RateLimitError(Exception):
|
||||
"""Rate limiting errors."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIResponse:
|
||||
"""API response wrapper."""
|
||||
status_code: int
|
||||
data: Dict[str, Any]
|
||||
headers: Dict[str, str]
|
||||
response_time_ms: float
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
class MockAPIClient:
|
||||
"""Mock API client for testing."""
|
||||
|
||||
def __init__(self, base_url: str = "http://localhost:8000"):
|
||||
self.base_url = base_url
|
||||
self.session = None
|
||||
self.auth_token = None
|
||||
self.refresh_token = None
|
||||
self.token_expires_at = None
|
||||
self.request_history = []
|
||||
self.response_delays = {}
|
||||
self.error_simulation = {}
|
||||
self.rate_limit_config = {
|
||||
"enabled": False,
|
||||
"requests_per_minute": 60,
|
||||
"current_count": 0,
|
||||
"window_start": time.time()
|
||||
}
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
await self.disconnect()
|
||||
|
||||
async def connect(self):
|
||||
"""Initialize connection."""
|
||||
self.session = aiohttp.ClientSession()
|
||||
|
||||
async def disconnect(self):
|
||||
"""Close connection."""
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
|
||||
def set_response_delay(self, endpoint: str, delay_ms: float):
|
||||
"""Set artificial delay for endpoint."""
|
||||
self.response_delays[endpoint] = delay_ms
|
||||
|
||||
def simulate_error(self, endpoint: str, error_type: str, probability: float = 1.0):
|
||||
"""Simulate errors for endpoint."""
|
||||
self.error_simulation[endpoint] = {
|
||||
"type": error_type,
|
||||
"probability": probability
|
||||
}
|
||||
|
||||
def enable_rate_limiting(self, requests_per_minute: int = 60):
|
||||
"""Enable rate limiting simulation."""
|
||||
self.rate_limit_config.update({
|
||||
"enabled": True,
|
||||
"requests_per_minute": requests_per_minute,
|
||||
"current_count": 0,
|
||||
"window_start": time.time()
|
||||
})
|
||||
|
||||
async def _check_rate_limit(self):
|
||||
"""Check rate limiting."""
|
||||
if not self.rate_limit_config["enabled"]:
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
window_duration = 60 # 1 minute
|
||||
|
||||
# Reset window if needed
|
||||
if current_time - self.rate_limit_config["window_start"] > window_duration:
|
||||
self.rate_limit_config["current_count"] = 0
|
||||
self.rate_limit_config["window_start"] = current_time
|
||||
|
||||
# Check limit
|
||||
if self.rate_limit_config["current_count"] >= self.rate_limit_config["requests_per_minute"]:
|
||||
raise RateLimitError("Rate limit exceeded")
|
||||
|
||||
self.rate_limit_config["current_count"] += 1
|
||||
|
||||
async def _simulate_network_delay(self, endpoint: str):
|
||||
"""Simulate network delay."""
|
||||
delay = self.response_delays.get(endpoint, 0)
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay / 1000) # Convert ms to seconds
|
||||
|
||||
async def _check_error_simulation(self, endpoint: str):
|
||||
"""Check if error should be simulated."""
|
||||
if endpoint in self.error_simulation:
|
||||
config = self.error_simulation[endpoint]
|
||||
if random.random() < config["probability"]:
|
||||
error_type = config["type"]
|
||||
if error_type == "timeout":
|
||||
raise asyncio.TimeoutError("Simulated timeout")
|
||||
elif error_type == "connection":
|
||||
raise aiohttp.ClientConnectionError("Simulated connection error")
|
||||
elif error_type == "server_error":
|
||||
raise APIError("Simulated server error")
|
||||
|
||||
async def _make_request(self, method: str, endpoint: str, **kwargs) -> APIResponse:
|
||||
"""Make HTTP request with simulation."""
|
||||
start_time = time.time()
|
||||
|
||||
# Check rate limiting
|
||||
await self._check_rate_limit()
|
||||
|
||||
# Simulate network delay
|
||||
await self._simulate_network_delay(endpoint)
|
||||
|
||||
# Check error simulation
|
||||
await self._check_error_simulation(endpoint)
|
||||
|
||||
# Record request
|
||||
request_record = {
|
||||
"method": method,
|
||||
"endpoint": endpoint,
|
||||
"timestamp": datetime.utcnow(),
|
||||
"kwargs": kwargs
|
||||
}
|
||||
self.request_history.append(request_record)
|
||||
|
||||
# Generate mock response
|
||||
response_data = await self._generate_mock_response(method, endpoint, kwargs)
|
||||
|
||||
end_time = time.time()
|
||||
response_time = (end_time - start_time) * 1000
|
||||
|
||||
return APIResponse(
|
||||
status_code=response_data["status_code"],
|
||||
data=response_data["data"],
|
||||
headers=response_data.get("headers", {}),
|
||||
response_time_ms=response_time,
|
||||
timestamp=datetime.utcnow()
|
||||
)
|
||||
|
||||
async def _generate_mock_response(self, method: str, endpoint: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Generate mock response based on endpoint."""
|
||||
if endpoint == "/health":
|
||||
return {
|
||||
"status_code": 200,
|
||||
"data": {
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"version": "1.0.0"
|
||||
}
|
||||
}
|
||||
|
||||
elif endpoint == "/auth/login":
|
||||
if method == "POST":
|
||||
# Generate mock JWT tokens
|
||||
payload = {
|
||||
"user_id": "test_user",
|
||||
"exp": datetime.utcnow() + timedelta(hours=1)
|
||||
}
|
||||
access_token = jwt.encode(payload, "secret", algorithm="HS256")
|
||||
refresh_token = jwt.encode({"user_id": "test_user"}, "secret", algorithm="HS256")
|
||||
|
||||
self.auth_token = access_token
|
||||
self.refresh_token = refresh_token
|
||||
self.token_expires_at = payload["exp"]
|
||||
|
||||
return {
|
||||
"status_code": 200,
|
||||
"data": {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": 3600
|
||||
}
|
||||
}
|
||||
|
||||
elif endpoint == "/auth/refresh":
|
||||
if method == "POST" and self.refresh_token:
|
||||
# Generate new access token
|
||||
payload = {
|
||||
"user_id": "test_user",
|
||||
"exp": datetime.utcnow() + timedelta(hours=1)
|
||||
}
|
||||
access_token = jwt.encode(payload, "secret", algorithm="HS256")
|
||||
|
||||
self.auth_token = access_token
|
||||
self.token_expires_at = payload["exp"]
|
||||
|
||||
return {
|
||||
"status_code": 200,
|
||||
"data": {
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": 3600
|
||||
}
|
||||
}
|
||||
|
||||
elif endpoint == "/pose/detect":
|
||||
if method == "POST":
|
||||
return {
|
||||
"status_code": 200,
|
||||
"data": {
|
||||
"persons": [
|
||||
{
|
||||
"person_id": "person_1",
|
||||
"confidence": 0.85,
|
||||
"bounding_box": {"x": 100, "y": 150, "width": 80, "height": 180},
|
||||
"keypoints": [[x, y, 0.9] for x, y in zip(range(17), range(17))],
|
||||
"activity": "standing"
|
||||
}
|
||||
],
|
||||
"processing_time_ms": 45.2,
|
||||
"model_version": "v1.0",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
elif endpoint == "/config":
|
||||
if method == "GET":
|
||||
return {
|
||||
"status_code": 200,
|
||||
"data": {
|
||||
"model_config": {
|
||||
"confidence_threshold": 0.7,
|
||||
"nms_threshold": 0.5,
|
||||
"max_persons": 10
|
||||
},
|
||||
"processing_config": {
|
||||
"batch_size": 1,
|
||||
"use_gpu": True,
|
||||
"preprocessing": "standard"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Default response
|
||||
return {
|
||||
"status_code": 404,
|
||||
"data": {"error": "Endpoint not found"}
|
||||
}
|
||||
|
||||
async def get(self, endpoint: str, **kwargs) -> APIResponse:
|
||||
"""Make GET request."""
|
||||
return await self._make_request("GET", endpoint, **kwargs)
|
||||
|
||||
async def post(self, endpoint: str, **kwargs) -> APIResponse:
|
||||
"""Make POST request."""
|
||||
return await self._make_request("POST", endpoint, **kwargs)
|
||||
|
||||
async def put(self, endpoint: str, **kwargs) -> APIResponse:
|
||||
"""Make PUT request."""
|
||||
return await self._make_request("PUT", endpoint, **kwargs)
|
||||
|
||||
async def delete(self, endpoint: str, **kwargs) -> APIResponse:
|
||||
"""Make DELETE request."""
|
||||
return await self._make_request("DELETE", endpoint, **kwargs)
|
||||
|
||||
async def login(self, username: str, password: str) -> bool:
|
||||
"""Authenticate with API."""
|
||||
response = await self.post("/auth/login", json={
|
||||
"username": username,
|
||||
"password": password
|
||||
})
|
||||
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
raise AuthenticationError("Login failed")
|
||||
|
||||
async def refresh_auth_token(self) -> bool:
|
||||
"""Refresh authentication token."""
|
||||
if not self.refresh_token:
|
||||
raise AuthenticationError("No refresh token available")
|
||||
|
||||
response = await self.post("/auth/refresh", json={
|
||||
"refresh_token": self.refresh_token
|
||||
})
|
||||
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
raise AuthenticationError("Token refresh failed")
|
||||
|
||||
def is_authenticated(self) -> bool:
|
||||
"""Check if client is authenticated."""
|
||||
if not self.auth_token or not self.token_expires_at:
|
||||
return False
|
||||
|
||||
return datetime.utcnow() < self.token_expires_at
|
||||
|
||||
def get_request_history(self) -> List[Dict[str, Any]]:
|
||||
"""Get request history."""
|
||||
return self.request_history.copy()
|
||||
|
||||
def clear_request_history(self):
|
||||
"""Clear request history."""
|
||||
self.request_history.clear()
|
||||
|
||||
|
||||
class MockWebSocketClient:
|
||||
"""Mock WebSocket client for testing."""
|
||||
|
||||
def __init__(self, uri: str = "ws://localhost:8000/ws"):
|
||||
self.uri = uri
|
||||
self.websocket = None
|
||||
self.is_connected = False
|
||||
self.messages_received = []
|
||||
self.messages_sent = []
|
||||
self.connection_errors = []
|
||||
self.auto_respond = True
|
||||
self.response_delay = 0.01 # 10ms default delay
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to WebSocket."""
|
||||
try:
|
||||
# Simulate connection
|
||||
await asyncio.sleep(0.01)
|
||||
self.is_connected = True
|
||||
return True
|
||||
except Exception as e:
|
||||
self.connection_errors.append(str(e))
|
||||
return False
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from WebSocket."""
|
||||
self.is_connected = False
|
||||
self.websocket = None
|
||||
|
||||
async def send_message(self, message: Dict[str, Any]) -> bool:
|
||||
"""Send message to WebSocket."""
|
||||
if not self.is_connected:
|
||||
raise ConnectionError("WebSocket not connected")
|
||||
|
||||
# Record sent message
|
||||
self.messages_sent.append({
|
||||
"message": message,
|
||||
"timestamp": datetime.utcnow()
|
||||
})
|
||||
|
||||
# Auto-respond if enabled
|
||||
if self.auto_respond:
|
||||
await asyncio.sleep(self.response_delay)
|
||||
response = await self._generate_auto_response(message)
|
||||
if response:
|
||||
self.messages_received.append({
|
||||
"message": response,
|
||||
"timestamp": datetime.utcnow()
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
async def receive_message(self, timeout: float = 1.0) -> Optional[Dict[str, Any]]:
|
||||
"""Receive message from WebSocket."""
|
||||
if not self.is_connected:
|
||||
raise ConnectionError("WebSocket not connected")
|
||||
|
||||
# Wait for message or timeout
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
if self.messages_received:
|
||||
return self.messages_received.pop(0)["message"]
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return None
|
||||
|
||||
async def _generate_auto_response(self, message: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Generate automatic response to message."""
|
||||
message_type = message.get("type")
|
||||
|
||||
if message_type == "subscribe":
|
||||
return {
|
||||
"type": "subscription_confirmed",
|
||||
"channel": message.get("channel"),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
elif message_type == "pose_request":
|
||||
return {
|
||||
"type": "pose_data",
|
||||
"data": {
|
||||
"persons": [
|
||||
{
|
||||
"person_id": "person_1",
|
||||
"confidence": 0.88,
|
||||
"bounding_box": {"x": 150, "y": 200, "width": 80, "height": 180},
|
||||
"keypoints": [[x, y, 0.9] for x, y in zip(range(17), range(17))]
|
||||
}
|
||||
],
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
},
|
||||
"request_id": message.get("request_id")
|
||||
}
|
||||
|
||||
elif message_type == "ping":
|
||||
return {
|
||||
"type": "pong",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def set_auto_respond(self, enabled: bool, delay_ms: float = 10):
|
||||
"""Configure auto-response behavior."""
|
||||
self.auto_respond = enabled
|
||||
self.response_delay = delay_ms / 1000
|
||||
|
||||
def inject_message(self, message: Dict[str, Any]):
|
||||
"""Inject message as if received from server."""
|
||||
self.messages_received.append({
|
||||
"message": message,
|
||||
"timestamp": datetime.utcnow()
|
||||
})
|
||||
|
||||
def get_sent_messages(self) -> List[Dict[str, Any]]:
|
||||
"""Get all sent messages."""
|
||||
return self.messages_sent.copy()
|
||||
|
||||
def get_received_messages(self) -> List[Dict[str, Any]]:
|
||||
"""Get all received messages."""
|
||||
return self.messages_received.copy()
|
||||
|
||||
def clear_message_history(self):
|
||||
"""Clear message history."""
|
||||
self.messages_sent.clear()
|
||||
self.messages_received.clear()
|
||||
|
||||
|
||||
class APITestClient:
|
||||
"""High-level test client combining HTTP and WebSocket."""
|
||||
|
||||
def __init__(self, base_url: str = "http://localhost:8000"):
|
||||
self.base_url = base_url
|
||||
self.ws_url = base_url.replace("http", "ws") + "/ws"
|
||||
self.http_client = MockAPIClient(base_url)
|
||||
self.ws_client = MockWebSocketClient(self.ws_url)
|
||||
self.test_session_id = None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
await self.setup()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
await self.teardown()
|
||||
|
||||
async def setup(self):
|
||||
"""Setup test client."""
|
||||
await self.http_client.connect()
|
||||
await self.ws_client.connect()
|
||||
self.test_session_id = f"test_session_{int(time.time())}"
|
||||
|
||||
async def teardown(self):
|
||||
"""Teardown test client."""
|
||||
await self.ws_client.disconnect()
|
||||
await self.http_client.disconnect()
|
||||
|
||||
async def authenticate(self, username: str = "test_user", password: str = "test_pass") -> bool:
|
||||
"""Authenticate with API."""
|
||||
return await self.http_client.login(username, password)
|
||||
|
||||
async def test_health_endpoint(self) -> APIResponse:
|
||||
"""Test health endpoint."""
|
||||
return await self.http_client.get("/health")
|
||||
|
||||
async def test_pose_detection(self, csi_data: Dict[str, Any]) -> APIResponse:
|
||||
"""Test pose detection endpoint."""
|
||||
return await self.http_client.post("/pose/detect", json=csi_data)
|
||||
|
||||
async def test_websocket_streaming(self, duration_seconds: int = 5) -> List[Dict[str, Any]]:
|
||||
"""Test WebSocket streaming."""
|
||||
# Subscribe to pose stream
|
||||
await self.ws_client.send_message({
|
||||
"type": "subscribe",
|
||||
"channel": "pose_stream",
|
||||
"session_id": self.test_session_id
|
||||
})
|
||||
|
||||
# Collect messages for specified duration
|
||||
messages = []
|
||||
end_time = time.time() + duration_seconds
|
||||
|
||||
while time.time() < end_time:
|
||||
message = await self.ws_client.receive_message(timeout=0.1)
|
||||
if message:
|
||||
messages.append(message)
|
||||
|
||||
return messages
|
||||
|
||||
async def simulate_concurrent_requests(self, num_requests: int = 10) -> List[APIResponse]:
|
||||
"""Simulate concurrent HTTP requests."""
|
||||
tasks = []
|
||||
|
||||
for i in range(num_requests):
|
||||
task = asyncio.create_task(self.http_client.get("/health"))
|
||||
tasks.append(task)
|
||||
|
||||
responses = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
return responses
|
||||
|
||||
async def simulate_websocket_load(self, num_connections: int = 5, duration_seconds: int = 3) -> Dict[str, Any]:
|
||||
"""Simulate WebSocket load testing."""
|
||||
# Create multiple WebSocket clients
|
||||
ws_clients = []
|
||||
for i in range(num_connections):
|
||||
client = MockWebSocketClient(self.ws_url)
|
||||
await client.connect()
|
||||
ws_clients.append(client)
|
||||
|
||||
# Send messages from all clients
|
||||
message_counts = []
|
||||
|
||||
try:
|
||||
tasks = []
|
||||
for i, client in enumerate(ws_clients):
|
||||
task = asyncio.create_task(self._send_messages_for_duration(client, duration_seconds, i))
|
||||
tasks.append(task)
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
message_counts = results
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
for client in ws_clients:
|
||||
await client.disconnect()
|
||||
|
||||
return {
|
||||
"num_connections": num_connections,
|
||||
"duration_seconds": duration_seconds,
|
||||
"messages_per_connection": message_counts,
|
||||
"total_messages": sum(message_counts)
|
||||
}
|
||||
|
||||
async def _send_messages_for_duration(self, client: MockWebSocketClient, duration: int, client_id: int) -> int:
|
||||
"""Send messages for specified duration."""
|
||||
message_count = 0
|
||||
end_time = time.time() + duration
|
||||
|
||||
while time.time() < end_time:
|
||||
await client.send_message({
|
||||
"type": "ping",
|
||||
"client_id": client_id,
|
||||
"message_id": message_count
|
||||
})
|
||||
message_count += 1
|
||||
await asyncio.sleep(0.1) # 10 messages per second
|
||||
|
||||
return message_count
|
||||
|
||||
def configure_error_simulation(self, endpoint: str, error_type: str, probability: float = 0.1):
|
||||
"""Configure error simulation for testing."""
|
||||
self.http_client.simulate_error(endpoint, error_type, probability)
|
||||
|
||||
def configure_rate_limiting(self, requests_per_minute: int = 60):
|
||||
"""Configure rate limiting for testing."""
|
||||
self.http_client.enable_rate_limiting(requests_per_minute)
|
||||
|
||||
def get_performance_metrics(self) -> Dict[str, Any]:
|
||||
"""Get performance metrics from test session."""
|
||||
http_history = self.http_client.get_request_history()
|
||||
ws_sent = self.ws_client.get_sent_messages()
|
||||
ws_received = self.ws_client.get_received_messages()
|
||||
|
||||
# Calculate HTTP metrics
|
||||
if http_history:
|
||||
response_times = [r.get("response_time_ms", 0) for r in http_history]
|
||||
http_metrics = {
|
||||
"total_requests": len(http_history),
|
||||
"avg_response_time_ms": sum(response_times) / len(response_times),
|
||||
"min_response_time_ms": min(response_times),
|
||||
"max_response_time_ms": max(response_times)
|
||||
}
|
||||
else:
|
||||
http_metrics = {"total_requests": 0}
|
||||
|
||||
# Calculate WebSocket metrics
|
||||
ws_metrics = {
|
||||
"messages_sent": len(ws_sent),
|
||||
"messages_received": len(ws_received),
|
||||
"connection_active": self.ws_client.is_connected
|
||||
}
|
||||
|
||||
return {
|
||||
"session_id": self.test_session_id,
|
||||
"http_metrics": http_metrics,
|
||||
"websocket_metrics": ws_metrics,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
# Utility functions for test data generation
|
||||
def generate_test_csi_data() -> Dict[str, Any]:
|
||||
"""Generate test CSI data for API testing."""
|
||||
import numpy as np
|
||||
|
||||
return {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"router_id": "test_router_001",
|
||||
"amplitude": np.random.uniform(0, 1, (4, 64)).tolist(),
|
||||
"phase": np.random.uniform(-np.pi, np.pi, (4, 64)).tolist(),
|
||||
"frequency": 5.8e9,
|
||||
"bandwidth": 80e6,
|
||||
"num_antennas": 4,
|
||||
"num_subcarriers": 64
|
||||
}
|
||||
|
||||
|
||||
def create_test_user_credentials() -> Dict[str, str]:
|
||||
"""Create test user credentials."""
|
||||
return {
|
||||
"username": "test_user",
|
||||
"password": "test_password_123",
|
||||
"email": "test@example.com"
|
||||
}
|
||||
|
||||
|
||||
async def wait_for_condition(condition_func, timeout: float = 5.0, interval: float = 0.1) -> bool:
|
||||
"""Wait for condition to become true."""
|
||||
end_time = time.time() + timeout
|
||||
|
||||
while time.time() < end_time:
|
||||
if await condition_func() if asyncio.iscoroutinefunction(condition_func) else condition_func():
|
||||
return True
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
return False
|
||||
487
v1/tests/fixtures/csi_data.py
vendored
Normal file
487
v1/tests/fixtures/csi_data.py
vendored
Normal file
@@ -0,0 +1,487 @@
|
||||
"""
|
||||
Test data generation utilities for CSI data.
|
||||
|
||||
Provides realistic CSI data samples for testing pose estimation pipeline.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
import json
|
||||
import random
|
||||
|
||||
|
||||
class CSIDataGenerator:
|
||||
"""Generate realistic CSI data for testing."""
|
||||
|
||||
def __init__(self,
|
||||
frequency: float = 5.8e9,
|
||||
bandwidth: float = 80e6,
|
||||
num_antennas: int = 4,
|
||||
num_subcarriers: int = 64):
|
||||
self.frequency = frequency
|
||||
self.bandwidth = bandwidth
|
||||
self.num_antennas = num_antennas
|
||||
self.num_subcarriers = num_subcarriers
|
||||
self.sample_rate = 1000 # Hz
|
||||
self.noise_level = 0.1
|
||||
|
||||
# Pre-computed patterns for different scenarios
|
||||
self._initialize_patterns()
|
||||
|
||||
def _initialize_patterns(self):
|
||||
"""Initialize CSI patterns for different scenarios."""
|
||||
# Empty room pattern (baseline)
|
||||
self.empty_room_pattern = {
|
||||
"amplitude_mean": 0.3,
|
||||
"amplitude_std": 0.05,
|
||||
"phase_variance": 0.1,
|
||||
"temporal_stability": 0.95
|
||||
}
|
||||
|
||||
# Single person patterns
|
||||
self.single_person_patterns = {
|
||||
"standing": {
|
||||
"amplitude_mean": 0.5,
|
||||
"amplitude_std": 0.08,
|
||||
"phase_variance": 0.2,
|
||||
"temporal_stability": 0.85,
|
||||
"movement_frequency": 0.1
|
||||
},
|
||||
"walking": {
|
||||
"amplitude_mean": 0.6,
|
||||
"amplitude_std": 0.15,
|
||||
"phase_variance": 0.4,
|
||||
"temporal_stability": 0.6,
|
||||
"movement_frequency": 2.0
|
||||
},
|
||||
"sitting": {
|
||||
"amplitude_mean": 0.4,
|
||||
"amplitude_std": 0.06,
|
||||
"phase_variance": 0.15,
|
||||
"temporal_stability": 0.9,
|
||||
"movement_frequency": 0.05
|
||||
},
|
||||
"fallen": {
|
||||
"amplitude_mean": 0.35,
|
||||
"amplitude_std": 0.04,
|
||||
"phase_variance": 0.08,
|
||||
"temporal_stability": 0.95,
|
||||
"movement_frequency": 0.02
|
||||
}
|
||||
}
|
||||
|
||||
# Multi-person patterns
|
||||
self.multi_person_patterns = {
|
||||
2: {"amplitude_multiplier": 1.4, "phase_complexity": 1.6},
|
||||
3: {"amplitude_multiplier": 1.7, "phase_complexity": 2.1},
|
||||
4: {"amplitude_multiplier": 2.0, "phase_complexity": 2.8}
|
||||
}
|
||||
|
||||
def generate_empty_room_sample(self, timestamp: Optional[datetime] = None) -> Dict[str, Any]:
|
||||
"""Generate CSI sample for empty room."""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.utcnow()
|
||||
|
||||
pattern = self.empty_room_pattern
|
||||
|
||||
# Generate amplitude matrix
|
||||
amplitude = np.random.normal(
|
||||
pattern["amplitude_mean"],
|
||||
pattern["amplitude_std"],
|
||||
(self.num_antennas, self.num_subcarriers)
|
||||
)
|
||||
amplitude = np.clip(amplitude, 0, 1)
|
||||
|
||||
# Generate phase matrix
|
||||
phase = np.random.uniform(
|
||||
-np.pi, np.pi,
|
||||
(self.num_antennas, self.num_subcarriers)
|
||||
)
|
||||
|
||||
# Add temporal stability
|
||||
if hasattr(self, '_last_empty_sample'):
|
||||
stability = pattern["temporal_stability"]
|
||||
amplitude = stability * self._last_empty_sample["amplitude"] + (1 - stability) * amplitude
|
||||
phase = stability * self._last_empty_sample["phase"] + (1 - stability) * phase
|
||||
|
||||
sample = {
|
||||
"timestamp": timestamp.isoformat(),
|
||||
"router_id": "router_001",
|
||||
"amplitude": amplitude.tolist(),
|
||||
"phase": phase.tolist(),
|
||||
"frequency": self.frequency,
|
||||
"bandwidth": self.bandwidth,
|
||||
"num_antennas": self.num_antennas,
|
||||
"num_subcarriers": self.num_subcarriers,
|
||||
"sample_rate": self.sample_rate,
|
||||
"scenario": "empty_room",
|
||||
"signal_quality": np.random.uniform(0.85, 0.95)
|
||||
}
|
||||
|
||||
self._last_empty_sample = {
|
||||
"amplitude": amplitude,
|
||||
"phase": phase
|
||||
}
|
||||
|
||||
return sample
|
||||
|
||||
def generate_single_person_sample(self,
|
||||
activity: str = "standing",
|
||||
timestamp: Optional[datetime] = None) -> Dict[str, Any]:
|
||||
"""Generate CSI sample for single person activity."""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.utcnow()
|
||||
|
||||
if activity not in self.single_person_patterns:
|
||||
raise ValueError(f"Unknown activity: {activity}")
|
||||
|
||||
pattern = self.single_person_patterns[activity]
|
||||
|
||||
# Generate base amplitude
|
||||
amplitude = np.random.normal(
|
||||
pattern["amplitude_mean"],
|
||||
pattern["amplitude_std"],
|
||||
(self.num_antennas, self.num_subcarriers)
|
||||
)
|
||||
|
||||
# Add movement-induced variations
|
||||
movement_freq = pattern["movement_frequency"]
|
||||
time_factor = timestamp.timestamp()
|
||||
movement_modulation = 0.1 * np.sin(2 * np.pi * movement_freq * time_factor)
|
||||
amplitude += movement_modulation
|
||||
amplitude = np.clip(amplitude, 0, 1)
|
||||
|
||||
# Generate phase with activity-specific variance
|
||||
phase_base = np.random.uniform(-np.pi, np.pi, (self.num_antennas, self.num_subcarriers))
|
||||
phase_variance = pattern["phase_variance"]
|
||||
phase_noise = np.random.normal(0, phase_variance, (self.num_antennas, self.num_subcarriers))
|
||||
phase = phase_base + phase_noise
|
||||
phase = np.mod(phase + np.pi, 2 * np.pi) - np.pi # Wrap to [-π, π]
|
||||
|
||||
# Add temporal correlation
|
||||
if hasattr(self, f'_last_{activity}_sample'):
|
||||
stability = pattern["temporal_stability"]
|
||||
last_sample = getattr(self, f'_last_{activity}_sample')
|
||||
amplitude = stability * last_sample["amplitude"] + (1 - stability) * amplitude
|
||||
phase = stability * last_sample["phase"] + (1 - stability) * phase
|
||||
|
||||
sample = {
|
||||
"timestamp": timestamp.isoformat(),
|
||||
"router_id": "router_001",
|
||||
"amplitude": amplitude.tolist(),
|
||||
"phase": phase.tolist(),
|
||||
"frequency": self.frequency,
|
||||
"bandwidth": self.bandwidth,
|
||||
"num_antennas": self.num_antennas,
|
||||
"num_subcarriers": self.num_subcarriers,
|
||||
"sample_rate": self.sample_rate,
|
||||
"scenario": f"single_person_{activity}",
|
||||
"signal_quality": np.random.uniform(0.7, 0.9),
|
||||
"activity": activity
|
||||
}
|
||||
|
||||
setattr(self, f'_last_{activity}_sample', {
|
||||
"amplitude": amplitude,
|
||||
"phase": phase
|
||||
})
|
||||
|
||||
return sample
|
||||
|
||||
def generate_multi_person_sample(self,
|
||||
num_persons: int = 2,
|
||||
activities: Optional[List[str]] = None,
|
||||
timestamp: Optional[datetime] = None) -> Dict[str, Any]:
|
||||
"""Generate CSI sample for multiple persons."""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.utcnow()
|
||||
|
||||
if num_persons < 2 or num_persons > 4:
|
||||
raise ValueError("Number of persons must be between 2 and 4")
|
||||
|
||||
if activities is None:
|
||||
activities = random.choices(list(self.single_person_patterns.keys()), k=num_persons)
|
||||
|
||||
if len(activities) != num_persons:
|
||||
raise ValueError("Number of activities must match number of persons")
|
||||
|
||||
# Start with empty room baseline
|
||||
amplitude = np.random.normal(
|
||||
self.empty_room_pattern["amplitude_mean"],
|
||||
self.empty_room_pattern["amplitude_std"],
|
||||
(self.num_antennas, self.num_subcarriers)
|
||||
)
|
||||
|
||||
phase = np.random.uniform(
|
||||
-np.pi, np.pi,
|
||||
(self.num_antennas, self.num_subcarriers)
|
||||
)
|
||||
|
||||
# Add contribution from each person
|
||||
for i, activity in enumerate(activities):
|
||||
person_pattern = self.single_person_patterns[activity]
|
||||
|
||||
# Generate person-specific contribution
|
||||
person_amplitude = np.random.normal(
|
||||
person_pattern["amplitude_mean"] * 0.7, # Reduced for multi-person
|
||||
person_pattern["amplitude_std"],
|
||||
(self.num_antennas, self.num_subcarriers)
|
||||
)
|
||||
|
||||
# Add spatial variation (different persons at different locations)
|
||||
spatial_offset = i * self.num_subcarriers // num_persons
|
||||
person_amplitude = np.roll(person_amplitude, spatial_offset, axis=1)
|
||||
|
||||
# Add movement modulation
|
||||
movement_freq = person_pattern["movement_frequency"]
|
||||
time_factor = timestamp.timestamp() + i * 0.5 # Phase offset between persons
|
||||
movement_modulation = 0.05 * np.sin(2 * np.pi * movement_freq * time_factor)
|
||||
person_amplitude += movement_modulation
|
||||
|
||||
amplitude += person_amplitude
|
||||
|
||||
# Add phase contribution
|
||||
person_phase = np.random.normal(0, person_pattern["phase_variance"],
|
||||
(self.num_antennas, self.num_subcarriers))
|
||||
person_phase = np.roll(person_phase, spatial_offset, axis=1)
|
||||
phase += person_phase
|
||||
|
||||
# Apply multi-person complexity
|
||||
pattern = self.multi_person_patterns[num_persons]
|
||||
amplitude *= pattern["amplitude_multiplier"]
|
||||
phase *= pattern["phase_complexity"]
|
||||
|
||||
# Clip and normalize
|
||||
amplitude = np.clip(amplitude, 0, 1)
|
||||
phase = np.mod(phase + np.pi, 2 * np.pi) - np.pi
|
||||
|
||||
sample = {
|
||||
"timestamp": timestamp.isoformat(),
|
||||
"router_id": "router_001",
|
||||
"amplitude": amplitude.tolist(),
|
||||
"phase": phase.tolist(),
|
||||
"frequency": self.frequency,
|
||||
"bandwidth": self.bandwidth,
|
||||
"num_antennas": self.num_antennas,
|
||||
"num_subcarriers": self.num_subcarriers,
|
||||
"sample_rate": self.sample_rate,
|
||||
"scenario": f"multi_person_{num_persons}",
|
||||
"signal_quality": np.random.uniform(0.6, 0.8),
|
||||
"num_persons": num_persons,
|
||||
"activities": activities
|
||||
}
|
||||
|
||||
return sample
|
||||
|
||||
def generate_time_series(self,
|
||||
duration_seconds: int = 10,
|
||||
scenario: str = "single_person_walking",
|
||||
**kwargs) -> List[Dict[str, Any]]:
|
||||
"""Generate time series of CSI samples."""
|
||||
samples = []
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
for i in range(duration_seconds * self.sample_rate):
|
||||
timestamp = start_time + timedelta(seconds=i / self.sample_rate)
|
||||
|
||||
if scenario == "empty_room":
|
||||
sample = self.generate_empty_room_sample(timestamp)
|
||||
elif scenario.startswith("single_person_"):
|
||||
activity = scenario.replace("single_person_", "")
|
||||
sample = self.generate_single_person_sample(activity, timestamp)
|
||||
elif scenario.startswith("multi_person_"):
|
||||
num_persons = int(scenario.split("_")[-1])
|
||||
sample = self.generate_multi_person_sample(num_persons, timestamp=timestamp, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown scenario: {scenario}")
|
||||
|
||||
samples.append(sample)
|
||||
|
||||
return samples
|
||||
|
||||
def add_noise(self, sample: Dict[str, Any], noise_level: Optional[float] = None) -> Dict[str, Any]:
|
||||
"""Add noise to CSI sample."""
|
||||
if noise_level is None:
|
||||
noise_level = self.noise_level
|
||||
|
||||
noisy_sample = sample.copy()
|
||||
|
||||
# Add amplitude noise
|
||||
amplitude = np.array(sample["amplitude"])
|
||||
amplitude_noise = np.random.normal(0, noise_level, amplitude.shape)
|
||||
noisy_amplitude = amplitude + amplitude_noise
|
||||
noisy_amplitude = np.clip(noisy_amplitude, 0, 1)
|
||||
noisy_sample["amplitude"] = noisy_amplitude.tolist()
|
||||
|
||||
# Add phase noise
|
||||
phase = np.array(sample["phase"])
|
||||
phase_noise = np.random.normal(0, noise_level * np.pi, phase.shape)
|
||||
noisy_phase = phase + phase_noise
|
||||
noisy_phase = np.mod(noisy_phase + np.pi, 2 * np.pi) - np.pi
|
||||
noisy_sample["phase"] = noisy_phase.tolist()
|
||||
|
||||
# Reduce signal quality
|
||||
noisy_sample["signal_quality"] *= (1 - noise_level)
|
||||
|
||||
return noisy_sample
|
||||
|
||||
def simulate_hardware_artifacts(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Simulate hardware-specific artifacts."""
|
||||
artifact_sample = sample.copy()
|
||||
|
||||
amplitude = np.array(sample["amplitude"])
|
||||
phase = np.array(sample["phase"])
|
||||
|
||||
# Simulate antenna coupling
|
||||
coupling_matrix = np.random.uniform(0.95, 1.05, (self.num_antennas, self.num_antennas))
|
||||
amplitude = coupling_matrix @ amplitude
|
||||
|
||||
# Simulate frequency-dependent gain variations
|
||||
freq_response = 1 + 0.1 * np.sin(np.linspace(0, 2*np.pi, self.num_subcarriers))
|
||||
amplitude *= freq_response[np.newaxis, :]
|
||||
|
||||
# Simulate phase drift
|
||||
phase_drift = np.random.uniform(-0.1, 0.1) * np.arange(self.num_subcarriers)
|
||||
phase += phase_drift[np.newaxis, :]
|
||||
|
||||
# Clip and wrap
|
||||
amplitude = np.clip(amplitude, 0, 1)
|
||||
phase = np.mod(phase + np.pi, 2 * np.pi) - np.pi
|
||||
|
||||
artifact_sample["amplitude"] = amplitude.tolist()
|
||||
artifact_sample["phase"] = phase.tolist()
|
||||
|
||||
return artifact_sample
|
||||
|
||||
|
||||
# Convenience functions for common test scenarios
|
||||
def generate_fall_detection_sequence() -> List[Dict[str, Any]]:
|
||||
"""Generate CSI sequence showing fall detection scenario."""
|
||||
generator = CSIDataGenerator()
|
||||
|
||||
sequence = []
|
||||
|
||||
# Normal standing (5 seconds)
|
||||
sequence.extend(generator.generate_time_series(5, "single_person_standing"))
|
||||
|
||||
# Walking (3 seconds)
|
||||
sequence.extend(generator.generate_time_series(3, "single_person_walking"))
|
||||
|
||||
# Fall event (1 second transition)
|
||||
sequence.extend(generator.generate_time_series(1, "single_person_fallen"))
|
||||
|
||||
# Fallen state (3 seconds)
|
||||
sequence.extend(generator.generate_time_series(3, "single_person_fallen"))
|
||||
|
||||
return sequence
|
||||
|
||||
|
||||
def generate_multi_person_scenario() -> List[Dict[str, Any]]:
|
||||
"""Generate CSI sequence for multi-person scenario."""
|
||||
generator = CSIDataGenerator()
|
||||
|
||||
sequence = []
|
||||
|
||||
# Start with empty room
|
||||
sequence.extend(generator.generate_time_series(2, "empty_room"))
|
||||
|
||||
# One person enters
|
||||
sequence.extend(generator.generate_time_series(3, "single_person_walking"))
|
||||
|
||||
# Second person enters
|
||||
sequence.extend(generator.generate_time_series(5, "multi_person_2",
|
||||
activities=["standing", "walking"]))
|
||||
|
||||
# Third person enters
|
||||
sequence.extend(generator.generate_time_series(4, "multi_person_3",
|
||||
activities=["standing", "walking", "sitting"]))
|
||||
|
||||
return sequence
|
||||
|
||||
|
||||
def generate_noisy_environment_data() -> List[Dict[str, Any]]:
|
||||
"""Generate CSI data with various noise levels."""
|
||||
generator = CSIDataGenerator()
|
||||
|
||||
# Generate clean data
|
||||
clean_samples = generator.generate_time_series(5, "single_person_walking")
|
||||
|
||||
# Add different noise levels
|
||||
noisy_samples = []
|
||||
noise_levels = [0.05, 0.1, 0.2, 0.3]
|
||||
|
||||
for noise_level in noise_levels:
|
||||
for sample in clean_samples[:10]: # Take first 10 samples
|
||||
noisy_sample = generator.add_noise(sample, noise_level)
|
||||
noisy_samples.append(noisy_sample)
|
||||
|
||||
return noisy_samples
|
||||
|
||||
|
||||
def generate_hardware_test_data() -> List[Dict[str, Any]]:
|
||||
"""Generate CSI data with hardware artifacts."""
|
||||
generator = CSIDataGenerator()
|
||||
|
||||
# Generate base samples
|
||||
base_samples = generator.generate_time_series(3, "single_person_standing")
|
||||
|
||||
# Add hardware artifacts
|
||||
artifact_samples = []
|
||||
for sample in base_samples:
|
||||
artifact_sample = generator.simulate_hardware_artifacts(sample)
|
||||
artifact_samples.append(artifact_sample)
|
||||
|
||||
return artifact_samples
|
||||
|
||||
|
||||
# Test data validation utilities
|
||||
def validate_csi_sample(sample: Dict[str, Any]) -> bool:
|
||||
"""Validate CSI sample structure and data ranges."""
|
||||
required_fields = [
|
||||
"timestamp", "router_id", "amplitude", "phase",
|
||||
"frequency", "bandwidth", "num_antennas", "num_subcarriers"
|
||||
]
|
||||
|
||||
# Check required fields
|
||||
for field in required_fields:
|
||||
if field not in sample:
|
||||
return False
|
||||
|
||||
# Validate data types and ranges
|
||||
amplitude = np.array(sample["amplitude"])
|
||||
phase = np.array(sample["phase"])
|
||||
|
||||
# Check shapes
|
||||
expected_shape = (sample["num_antennas"], sample["num_subcarriers"])
|
||||
if amplitude.shape != expected_shape or phase.shape != expected_shape:
|
||||
return False
|
||||
|
||||
# Check value ranges
|
||||
if not (0 <= amplitude.min() and amplitude.max() <= 1):
|
||||
return False
|
||||
|
||||
if not (-np.pi <= phase.min() and phase.max() <= np.pi):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def extract_features_from_csi(sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Extract features from CSI sample for testing."""
|
||||
amplitude = np.array(sample["amplitude"])
|
||||
phase = np.array(sample["phase"])
|
||||
|
||||
features = {
|
||||
"amplitude_mean": float(np.mean(amplitude)),
|
||||
"amplitude_std": float(np.std(amplitude)),
|
||||
"amplitude_max": float(np.max(amplitude)),
|
||||
"amplitude_min": float(np.min(amplitude)),
|
||||
"phase_variance": float(np.var(phase)),
|
||||
"phase_range": float(np.max(phase) - np.min(phase)),
|
||||
"signal_energy": float(np.sum(amplitude ** 2)),
|
||||
"phase_coherence": float(np.abs(np.mean(np.exp(1j * phase)))),
|
||||
"spatial_correlation": float(np.mean(np.corrcoef(amplitude))),
|
||||
"frequency_diversity": float(np.std(np.mean(amplitude, axis=0)))
|
||||
}
|
||||
|
||||
return features
|
||||
338
v1/tests/integration/test_api_endpoints.py
Normal file
338
v1/tests/integration/test_api_endpoints.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""
|
||||
Integration tests for WiFi-DensePose API endpoints.
|
||||
|
||||
Tests all REST API endpoints with real service dependencies.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import FastAPI
|
||||
import httpx
|
||||
|
||||
from src.api.dependencies import (
|
||||
get_pose_service,
|
||||
get_stream_service,
|
||||
get_hardware_service,
|
||||
get_current_user
|
||||
)
|
||||
from src.api.routers.health import router as health_router
|
||||
from src.api.routers.pose import router as pose_router
|
||||
from src.api.routers.stream import router as stream_router
|
||||
|
||||
|
||||
class TestAPIEndpoints:
|
||||
"""Integration tests for API endpoints."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create FastAPI app with test dependencies."""
|
||||
app = FastAPI()
|
||||
app.include_router(health_router, prefix="/health", tags=["health"])
|
||||
app.include_router(pose_router, prefix="/pose", tags=["pose"])
|
||||
app.include_router(stream_router, prefix="/stream", tags=["stream"])
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pose_service(self):
|
||||
"""Mock pose service."""
|
||||
service = AsyncMock()
|
||||
service.health_check.return_value = {
|
||||
"status": "healthy",
|
||||
"message": "Service operational",
|
||||
"uptime_seconds": 3600.0,
|
||||
"metrics": {"processed_frames": 1000}
|
||||
}
|
||||
service.is_ready.return_value = True
|
||||
service.estimate_poses.return_value = {
|
||||
"timestamp": datetime.utcnow(),
|
||||
"frame_id": "test-frame-001",
|
||||
"persons": [],
|
||||
"zone_summary": {"zone1": 0},
|
||||
"processing_time_ms": 50.0,
|
||||
"metadata": {}
|
||||
}
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stream_service(self):
|
||||
"""Mock stream service."""
|
||||
service = AsyncMock()
|
||||
service.health_check.return_value = {
|
||||
"status": "healthy",
|
||||
"message": "Stream service operational",
|
||||
"uptime_seconds": 1800.0
|
||||
}
|
||||
service.is_ready.return_value = True
|
||||
service.get_status.return_value = {
|
||||
"is_active": True,
|
||||
"active_streams": [],
|
||||
"uptime_seconds": 1800.0
|
||||
}
|
||||
service.is_active.return_value = True
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def mock_hardware_service(self):
|
||||
"""Mock hardware service."""
|
||||
service = AsyncMock()
|
||||
service.health_check.return_value = {
|
||||
"status": "healthy",
|
||||
"message": "Hardware connected",
|
||||
"uptime_seconds": 7200.0,
|
||||
"metrics": {"connected_routers": 3}
|
||||
}
|
||||
service.is_ready.return_value = True
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user(self):
|
||||
"""Mock authenticated user."""
|
||||
return {
|
||||
"id": "test-user-001",
|
||||
"username": "testuser",
|
||||
"email": "test@example.com",
|
||||
"is_admin": False,
|
||||
"is_active": True,
|
||||
"permissions": ["read", "write"]
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def client(self, app, mock_pose_service, mock_stream_service, mock_hardware_service, mock_user):
|
||||
"""Create test client with mocked dependencies."""
|
||||
app.dependency_overrides[get_pose_service] = lambda: mock_pose_service
|
||||
app.dependency_overrides[get_stream_service] = lambda: mock_stream_service
|
||||
app.dependency_overrides[get_hardware_service] = lambda: mock_hardware_service
|
||||
app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
|
||||
def test_health_check_endpoint_should_fail_initially(self, client):
|
||||
"""Test health check endpoint - should fail initially."""
|
||||
# This test should fail because we haven't implemented the endpoint properly
|
||||
response = client.get("/health/health")
|
||||
|
||||
# This assertion will fail initially, driving us to implement the endpoint
|
||||
assert response.status_code == 200
|
||||
assert "status" in response.json()
|
||||
assert "components" in response.json()
|
||||
assert "system_metrics" in response.json()
|
||||
|
||||
def test_readiness_check_endpoint_should_fail_initially(self, client):
|
||||
"""Test readiness check endpoint - should fail initially."""
|
||||
response = client.get("/health/ready")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "ready" in data
|
||||
assert "checks" in data
|
||||
assert isinstance(data["checks"], dict)
|
||||
|
||||
def test_liveness_check_endpoint_should_fail_initially(self, client):
|
||||
"""Test liveness check endpoint - should fail initially."""
|
||||
response = client.get("/health/live")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "status" in data
|
||||
assert data["status"] == "alive"
|
||||
|
||||
def test_version_info_endpoint_should_fail_initially(self, client):
|
||||
"""Test version info endpoint - should fail initially."""
|
||||
response = client.get("/health/version")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "name" in data
|
||||
assert "version" in data
|
||||
assert "environment" in data
|
||||
|
||||
def test_pose_current_endpoint_should_fail_initially(self, client):
|
||||
"""Test current pose estimation endpoint - should fail initially."""
|
||||
response = client.get("/pose/current")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "timestamp" in data
|
||||
assert "frame_id" in data
|
||||
assert "persons" in data
|
||||
assert "zone_summary" in data
|
||||
|
||||
def test_pose_analyze_endpoint_should_fail_initially(self, client):
|
||||
"""Test pose analysis endpoint - should fail initially."""
|
||||
request_data = {
|
||||
"zone_ids": ["zone1", "zone2"],
|
||||
"confidence_threshold": 0.7,
|
||||
"max_persons": 10,
|
||||
"include_keypoints": True,
|
||||
"include_segmentation": False
|
||||
}
|
||||
|
||||
response = client.post("/pose/analyze", json=request_data)
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "timestamp" in data
|
||||
assert "persons" in data
|
||||
|
||||
def test_zone_occupancy_endpoint_should_fail_initially(self, client):
|
||||
"""Test zone occupancy endpoint - should fail initially."""
|
||||
response = client.get("/pose/zones/zone1/occupancy")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "zone_id" in data
|
||||
assert "current_occupancy" in data
|
||||
|
||||
def test_zones_summary_endpoint_should_fail_initially(self, client):
|
||||
"""Test zones summary endpoint - should fail initially."""
|
||||
response = client.get("/pose/zones/summary")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "total_persons" in data
|
||||
assert "zones" in data
|
||||
|
||||
def test_stream_status_endpoint_should_fail_initially(self, client):
|
||||
"""Test stream status endpoint - should fail initially."""
|
||||
response = client.get("/stream/status")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "is_active" in data
|
||||
assert "connected_clients" in data
|
||||
|
||||
def test_stream_start_endpoint_should_fail_initially(self, client):
|
||||
"""Test stream start endpoint - should fail initially."""
|
||||
response = client.post("/stream/start")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "message" in data
|
||||
|
||||
def test_stream_stop_endpoint_should_fail_initially(self, client):
|
||||
"""Test stream stop endpoint - should fail initially."""
|
||||
response = client.post("/stream/stop")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "message" in data
|
||||
|
||||
|
||||
class TestAPIErrorHandling:
|
||||
"""Test API error handling scenarios."""
|
||||
|
||||
@pytest.fixture
|
||||
def app_with_failing_services(self):
|
||||
"""Create app with failing service dependencies."""
|
||||
app = FastAPI()
|
||||
app.include_router(health_router, prefix="/health", tags=["health"])
|
||||
app.include_router(pose_router, prefix="/pose", tags=["pose"])
|
||||
|
||||
# Mock failing services
|
||||
failing_pose_service = AsyncMock()
|
||||
failing_pose_service.health_check.side_effect = Exception("Service unavailable")
|
||||
|
||||
app.dependency_overrides[get_pose_service] = lambda: failing_pose_service
|
||||
|
||||
return app
|
||||
|
||||
def test_health_check_with_failing_service_should_fail_initially(self, app_with_failing_services):
|
||||
"""Test health check with failing service - should fail initially."""
|
||||
with TestClient(app_with_failing_services) as client:
|
||||
response = client.get("/health/health")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "unhealthy"
|
||||
assert "hardware" in data["components"]
|
||||
assert data["components"]["pose"]["status"] == "unhealthy"
|
||||
|
||||
|
||||
class TestAPIAuthentication:
|
||||
"""Test API authentication scenarios."""
|
||||
|
||||
@pytest.fixture
|
||||
def app_with_auth(self):
|
||||
"""Create app with authentication enabled."""
|
||||
app = FastAPI()
|
||||
app.include_router(pose_router, prefix="/pose", tags=["pose"])
|
||||
|
||||
# Mock authenticated user dependency
|
||||
def get_authenticated_user():
|
||||
return {
|
||||
"id": "auth-user-001",
|
||||
"username": "authuser",
|
||||
"is_admin": True,
|
||||
"permissions": ["read", "write", "admin"]
|
||||
}
|
||||
|
||||
app.dependency_overrides[get_current_user] = get_authenticated_user
|
||||
|
||||
return app
|
||||
|
||||
def test_authenticated_endpoint_access_should_fail_initially(self, app_with_auth):
|
||||
"""Test authenticated endpoint access - should fail initially."""
|
||||
with TestClient(app_with_auth) as client:
|
||||
response = client.post("/pose/analyze", json={
|
||||
"confidence_threshold": 0.8,
|
||||
"include_keypoints": True
|
||||
})
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestAPIValidation:
|
||||
"""Test API request validation."""
|
||||
|
||||
@pytest.fixture
|
||||
def validation_app(self):
|
||||
"""Create app for validation testing."""
|
||||
app = FastAPI()
|
||||
app.include_router(pose_router, prefix="/pose", tags=["pose"])
|
||||
|
||||
# Mock service
|
||||
mock_service = AsyncMock()
|
||||
app.dependency_overrides[get_pose_service] = lambda: mock_service
|
||||
|
||||
return app
|
||||
|
||||
def test_invalid_confidence_threshold_should_fail_initially(self, validation_app):
|
||||
"""Test invalid confidence threshold validation - should fail initially."""
|
||||
with TestClient(validation_app) as client:
|
||||
response = client.post("/pose/analyze", json={
|
||||
"confidence_threshold": 1.5, # Invalid: > 1.0
|
||||
"include_keypoints": True
|
||||
})
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 422
|
||||
assert "validation error" in response.json()["detail"][0]["msg"].lower()
|
||||
|
||||
def test_invalid_max_persons_should_fail_initially(self, validation_app):
|
||||
"""Test invalid max_persons validation - should fail initially."""
|
||||
with TestClient(validation_app) as client:
|
||||
response = client.post("/pose/analyze", json={
|
||||
"max_persons": 0, # Invalid: < 1
|
||||
"include_keypoints": True
|
||||
})
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 422
|
||||
571
v1/tests/integration/test_authentication.py
Normal file
571
v1/tests/integration/test_authentication.py
Normal file
@@ -0,0 +1,571 @@
|
||||
"""
|
||||
Integration tests for authentication and authorization.
|
||||
|
||||
Tests JWT authentication flow, user permissions, and access control.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import jwt
|
||||
import json
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
|
||||
|
||||
class MockJWTToken:
|
||||
"""Mock JWT token for testing."""
|
||||
|
||||
def __init__(self, payload: Dict[str, Any], secret: str = "test-secret"):
|
||||
self.payload = payload
|
||||
self.secret = secret
|
||||
self.token = jwt.encode(payload, secret, algorithm="HS256")
|
||||
|
||||
def decode(self, token: str, secret: str) -> Dict[str, Any]:
|
||||
"""Decode JWT token."""
|
||||
return jwt.decode(token, secret, algorithms=["HS256"])
|
||||
|
||||
|
||||
class TestJWTAuthentication:
|
||||
"""Test JWT authentication functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def valid_user_payload(self):
|
||||
"""Valid user payload for JWT token."""
|
||||
return {
|
||||
"sub": "user-001",
|
||||
"username": "testuser",
|
||||
"email": "test@example.com",
|
||||
"is_admin": False,
|
||||
"is_active": True,
|
||||
"permissions": ["read", "write"],
|
||||
"exp": datetime.utcnow() + timedelta(hours=1),
|
||||
"iat": datetime.utcnow()
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_payload(self):
|
||||
"""Admin user payload for JWT token."""
|
||||
return {
|
||||
"sub": "admin-001",
|
||||
"username": "admin",
|
||||
"email": "admin@example.com",
|
||||
"is_admin": True,
|
||||
"is_active": True,
|
||||
"permissions": ["read", "write", "admin"],
|
||||
"exp": datetime.utcnow() + timedelta(hours=1),
|
||||
"iat": datetime.utcnow()
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def expired_user_payload(self):
|
||||
"""Expired user payload for JWT token."""
|
||||
return {
|
||||
"sub": "user-002",
|
||||
"username": "expireduser",
|
||||
"email": "expired@example.com",
|
||||
"is_admin": False,
|
||||
"is_active": True,
|
||||
"permissions": ["read"],
|
||||
"exp": datetime.utcnow() - timedelta(hours=1), # Expired
|
||||
"iat": datetime.utcnow() - timedelta(hours=2)
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jwt_service(self):
|
||||
"""Mock JWT service."""
|
||||
class MockJWTService:
|
||||
def __init__(self):
|
||||
self.secret = "test-secret-key"
|
||||
self.algorithm = "HS256"
|
||||
|
||||
def create_token(self, user_data: Dict[str, Any]) -> str:
|
||||
"""Create JWT token."""
|
||||
payload = {
|
||||
**user_data,
|
||||
"exp": datetime.utcnow() + timedelta(hours=1),
|
||||
"iat": datetime.utcnow()
|
||||
}
|
||||
return jwt.encode(payload, self.secret, algorithm=self.algorithm)
|
||||
|
||||
def verify_token(self, token: str) -> Dict[str, Any]:
|
||||
"""Verify JWT token."""
|
||||
try:
|
||||
payload = jwt.decode(token, self.secret, algorithms=[self.algorithm])
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has expired"
|
||||
)
|
||||
except jwt.InvalidTokenError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token"
|
||||
)
|
||||
|
||||
def refresh_token(self, token: str) -> str:
|
||||
"""Refresh JWT token."""
|
||||
payload = self.verify_token(token)
|
||||
# Remove exp and iat for new token
|
||||
payload.pop("exp", None)
|
||||
payload.pop("iat", None)
|
||||
return self.create_token(payload)
|
||||
|
||||
return MockJWTService()
|
||||
|
||||
def test_jwt_token_creation_should_fail_initially(self, mock_jwt_service, valid_user_payload):
|
||||
"""Test JWT token creation - should fail initially."""
|
||||
token = mock_jwt_service.create_token(valid_user_payload)
|
||||
|
||||
# This will fail initially
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 0
|
||||
|
||||
# Verify token can be decoded
|
||||
decoded = mock_jwt_service.verify_token(token)
|
||||
assert decoded["sub"] == valid_user_payload["sub"]
|
||||
assert decoded["username"] == valid_user_payload["username"]
|
||||
|
||||
def test_jwt_token_verification_should_fail_initially(self, mock_jwt_service, valid_user_payload):
|
||||
"""Test JWT token verification - should fail initially."""
|
||||
token = mock_jwt_service.create_token(valid_user_payload)
|
||||
decoded = mock_jwt_service.verify_token(token)
|
||||
|
||||
# This will fail initially
|
||||
assert decoded["sub"] == valid_user_payload["sub"]
|
||||
assert decoded["is_admin"] == valid_user_payload["is_admin"]
|
||||
assert "exp" in decoded
|
||||
assert "iat" in decoded
|
||||
|
||||
def test_expired_token_rejection_should_fail_initially(self, mock_jwt_service, expired_user_payload):
|
||||
"""Test expired token rejection - should fail initially."""
|
||||
# Create token with expired payload
|
||||
token = jwt.encode(expired_user_payload, mock_jwt_service.secret, algorithm=mock_jwt_service.algorithm)
|
||||
|
||||
# This should fail initially
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
mock_jwt_service.verify_token(token)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert "expired" in exc_info.value.detail.lower()
|
||||
|
||||
def test_invalid_token_rejection_should_fail_initially(self, mock_jwt_service):
|
||||
"""Test invalid token rejection - should fail initially."""
|
||||
invalid_token = "invalid.jwt.token"
|
||||
|
||||
# This should fail initially
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
mock_jwt_service.verify_token(invalid_token)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert "invalid" in exc_info.value.detail.lower()
|
||||
|
||||
def test_token_refresh_should_fail_initially(self, mock_jwt_service, valid_user_payload):
|
||||
"""Test token refresh functionality - should fail initially."""
|
||||
original_token = mock_jwt_service.create_token(valid_user_payload)
|
||||
|
||||
# Wait a moment to ensure different timestamps
|
||||
import time
|
||||
time.sleep(0.1)
|
||||
|
||||
refreshed_token = mock_jwt_service.refresh_token(original_token)
|
||||
|
||||
# This will fail initially
|
||||
assert refreshed_token != original_token
|
||||
|
||||
# Verify both tokens are valid but have different timestamps
|
||||
original_payload = mock_jwt_service.verify_token(original_token)
|
||||
refreshed_payload = mock_jwt_service.verify_token(refreshed_token)
|
||||
|
||||
assert original_payload["sub"] == refreshed_payload["sub"]
|
||||
assert original_payload["iat"] != refreshed_payload["iat"]
|
||||
|
||||
|
||||
class TestUserAuthentication:
|
||||
"""Test user authentication scenarios."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_service(self):
|
||||
"""Mock user service."""
|
||||
class MockUserService:
|
||||
def __init__(self):
|
||||
self.users = {
|
||||
"testuser": {
|
||||
"id": "user-001",
|
||||
"username": "testuser",
|
||||
"email": "test@example.com",
|
||||
"password_hash": "hashed_password",
|
||||
"is_admin": False,
|
||||
"is_active": True,
|
||||
"permissions": ["read", "write"],
|
||||
"zones": ["zone1", "zone2"],
|
||||
"created_at": datetime.utcnow()
|
||||
},
|
||||
"admin": {
|
||||
"id": "admin-001",
|
||||
"username": "admin",
|
||||
"email": "admin@example.com",
|
||||
"password_hash": "admin_hashed_password",
|
||||
"is_admin": True,
|
||||
"is_active": True,
|
||||
"permissions": ["read", "write", "admin"],
|
||||
"zones": [], # Admin has access to all zones
|
||||
"created_at": datetime.utcnow()
|
||||
}
|
||||
}
|
||||
|
||||
async def authenticate_user(self, username: str, password: str) -> Optional[Dict[str, Any]]:
|
||||
"""Authenticate user with username and password."""
|
||||
user = self.users.get(username)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
# Mock password verification
|
||||
if password == "correct_password":
|
||||
return user
|
||||
return None
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get user by ID."""
|
||||
for user in self.users.values():
|
||||
if user["id"] == user_id:
|
||||
return user
|
||||
return None
|
||||
|
||||
async def update_user_activity(self, user_id: str):
|
||||
"""Update user last activity."""
|
||||
user = await self.get_user_by_id(user_id)
|
||||
if user:
|
||||
user["last_activity"] = datetime.utcnow()
|
||||
|
||||
return MockUserService()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_authentication_success_should_fail_initially(self, mock_user_service):
|
||||
"""Test successful user authentication - should fail initially."""
|
||||
user = await mock_user_service.authenticate_user("testuser", "correct_password")
|
||||
|
||||
# This will fail initially
|
||||
assert user is not None
|
||||
assert user["username"] == "testuser"
|
||||
assert user["is_active"] is True
|
||||
assert "read" in user["permissions"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_authentication_failure_should_fail_initially(self, mock_user_service):
|
||||
"""Test failed user authentication - should fail initially."""
|
||||
user = await mock_user_service.authenticate_user("testuser", "wrong_password")
|
||||
|
||||
# This will fail initially
|
||||
assert user is None
|
||||
|
||||
# Test with non-existent user
|
||||
user = await mock_user_service.authenticate_user("nonexistent", "any_password")
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_user_authentication_should_fail_initially(self, mock_user_service):
|
||||
"""Test admin user authentication - should fail initially."""
|
||||
admin = await mock_user_service.authenticate_user("admin", "correct_password")
|
||||
|
||||
# This will fail initially
|
||||
assert admin is not None
|
||||
assert admin["is_admin"] is True
|
||||
assert "admin" in admin["permissions"]
|
||||
assert admin["zones"] == [] # Admin has access to all zones
|
||||
|
||||
|
||||
class TestAuthorizationDependencies:
|
||||
"""Test authorization dependency functions."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request(self):
|
||||
"""Mock FastAPI request."""
|
||||
class MockRequest:
|
||||
def __init__(self):
|
||||
self.state = MagicMock()
|
||||
self.state.user = None
|
||||
|
||||
return MockRequest()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credentials(self):
|
||||
"""Mock HTTP authorization credentials."""
|
||||
def create_credentials(token: str):
|
||||
return HTTPAuthorizationCredentials(
|
||||
scheme="Bearer",
|
||||
credentials=token
|
||||
)
|
||||
return create_credentials
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_with_valid_token_should_fail_initially(self, mock_request, mock_credentials):
|
||||
"""Test get_current_user with valid token - should fail initially."""
|
||||
# Mock the get_current_user dependency
|
||||
async def mock_get_current_user(request, credentials):
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
# Mock token validation
|
||||
if credentials.credentials == "valid_token":
|
||||
return {
|
||||
"id": "user-001",
|
||||
"username": "testuser",
|
||||
"is_admin": False,
|
||||
"is_active": True,
|
||||
"permissions": ["read", "write"]
|
||||
}
|
||||
return None
|
||||
|
||||
credentials = mock_credentials("valid_token")
|
||||
user = await mock_get_current_user(mock_request, credentials)
|
||||
|
||||
# This will fail initially
|
||||
assert user is not None
|
||||
assert user["username"] == "testuser"
|
||||
assert user["is_active"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_without_credentials_should_fail_initially(self, mock_request):
|
||||
"""Test get_current_user without credentials - should fail initially."""
|
||||
async def mock_get_current_user(request, credentials):
|
||||
if not credentials:
|
||||
return None
|
||||
return {"id": "user-001"}
|
||||
|
||||
user = await mock_get_current_user(mock_request, None)
|
||||
|
||||
# This will fail initially
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_active_user_should_fail_initially(self):
|
||||
"""Test require active user dependency - should fail initially."""
|
||||
async def mock_get_current_active_user(current_user):
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required"
|
||||
)
|
||||
|
||||
if not current_user.get("is_active", True):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
# Test with active user
|
||||
active_user = {"id": "user-001", "is_active": True}
|
||||
result = await mock_get_current_active_user(active_user)
|
||||
|
||||
# This will fail initially
|
||||
assert result == active_user
|
||||
|
||||
# Test with inactive user
|
||||
inactive_user = {"id": "user-002", "is_active": False}
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await mock_get_current_active_user(inactive_user)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
# Test with no user
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await mock_get_current_active_user(None)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_admin_user_should_fail_initially(self):
|
||||
"""Test require admin user dependency - should fail initially."""
|
||||
async def mock_get_admin_user(current_user):
|
||||
if not current_user.get("is_admin", False):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin privileges required"
|
||||
)
|
||||
return current_user
|
||||
|
||||
# Test with admin user
|
||||
admin_user = {"id": "admin-001", "is_admin": True}
|
||||
result = await mock_get_admin_user(admin_user)
|
||||
|
||||
# This will fail initially
|
||||
assert result == admin_user
|
||||
|
||||
# Test with regular user
|
||||
regular_user = {"id": "user-001", "is_admin": False}
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await mock_get_admin_user(regular_user)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_checking_should_fail_initially(self):
|
||||
"""Test permission checking functionality - should fail initially."""
|
||||
def require_permission(permission: str):
|
||||
async def check_permission(current_user):
|
||||
user_permissions = current_user.get("permissions", [])
|
||||
|
||||
# Admin users have all permissions
|
||||
if current_user.get("is_admin", False):
|
||||
return current_user
|
||||
|
||||
# Check specific permission
|
||||
if permission not in user_permissions:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Permission '{permission}' required"
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
return check_permission
|
||||
|
||||
# Test with user having required permission
|
||||
user_with_permission = {
|
||||
"id": "user-001",
|
||||
"permissions": ["read", "write"],
|
||||
"is_admin": False
|
||||
}
|
||||
|
||||
check_read = require_permission("read")
|
||||
result = await check_read(user_with_permission)
|
||||
|
||||
# This will fail initially
|
||||
assert result == user_with_permission
|
||||
|
||||
# Test with user missing permission
|
||||
check_admin = require_permission("admin")
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await check_admin(user_with_permission)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert "admin" in exc_info.value.detail
|
||||
|
||||
# Test with admin user (should have all permissions)
|
||||
admin_user = {"id": "admin-001", "is_admin": True, "permissions": ["read"]}
|
||||
result = await check_admin(admin_user)
|
||||
assert result == admin_user
|
||||
|
||||
|
||||
class TestZoneAndRouterAccess:
|
||||
"""Test zone and router access control."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_domain_config(self):
|
||||
"""Mock domain configuration."""
|
||||
class MockDomainConfig:
|
||||
def __init__(self):
|
||||
self.zones = {
|
||||
"zone1": {"id": "zone1", "name": "Zone 1", "enabled": True},
|
||||
"zone2": {"id": "zone2", "name": "Zone 2", "enabled": True},
|
||||
"zone3": {"id": "zone3", "name": "Zone 3", "enabled": False}
|
||||
}
|
||||
self.routers = {
|
||||
"router1": {"id": "router1", "name": "Router 1", "enabled": True},
|
||||
"router2": {"id": "router2", "name": "Router 2", "enabled": False}
|
||||
}
|
||||
|
||||
def get_zone(self, zone_id: str):
|
||||
return self.zones.get(zone_id)
|
||||
|
||||
def get_router(self, router_id: str):
|
||||
return self.routers.get(router_id)
|
||||
|
||||
return MockDomainConfig()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zone_access_validation_should_fail_initially(self, mock_domain_config):
|
||||
"""Test zone access validation - should fail initially."""
|
||||
async def validate_zone_access(zone_id: str, current_user=None):
|
||||
zone = mock_domain_config.get_zone(zone_id)
|
||||
if not zone:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Zone '{zone_id}' not found"
|
||||
)
|
||||
|
||||
if not zone["enabled"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Zone '{zone_id}' is disabled"
|
||||
)
|
||||
|
||||
if current_user:
|
||||
if current_user.get("is_admin", False):
|
||||
return zone_id
|
||||
|
||||
user_zones = current_user.get("zones", [])
|
||||
if user_zones and zone_id not in user_zones:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Access denied to zone '{zone_id}'"
|
||||
)
|
||||
|
||||
return zone_id
|
||||
|
||||
# Test valid zone access
|
||||
result = await validate_zone_access("zone1")
|
||||
|
||||
# This will fail initially
|
||||
assert result == "zone1"
|
||||
|
||||
# Test invalid zone
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await validate_zone_access("nonexistent")
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
# Test disabled zone
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await validate_zone_access("zone3")
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
# Test user with zone access
|
||||
user_with_access = {"id": "user-001", "zones": ["zone1", "zone2"]}
|
||||
result = await validate_zone_access("zone1", user_with_access)
|
||||
assert result == "zone1"
|
||||
|
||||
# Test user without zone access
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await validate_zone_access("zone2", {"id": "user-002", "zones": ["zone1"]})
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_access_validation_should_fail_initially(self, mock_domain_config):
|
||||
"""Test router access validation - should fail initially."""
|
||||
async def validate_router_access(router_id: str, current_user=None):
|
||||
router = mock_domain_config.get_router(router_id)
|
||||
if not router:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Router '{router_id}' not found"
|
||||
)
|
||||
|
||||
if not router["enabled"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Router '{router_id}' is disabled"
|
||||
)
|
||||
|
||||
return router_id
|
||||
|
||||
# Test valid router access
|
||||
result = await validate_router_access("router1")
|
||||
|
||||
# This will fail initially
|
||||
assert result == "router1"
|
||||
|
||||
# Test disabled router
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await validate_router_access("router2")
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
353
v1/tests/integration/test_csi_pipeline.py
Normal file
353
v1/tests/integration/test_csi_pipeline.py
Normal file
@@ -0,0 +1,353 @@
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from src.core.csi_processor import CSIProcessor
|
||||
from src.core.phase_sanitizer import PhaseSanitizer
|
||||
from src.hardware.router_interface import RouterInterface
|
||||
from src.hardware.csi_extractor import CSIExtractor
|
||||
|
||||
|
||||
class TestCSIPipeline:
|
||||
"""Integration tests for CSI processing pipeline following London School TDD principles"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_router_config(self):
|
||||
"""Configuration for router interface"""
|
||||
return {
|
||||
'router_ip': '192.168.1.1',
|
||||
'username': 'admin',
|
||||
'password': 'password',
|
||||
'ssh_port': 22,
|
||||
'timeout': 30,
|
||||
'max_retries': 3
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_extractor_config(self):
|
||||
"""Configuration for CSI extractor"""
|
||||
return {
|
||||
'interface': 'wlan0',
|
||||
'channel': 6,
|
||||
'bandwidth': 20,
|
||||
'antenna_count': 3,
|
||||
'subcarrier_count': 56,
|
||||
'sample_rate': 1000,
|
||||
'buffer_size': 1024
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_processor_config(self):
|
||||
"""Configuration for CSI processor"""
|
||||
return {
|
||||
'window_size': 100,
|
||||
'overlap': 0.5,
|
||||
'filter_type': 'butterworth',
|
||||
'filter_order': 4,
|
||||
'cutoff_frequency': 50,
|
||||
'normalization': 'minmax',
|
||||
'outlier_threshold': 3.0
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sanitizer_config(self):
|
||||
"""Configuration for phase sanitizer"""
|
||||
return {
|
||||
'unwrap_method': 'numpy',
|
||||
'smoothing_window': 5,
|
||||
'outlier_threshold': 2.0,
|
||||
'interpolation_method': 'linear',
|
||||
'phase_correction': True
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def csi_pipeline_components(self, mock_router_config, mock_extractor_config,
|
||||
mock_processor_config, mock_sanitizer_config):
|
||||
"""Create CSI pipeline components for testing"""
|
||||
router = RouterInterface(mock_router_config)
|
||||
extractor = CSIExtractor(mock_extractor_config)
|
||||
processor = CSIProcessor(mock_processor_config)
|
||||
sanitizer = PhaseSanitizer(mock_sanitizer_config)
|
||||
|
||||
return {
|
||||
'router': router,
|
||||
'extractor': extractor,
|
||||
'processor': processor,
|
||||
'sanitizer': sanitizer
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_raw_csi_data(self):
|
||||
"""Generate mock raw CSI data"""
|
||||
batch_size = 10
|
||||
antennas = 3
|
||||
subcarriers = 56
|
||||
time_samples = 100
|
||||
|
||||
# Generate complex CSI data
|
||||
real_part = np.random.randn(batch_size, antennas, subcarriers, time_samples)
|
||||
imag_part = np.random.randn(batch_size, antennas, subcarriers, time_samples)
|
||||
|
||||
return {
|
||||
'csi_data': real_part + 1j * imag_part,
|
||||
'timestamps': np.linspace(0, 1, time_samples),
|
||||
'metadata': {
|
||||
'channel': 6,
|
||||
'bandwidth': 20,
|
||||
'rssi': -45,
|
||||
'noise_floor': -90
|
||||
}
|
||||
}
|
||||
|
||||
def test_end_to_end_csi_pipeline_processes_data_correctly(self, csi_pipeline_components, mock_raw_csi_data):
|
||||
"""Test that end-to-end CSI pipeline processes data correctly"""
|
||||
# Arrange
|
||||
router = csi_pipeline_components['router']
|
||||
extractor = csi_pipeline_components['extractor']
|
||||
processor = csi_pipeline_components['processor']
|
||||
sanitizer = csi_pipeline_components['sanitizer']
|
||||
|
||||
# Mock the hardware extraction
|
||||
with patch.object(extractor, 'extract_csi_data', return_value=mock_raw_csi_data):
|
||||
with patch.object(router, 'connect', return_value=True):
|
||||
with patch.object(router, 'configure_monitor_mode', return_value=True):
|
||||
|
||||
# Act - Run the pipeline
|
||||
# 1. Connect to router and configure
|
||||
router.connect()
|
||||
router.configure_monitor_mode('wlan0', 6)
|
||||
|
||||
# 2. Extract CSI data
|
||||
raw_data = extractor.extract_csi_data()
|
||||
|
||||
# 3. Process CSI data
|
||||
processed_data = processor.process_csi_batch(raw_data['csi_data'])
|
||||
|
||||
# 4. Sanitize phase information
|
||||
sanitized_data = sanitizer.sanitize_phase_batch(processed_data)
|
||||
|
||||
# Assert
|
||||
assert raw_data is not None
|
||||
assert processed_data is not None
|
||||
assert sanitized_data is not None
|
||||
|
||||
# Check data flow integrity
|
||||
assert isinstance(processed_data, torch.Tensor)
|
||||
assert isinstance(sanitized_data, torch.Tensor)
|
||||
assert processed_data.shape == sanitized_data.shape
|
||||
|
||||
def test_pipeline_handles_hardware_connection_failure(self, csi_pipeline_components):
|
||||
"""Test that pipeline handles hardware connection failures gracefully"""
|
||||
# Arrange
|
||||
router = csi_pipeline_components['router']
|
||||
|
||||
# Mock connection failure
|
||||
with patch.object(router, 'connect', return_value=False):
|
||||
|
||||
# Act & Assert
|
||||
connection_result = router.connect()
|
||||
assert connection_result is False
|
||||
|
||||
# Pipeline should handle this gracefully
|
||||
with pytest.raises(Exception): # Should raise appropriate exception
|
||||
router.configure_monitor_mode('wlan0', 6)
|
||||
|
||||
def test_pipeline_handles_csi_extraction_timeout(self, csi_pipeline_components):
|
||||
"""Test that pipeline handles CSI extraction timeouts"""
|
||||
# Arrange
|
||||
extractor = csi_pipeline_components['extractor']
|
||||
|
||||
# Mock extraction timeout
|
||||
with patch.object(extractor, 'extract_csi_data', side_effect=TimeoutError("CSI extraction timeout")):
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(TimeoutError):
|
||||
extractor.extract_csi_data()
|
||||
|
||||
def test_pipeline_handles_invalid_csi_data_format(self, csi_pipeline_components):
|
||||
"""Test that pipeline handles invalid CSI data formats"""
|
||||
# Arrange
|
||||
processor = csi_pipeline_components['processor']
|
||||
|
||||
# Invalid data format
|
||||
invalid_data = np.random.randn(10, 2, 56) # Missing time dimension
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
processor.process_csi_batch(invalid_data)
|
||||
|
||||
def test_pipeline_maintains_data_consistency_across_stages(self, csi_pipeline_components, mock_raw_csi_data):
|
||||
"""Test that pipeline maintains data consistency across processing stages"""
|
||||
# Arrange
|
||||
processor = csi_pipeline_components['processor']
|
||||
sanitizer = csi_pipeline_components['sanitizer']
|
||||
|
||||
csi_data = mock_raw_csi_data['csi_data']
|
||||
|
||||
# Act
|
||||
processed_data = processor.process_csi_batch(csi_data)
|
||||
sanitized_data = sanitizer.sanitize_phase_batch(processed_data)
|
||||
|
||||
# Assert - Check data consistency
|
||||
assert processed_data.shape[0] == sanitized_data.shape[0] # Batch size preserved
|
||||
assert processed_data.shape[1] == sanitized_data.shape[1] # Antenna count preserved
|
||||
assert processed_data.shape[2] == sanitized_data.shape[2] # Subcarrier count preserved
|
||||
|
||||
# Check that data is not corrupted (no NaN or infinite values)
|
||||
assert not torch.isnan(processed_data).any()
|
||||
assert not torch.isinf(processed_data).any()
|
||||
assert not torch.isnan(sanitized_data).any()
|
||||
assert not torch.isinf(sanitized_data).any()
|
||||
|
||||
def test_pipeline_performance_meets_real_time_requirements(self, csi_pipeline_components, mock_raw_csi_data):
|
||||
"""Test that pipeline performance meets real-time processing requirements"""
|
||||
import time
|
||||
|
||||
# Arrange
|
||||
processor = csi_pipeline_components['processor']
|
||||
sanitizer = csi_pipeline_components['sanitizer']
|
||||
|
||||
csi_data = mock_raw_csi_data['csi_data']
|
||||
|
||||
# Act - Measure processing time
|
||||
start_time = time.time()
|
||||
|
||||
processed_data = processor.process_csi_batch(csi_data)
|
||||
sanitized_data = sanitizer.sanitize_phase_batch(processed_data)
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
# Assert - Should process within reasonable time (< 100ms for this data size)
|
||||
assert processing_time < 0.1, f"Processing took {processing_time:.3f}s, expected < 0.1s"
|
||||
|
||||
def test_pipeline_handles_different_data_sizes(self, csi_pipeline_components):
|
||||
"""Test that pipeline handles different CSI data sizes"""
|
||||
# Arrange
|
||||
processor = csi_pipeline_components['processor']
|
||||
sanitizer = csi_pipeline_components['sanitizer']
|
||||
|
||||
# Different data sizes
|
||||
small_data = np.random.randn(1, 3, 56, 50) + 1j * np.random.randn(1, 3, 56, 50)
|
||||
large_data = np.random.randn(20, 3, 56, 200) + 1j * np.random.randn(20, 3, 56, 200)
|
||||
|
||||
# Act
|
||||
small_processed = processor.process_csi_batch(small_data)
|
||||
small_sanitized = sanitizer.sanitize_phase_batch(small_processed)
|
||||
|
||||
large_processed = processor.process_csi_batch(large_data)
|
||||
large_sanitized = sanitizer.sanitize_phase_batch(large_processed)
|
||||
|
||||
# Assert
|
||||
assert small_processed.shape == small_sanitized.shape
|
||||
assert large_processed.shape == large_sanitized.shape
|
||||
assert small_processed.shape != large_processed.shape # Different sizes
|
||||
|
||||
def test_pipeline_configuration_validation(self, mock_router_config, mock_extractor_config,
|
||||
mock_processor_config, mock_sanitizer_config):
|
||||
"""Test that pipeline components validate configurations properly"""
|
||||
# Arrange - Invalid configurations
|
||||
invalid_router_config = mock_router_config.copy()
|
||||
invalid_router_config['router_ip'] = 'invalid_ip'
|
||||
|
||||
invalid_extractor_config = mock_extractor_config.copy()
|
||||
invalid_extractor_config['antenna_count'] = 0
|
||||
|
||||
invalid_processor_config = mock_processor_config.copy()
|
||||
invalid_processor_config['window_size'] = -1
|
||||
|
||||
invalid_sanitizer_config = mock_sanitizer_config.copy()
|
||||
invalid_sanitizer_config['smoothing_window'] = 0
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
RouterInterface(invalid_router_config)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
CSIExtractor(invalid_extractor_config)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
CSIProcessor(invalid_processor_config)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
PhaseSanitizer(invalid_sanitizer_config)
|
||||
|
||||
def test_pipeline_error_recovery_and_logging(self, csi_pipeline_components, mock_raw_csi_data):
|
||||
"""Test that pipeline handles errors gracefully and logs appropriately"""
|
||||
# Arrange
|
||||
processor = csi_pipeline_components['processor']
|
||||
|
||||
# Corrupt some data to trigger error handling
|
||||
corrupted_data = mock_raw_csi_data['csi_data'].copy()
|
||||
corrupted_data[0, 0, 0, :] = np.inf # Introduce infinite values
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError): # Should detect and handle corrupted data
|
||||
processor.process_csi_batch(corrupted_data)
|
||||
|
||||
def test_pipeline_memory_usage_optimization(self, csi_pipeline_components):
|
||||
"""Test that pipeline optimizes memory usage for large datasets"""
|
||||
# Arrange
|
||||
processor = csi_pipeline_components['processor']
|
||||
sanitizer = csi_pipeline_components['sanitizer']
|
||||
|
||||
# Large dataset
|
||||
large_data = np.random.randn(100, 3, 56, 1000) + 1j * np.random.randn(100, 3, 56, 1000)
|
||||
|
||||
# Act - Process in chunks to test memory optimization
|
||||
chunk_size = 10
|
||||
results = []
|
||||
|
||||
for i in range(0, large_data.shape[0], chunk_size):
|
||||
chunk = large_data[i:i+chunk_size]
|
||||
processed_chunk = processor.process_csi_batch(chunk)
|
||||
sanitized_chunk = sanitizer.sanitize_phase_batch(processed_chunk)
|
||||
results.append(sanitized_chunk)
|
||||
|
||||
# Assert
|
||||
assert len(results) == 10 # 100 samples / 10 chunk_size
|
||||
for result in results:
|
||||
assert result.shape[0] <= chunk_size
|
||||
|
||||
def test_pipeline_supports_concurrent_processing(self, csi_pipeline_components, mock_raw_csi_data):
|
||||
"""Test that pipeline supports concurrent processing of multiple streams"""
|
||||
import threading
|
||||
import queue
|
||||
|
||||
# Arrange
|
||||
processor = csi_pipeline_components['processor']
|
||||
sanitizer = csi_pipeline_components['sanitizer']
|
||||
|
||||
results_queue = queue.Queue()
|
||||
|
||||
def process_stream(stream_id, data):
|
||||
try:
|
||||
processed = processor.process_csi_batch(data)
|
||||
sanitized = sanitizer.sanitize_phase_batch(processed)
|
||||
results_queue.put((stream_id, sanitized))
|
||||
except Exception as e:
|
||||
results_queue.put((stream_id, e))
|
||||
|
||||
# Act - Process multiple streams concurrently
|
||||
threads = []
|
||||
for i in range(3):
|
||||
thread = threading.Thread(
|
||||
target=process_stream,
|
||||
args=(i, mock_raw_csi_data['csi_data'])
|
||||
)
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for all threads to complete
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Assert
|
||||
results = []
|
||||
while not results_queue.empty():
|
||||
results.append(results_queue.get())
|
||||
|
||||
assert len(results) == 3
|
||||
for stream_id, result in results:
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert not isinstance(result, Exception)
|
||||
447
v1/tests/integration/test_full_system_integration.py
Normal file
447
v1/tests/integration/test_full_system_integration.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""
|
||||
Full system integration tests for WiFi-DensePose API
|
||||
Tests the complete integration of all components working together.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
import httpx
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from src.config.settings import get_settings
|
||||
from src.app import app
|
||||
from src.database.connection import get_database_manager
|
||||
from src.services.orchestrator import get_service_orchestrator
|
||||
from src.tasks.cleanup import get_cleanup_manager
|
||||
from src.tasks.monitoring import get_monitoring_manager
|
||||
from src.tasks.backup import get_backup_manager
|
||||
|
||||
|
||||
class TestFullSystemIntegration:
|
||||
"""Test complete system integration."""
|
||||
|
||||
@pytest.fixture
|
||||
async def settings(self):
|
||||
"""Get test settings."""
|
||||
settings = get_settings()
|
||||
settings.environment = "test"
|
||||
settings.debug = True
|
||||
settings.database_url = "sqlite+aiosqlite:///test_integration.db"
|
||||
settings.redis_enabled = False
|
||||
return settings
|
||||
|
||||
@pytest.fixture
|
||||
async def db_manager(self, settings):
|
||||
"""Get database manager for testing."""
|
||||
manager = get_database_manager(settings)
|
||||
await manager.initialize()
|
||||
yield manager
|
||||
await manager.close_all_connections()
|
||||
|
||||
@pytest.fixture
|
||||
async def client(self, settings):
|
||||
"""Get test HTTP client."""
|
||||
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
@pytest.fixture
|
||||
async def orchestrator(self, settings, db_manager):
|
||||
"""Get service orchestrator for testing."""
|
||||
orchestrator = get_service_orchestrator(settings)
|
||||
await orchestrator.initialize()
|
||||
yield orchestrator
|
||||
await orchestrator.shutdown()
|
||||
|
||||
async def test_application_startup_and_shutdown(self, settings, db_manager):
|
||||
"""Test complete application startup and shutdown sequence."""
|
||||
|
||||
# Test database initialization
|
||||
await db_manager.test_connection()
|
||||
stats = await db_manager.get_connection_stats()
|
||||
assert stats["database"]["connected"] is True
|
||||
|
||||
# Test service orchestrator initialization
|
||||
orchestrator = get_service_orchestrator(settings)
|
||||
await orchestrator.initialize()
|
||||
|
||||
# Verify services are running
|
||||
health_status = await orchestrator.get_health_status()
|
||||
assert health_status["status"] in ["healthy", "warning"]
|
||||
|
||||
# Test graceful shutdown
|
||||
await orchestrator.shutdown()
|
||||
|
||||
# Verify cleanup
|
||||
final_stats = await db_manager.get_connection_stats()
|
||||
assert final_stats is not None
|
||||
|
||||
async def test_api_endpoints_integration(self, client, settings, db_manager):
|
||||
"""Test API endpoints work with database integration."""
|
||||
|
||||
# Test health endpoint
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
health_data = response.json()
|
||||
assert "status" in health_data
|
||||
assert "timestamp" in health_data
|
||||
|
||||
# Test metrics endpoint
|
||||
response = await client.get("/metrics")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Test devices endpoint
|
||||
response = await client.get("/api/v1/devices")
|
||||
assert response.status_code == 200
|
||||
devices_data = response.json()
|
||||
assert "devices" in devices_data
|
||||
assert isinstance(devices_data["devices"], list)
|
||||
|
||||
# Test sessions endpoint
|
||||
response = await client.get("/api/v1/sessions")
|
||||
assert response.status_code == 200
|
||||
sessions_data = response.json()
|
||||
assert "sessions" in sessions_data
|
||||
assert isinstance(sessions_data["sessions"], list)
|
||||
|
||||
@patch('src.core.router_interface.RouterInterface')
|
||||
@patch('src.core.csi_processor.CSIProcessor')
|
||||
@patch('src.core.pose_estimator.PoseEstimator')
|
||||
async def test_data_processing_pipeline(
|
||||
self,
|
||||
mock_pose_estimator,
|
||||
mock_csi_processor,
|
||||
mock_router_interface,
|
||||
client,
|
||||
settings,
|
||||
db_manager
|
||||
):
|
||||
"""Test complete data processing pipeline integration."""
|
||||
|
||||
# Setup mocks
|
||||
mock_router = MagicMock()
|
||||
mock_router_interface.return_value = mock_router
|
||||
mock_router.connect.return_value = True
|
||||
mock_router.start_capture.return_value = True
|
||||
mock_router.get_csi_data.return_value = {
|
||||
"timestamp": time.time(),
|
||||
"csi_matrix": [[1.0, 2.0], [3.0, 4.0]],
|
||||
"rssi": -45,
|
||||
"noise_floor": -90
|
||||
}
|
||||
|
||||
mock_processor = MagicMock()
|
||||
mock_csi_processor.return_value = mock_processor
|
||||
mock_processor.process_csi_data.return_value = {
|
||||
"processed_csi": [[1.1, 2.1], [3.1, 4.1]],
|
||||
"quality_score": 0.85,
|
||||
"phase_sanitized": True
|
||||
}
|
||||
|
||||
mock_estimator = MagicMock()
|
||||
mock_pose_estimator.return_value = mock_estimator
|
||||
mock_estimator.estimate_pose.return_value = {
|
||||
"pose_data": {
|
||||
"keypoints": [[100, 200], [150, 250]],
|
||||
"confidence": 0.9
|
||||
},
|
||||
"processing_time": 0.05
|
||||
}
|
||||
|
||||
# Test device registration
|
||||
device_data = {
|
||||
"name": "test_router",
|
||||
"ip_address": "192.168.1.1",
|
||||
"device_type": "router",
|
||||
"model": "test_model"
|
||||
}
|
||||
|
||||
response = await client.post("/api/v1/devices", json=device_data)
|
||||
assert response.status_code == 201
|
||||
device_response = response.json()
|
||||
device_id = device_response["device"]["id"]
|
||||
|
||||
# Test session creation
|
||||
session_data = {
|
||||
"device_id": device_id,
|
||||
"session_type": "pose_detection",
|
||||
"configuration": {
|
||||
"sampling_rate": 1000,
|
||||
"duration": 60
|
||||
}
|
||||
}
|
||||
|
||||
response = await client.post("/api/v1/sessions", json=session_data)
|
||||
assert response.status_code == 201
|
||||
session_response = response.json()
|
||||
session_id = session_response["session"]["id"]
|
||||
|
||||
# Test CSI data submission
|
||||
csi_data = {
|
||||
"session_id": session_id,
|
||||
"timestamp": time.time(),
|
||||
"csi_matrix": [[1.0, 2.0], [3.0, 4.0]],
|
||||
"rssi": -45,
|
||||
"noise_floor": -90
|
||||
}
|
||||
|
||||
response = await client.post("/api/v1/csi-data", json=csi_data)
|
||||
assert response.status_code == 201
|
||||
|
||||
# Test pose detection retrieval
|
||||
response = await client.get(f"/api/v1/sessions/{session_id}/pose-detections")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Test session completion
|
||||
response = await client.patch(
|
||||
f"/api/v1/sessions/{session_id}",
|
||||
json={"status": "completed"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_background_tasks_integration(self, settings, db_manager):
|
||||
"""Test background tasks integration."""
|
||||
|
||||
# Test cleanup manager
|
||||
cleanup_manager = get_cleanup_manager(settings)
|
||||
cleanup_stats = cleanup_manager.get_stats()
|
||||
assert "manager" in cleanup_stats
|
||||
|
||||
# Run cleanup task
|
||||
cleanup_result = await cleanup_manager.run_all_tasks()
|
||||
assert cleanup_result["success"] is True
|
||||
|
||||
# Test monitoring manager
|
||||
monitoring_manager = get_monitoring_manager(settings)
|
||||
monitoring_stats = monitoring_manager.get_stats()
|
||||
assert "manager" in monitoring_stats
|
||||
|
||||
# Run monitoring task
|
||||
monitoring_result = await monitoring_manager.run_all_tasks()
|
||||
assert monitoring_result["success"] is True
|
||||
|
||||
# Test backup manager
|
||||
backup_manager = get_backup_manager(settings)
|
||||
backup_stats = backup_manager.get_stats()
|
||||
assert "manager" in backup_stats
|
||||
|
||||
# Run backup task
|
||||
backup_result = await backup_manager.run_all_tasks()
|
||||
assert backup_result["success"] is True
|
||||
|
||||
async def test_error_handling_integration(self, client, settings, db_manager):
|
||||
"""Test error handling across the system."""
|
||||
|
||||
# Test invalid device creation
|
||||
invalid_device_data = {
|
||||
"name": "", # Invalid empty name
|
||||
"ip_address": "invalid_ip",
|
||||
"device_type": "unknown_type"
|
||||
}
|
||||
|
||||
response = await client.post("/api/v1/devices", json=invalid_device_data)
|
||||
assert response.status_code == 422
|
||||
error_response = response.json()
|
||||
assert "detail" in error_response
|
||||
|
||||
# Test non-existent resource access
|
||||
response = await client.get("/api/v1/devices/99999")
|
||||
assert response.status_code == 404
|
||||
|
||||
# Test invalid session creation
|
||||
invalid_session_data = {
|
||||
"device_id": "invalid_uuid",
|
||||
"session_type": "invalid_type"
|
||||
}
|
||||
|
||||
response = await client.post("/api/v1/sessions", json=invalid_session_data)
|
||||
assert response.status_code == 422
|
||||
|
||||
async def test_authentication_and_authorization(self, client, settings):
|
||||
"""Test authentication and authorization integration."""
|
||||
|
||||
# Test protected endpoint without authentication
|
||||
response = await client.get("/api/v1/admin/system-info")
|
||||
assert response.status_code in [401, 403]
|
||||
|
||||
# Test with invalid token
|
||||
headers = {"Authorization": "Bearer invalid_token"}
|
||||
response = await client.get("/api/v1/admin/system-info", headers=headers)
|
||||
assert response.status_code in [401, 403]
|
||||
|
||||
async def test_rate_limiting_integration(self, client, settings):
|
||||
"""Test rate limiting integration."""
|
||||
|
||||
# Make multiple rapid requests to test rate limiting
|
||||
responses = []
|
||||
for i in range(10):
|
||||
response = await client.get("/health")
|
||||
responses.append(response.status_code)
|
||||
|
||||
# Should have at least some successful responses
|
||||
assert 200 in responses
|
||||
|
||||
# Rate limiting might kick in for some requests
|
||||
# This depends on the rate limiting configuration
|
||||
|
||||
async def test_monitoring_and_metrics_integration(self, client, settings, db_manager):
|
||||
"""Test monitoring and metrics collection integration."""
|
||||
|
||||
# Test metrics endpoint
|
||||
response = await client.get("/metrics")
|
||||
assert response.status_code == 200
|
||||
metrics_text = response.text
|
||||
|
||||
# Check for Prometheus format metrics
|
||||
assert "# HELP" in metrics_text
|
||||
assert "# TYPE" in metrics_text
|
||||
|
||||
# Test health check with detailed information
|
||||
response = await client.get("/health?detailed=true")
|
||||
assert response.status_code == 200
|
||||
health_data = response.json()
|
||||
|
||||
assert "database" in health_data
|
||||
assert "services" in health_data
|
||||
assert "system" in health_data
|
||||
|
||||
async def test_configuration_management_integration(self, settings):
|
||||
"""Test configuration management integration."""
|
||||
|
||||
# Test settings validation
|
||||
assert settings.environment == "test"
|
||||
assert settings.debug is True
|
||||
|
||||
# Test database URL configuration
|
||||
assert "test_integration.db" in settings.database_url
|
||||
|
||||
# Test Redis configuration
|
||||
assert settings.redis_enabled is False
|
||||
|
||||
# Test logging configuration
|
||||
assert settings.log_level in ["DEBUG", "INFO", "WARNING", "ERROR"]
|
||||
|
||||
async def test_database_migration_integration(self, settings, db_manager):
|
||||
"""Test database migration integration."""
|
||||
|
||||
# Test database connection
|
||||
await db_manager.test_connection()
|
||||
|
||||
# Test table creation
|
||||
async with db_manager.get_async_session() as session:
|
||||
from sqlalchemy import text
|
||||
|
||||
# Check if tables exist
|
||||
tables_query = text("""
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND name NOT LIKE 'sqlite_%'
|
||||
""")
|
||||
|
||||
result = await session.execute(tables_query)
|
||||
tables = [row[0] for row in result.fetchall()]
|
||||
|
||||
# Should have our main tables
|
||||
expected_tables = ["devices", "sessions", "csi_data", "pose_detections"]
|
||||
for table in expected_tables:
|
||||
assert table in tables
|
||||
|
||||
async def test_concurrent_operations_integration(self, client, settings, db_manager):
|
||||
"""Test concurrent operations integration."""
|
||||
|
||||
async def create_device(name: str):
|
||||
device_data = {
|
||||
"name": f"test_device_{name}",
|
||||
"ip_address": f"192.168.1.{name}",
|
||||
"device_type": "router",
|
||||
"model": "test_model"
|
||||
}
|
||||
response = await client.post("/api/v1/devices", json=device_data)
|
||||
return response.status_code
|
||||
|
||||
# Create multiple devices concurrently
|
||||
tasks = [create_device(str(i)) for i in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All should succeed
|
||||
assert all(status == 201 for status in results)
|
||||
|
||||
# Verify all devices were created
|
||||
response = await client.get("/api/v1/devices")
|
||||
assert response.status_code == 200
|
||||
devices_data = response.json()
|
||||
assert len(devices_data["devices"]) >= 5
|
||||
|
||||
async def test_system_resource_management(self, settings, db_manager, orchestrator):
|
||||
"""Test system resource management integration."""
|
||||
|
||||
# Test connection pool management
|
||||
stats = await db_manager.get_connection_stats()
|
||||
assert "database" in stats
|
||||
assert "pool_size" in stats["database"]
|
||||
|
||||
# Test service resource usage
|
||||
health_status = await orchestrator.get_health_status()
|
||||
assert "memory_usage" in health_status
|
||||
assert "cpu_usage" in health_status
|
||||
|
||||
# Test cleanup of resources
|
||||
await orchestrator.cleanup_resources()
|
||||
|
||||
# Verify resources are cleaned up
|
||||
final_stats = await db_manager.get_connection_stats()
|
||||
assert final_stats is not None
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestSystemPerformance:
|
||||
"""Test system performance under load."""
|
||||
|
||||
async def test_api_response_times(self, client):
|
||||
"""Test API response times under normal load."""
|
||||
|
||||
start_time = time.time()
|
||||
response = await client.get("/health")
|
||||
end_time = time.time()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert (end_time - start_time) < 1.0 # Should respond within 1 second
|
||||
|
||||
async def test_database_query_performance(self, db_manager):
|
||||
"""Test database query performance."""
|
||||
|
||||
async with db_manager.get_async_session() as session:
|
||||
from sqlalchemy import text
|
||||
|
||||
start_time = time.time()
|
||||
result = await session.execute(text("SELECT 1"))
|
||||
end_time = time.time()
|
||||
|
||||
assert result.scalar() == 1
|
||||
assert (end_time - start_time) < 0.1 # Should complete within 100ms
|
||||
|
||||
async def test_memory_usage_stability(self, orchestrator):
|
||||
"""Test memory usage remains stable."""
|
||||
|
||||
import psutil
|
||||
import os
|
||||
|
||||
process = psutil.Process(os.getpid())
|
||||
initial_memory = process.memory_info().rss
|
||||
|
||||
# Perform some operations
|
||||
for _ in range(10):
|
||||
health_status = await orchestrator.get_health_status()
|
||||
assert health_status is not None
|
||||
|
||||
final_memory = process.memory_info().rss
|
||||
memory_increase = final_memory - initial_memory
|
||||
|
||||
# Memory increase should be reasonable (less than 50MB)
|
||||
assert memory_increase < 50 * 1024 * 1024
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
663
v1/tests/integration/test_hardware_integration.py
Normal file
663
v1/tests/integration/test_hardware_integration.py
Normal file
@@ -0,0 +1,663 @@
|
||||
"""
|
||||
Integration tests for hardware integration and router communication.
|
||||
|
||||
Tests WiFi router communication, CSI data collection, and hardware management.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import json
|
||||
import socket
|
||||
|
||||
|
||||
class MockRouterInterface:
|
||||
"""Mock WiFi router interface for testing."""
|
||||
|
||||
def __init__(self, router_id: str, ip_address: str = "192.168.1.1"):
|
||||
self.router_id = router_id
|
||||
self.ip_address = ip_address
|
||||
self.is_connected = False
|
||||
self.is_authenticated = False
|
||||
self.csi_streaming = False
|
||||
self.connection_attempts = 0
|
||||
self.last_heartbeat = None
|
||||
self.firmware_version = "1.2.3"
|
||||
self.capabilities = ["csi", "beamforming", "mimo"]
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to the router."""
|
||||
self.connection_attempts += 1
|
||||
|
||||
# Simulate connection failure for testing
|
||||
if self.connection_attempts == 1:
|
||||
return False
|
||||
|
||||
await asyncio.sleep(0.1) # Simulate connection time
|
||||
self.is_connected = True
|
||||
return True
|
||||
|
||||
async def authenticate(self, username: str, password: str) -> bool:
|
||||
"""Authenticate with the router."""
|
||||
if not self.is_connected:
|
||||
return False
|
||||
|
||||
# Simulate authentication
|
||||
if username == "admin" and password == "correct_password":
|
||||
self.is_authenticated = True
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def start_csi_streaming(self, config: Dict[str, Any]) -> bool:
|
||||
"""Start CSI data streaming."""
|
||||
if not self.is_authenticated:
|
||||
return False
|
||||
|
||||
# This should fail initially to test proper error handling
|
||||
return False
|
||||
|
||||
async def stop_csi_streaming(self) -> bool:
|
||||
"""Stop CSI data streaming."""
|
||||
if self.csi_streaming:
|
||||
self.csi_streaming = False
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get router status."""
|
||||
return {
|
||||
"router_id": self.router_id,
|
||||
"ip_address": self.ip_address,
|
||||
"is_connected": self.is_connected,
|
||||
"is_authenticated": self.is_authenticated,
|
||||
"csi_streaming": self.csi_streaming,
|
||||
"firmware_version": self.firmware_version,
|
||||
"uptime_seconds": 3600,
|
||||
"signal_strength": -45.2,
|
||||
"temperature": 42.5,
|
||||
"cpu_usage": 15.3
|
||||
}
|
||||
|
||||
async def send_heartbeat(self) -> bool:
|
||||
"""Send heartbeat to router."""
|
||||
if not self.is_connected:
|
||||
return False
|
||||
|
||||
self.last_heartbeat = datetime.utcnow()
|
||||
return True
|
||||
|
||||
|
||||
class TestRouterConnection:
|
||||
"""Test router connection functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def router_interface(self):
|
||||
"""Create router interface for testing."""
|
||||
return MockRouterInterface("router_001", "192.168.1.100")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_connection_should_fail_initially(self, router_interface):
|
||||
"""Test router connection - should fail initially."""
|
||||
# First connection attempt should fail
|
||||
result = await router_interface.connect()
|
||||
|
||||
# This will fail initially because we designed the mock to fail first attempt
|
||||
assert result is False
|
||||
assert router_interface.is_connected is False
|
||||
assert router_interface.connection_attempts == 1
|
||||
|
||||
# Second attempt should succeed
|
||||
result = await router_interface.connect()
|
||||
assert result is True
|
||||
assert router_interface.is_connected is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_authentication_should_fail_initially(self, router_interface):
|
||||
"""Test router authentication - should fail initially."""
|
||||
# Connect first
|
||||
await router_interface.connect()
|
||||
await router_interface.connect() # Second attempt succeeds
|
||||
|
||||
# Test wrong credentials
|
||||
result = await router_interface.authenticate("admin", "wrong_password")
|
||||
|
||||
# This will fail initially
|
||||
assert result is False
|
||||
assert router_interface.is_authenticated is False
|
||||
|
||||
# Test correct credentials
|
||||
result = await router_interface.authenticate("admin", "correct_password")
|
||||
assert result is True
|
||||
assert router_interface.is_authenticated is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_csi_streaming_start_should_fail_initially(self, router_interface):
|
||||
"""Test CSI streaming start - should fail initially."""
|
||||
# Setup connection and authentication
|
||||
await router_interface.connect()
|
||||
await router_interface.connect() # Second attempt succeeds
|
||||
await router_interface.authenticate("admin", "correct_password")
|
||||
|
||||
# Try to start CSI streaming
|
||||
config = {
|
||||
"frequency": 5.8e9,
|
||||
"bandwidth": 80e6,
|
||||
"sample_rate": 1000,
|
||||
"antenna_config": "4x4_mimo"
|
||||
}
|
||||
|
||||
result = await router_interface.start_csi_streaming(config)
|
||||
|
||||
# This will fail initially because the mock is designed to return False
|
||||
assert result is False
|
||||
assert router_interface.csi_streaming is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_status_retrieval_should_fail_initially(self, router_interface):
|
||||
"""Test router status retrieval - should fail initially."""
|
||||
status = await router_interface.get_status()
|
||||
|
||||
# This will fail initially
|
||||
assert isinstance(status, dict)
|
||||
assert status["router_id"] == "router_001"
|
||||
assert status["ip_address"] == "192.168.1.100"
|
||||
assert "firmware_version" in status
|
||||
assert "uptime_seconds" in status
|
||||
assert "signal_strength" in status
|
||||
assert "temperature" in status
|
||||
assert "cpu_usage" in status
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heartbeat_mechanism_should_fail_initially(self, router_interface):
|
||||
"""Test heartbeat mechanism - should fail initially."""
|
||||
# Heartbeat without connection should fail
|
||||
result = await router_interface.send_heartbeat()
|
||||
|
||||
# This will fail initially
|
||||
assert result is False
|
||||
assert router_interface.last_heartbeat is None
|
||||
|
||||
# Connect and try heartbeat
|
||||
await router_interface.connect()
|
||||
await router_interface.connect() # Second attempt succeeds
|
||||
|
||||
result = await router_interface.send_heartbeat()
|
||||
assert result is True
|
||||
assert router_interface.last_heartbeat is not None
|
||||
|
||||
|
||||
class TestMultiRouterManagement:
|
||||
"""Test management of multiple routers."""
|
||||
|
||||
@pytest.fixture
|
||||
def router_manager(self):
|
||||
"""Create router manager for testing."""
|
||||
class RouterManager:
|
||||
def __init__(self):
|
||||
self.routers = {}
|
||||
self.active_connections = 0
|
||||
|
||||
async def add_router(self, router_id: str, ip_address: str) -> bool:
|
||||
"""Add a router to management."""
|
||||
if router_id in self.routers:
|
||||
return False
|
||||
|
||||
router = MockRouterInterface(router_id, ip_address)
|
||||
self.routers[router_id] = router
|
||||
return True
|
||||
|
||||
async def connect_router(self, router_id: str) -> bool:
|
||||
"""Connect to a specific router."""
|
||||
if router_id not in self.routers:
|
||||
return False
|
||||
|
||||
router = self.routers[router_id]
|
||||
|
||||
# Try connecting twice (mock fails first time)
|
||||
success = await router.connect()
|
||||
if not success:
|
||||
success = await router.connect()
|
||||
|
||||
if success:
|
||||
self.active_connections += 1
|
||||
|
||||
return success
|
||||
|
||||
async def authenticate_router(self, router_id: str, username: str, password: str) -> bool:
|
||||
"""Authenticate with a router."""
|
||||
if router_id not in self.routers:
|
||||
return False
|
||||
|
||||
router = self.routers[router_id]
|
||||
return await router.authenticate(username, password)
|
||||
|
||||
async def get_all_status(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get status of all routers."""
|
||||
status = {}
|
||||
for router_id, router in self.routers.items():
|
||||
status[router_id] = await router.get_status()
|
||||
return status
|
||||
|
||||
async def start_all_csi_streaming(self, config: Dict[str, Any]) -> Dict[str, bool]:
|
||||
"""Start CSI streaming on all authenticated routers."""
|
||||
results = {}
|
||||
for router_id, router in self.routers.items():
|
||||
if router.is_authenticated:
|
||||
results[router_id] = await router.start_csi_streaming(config)
|
||||
else:
|
||||
results[router_id] = False
|
||||
return results
|
||||
|
||||
return RouterManager()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_router_addition_should_fail_initially(self, router_manager):
|
||||
"""Test adding multiple routers - should fail initially."""
|
||||
# Add first router
|
||||
result1 = await router_manager.add_router("router_001", "192.168.1.100")
|
||||
|
||||
# This will fail initially
|
||||
assert result1 is True
|
||||
assert "router_001" in router_manager.routers
|
||||
|
||||
# Add second router
|
||||
result2 = await router_manager.add_router("router_002", "192.168.1.101")
|
||||
assert result2 is True
|
||||
assert "router_002" in router_manager.routers
|
||||
|
||||
# Try to add duplicate router
|
||||
result3 = await router_manager.add_router("router_001", "192.168.1.102")
|
||||
assert result3 is False
|
||||
assert len(router_manager.routers) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_router_connections_should_fail_initially(self, router_manager):
|
||||
"""Test concurrent router connections - should fail initially."""
|
||||
# Add multiple routers
|
||||
await router_manager.add_router("router_001", "192.168.1.100")
|
||||
await router_manager.add_router("router_002", "192.168.1.101")
|
||||
await router_manager.add_router("router_003", "192.168.1.102")
|
||||
|
||||
# Connect to all routers concurrently
|
||||
connection_tasks = [
|
||||
router_manager.connect_router("router_001"),
|
||||
router_manager.connect_router("router_002"),
|
||||
router_manager.connect_router("router_003")
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*connection_tasks)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == 3
|
||||
assert all(results) # All connections should succeed
|
||||
assert router_manager.active_connections == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_status_aggregation_should_fail_initially(self, router_manager):
|
||||
"""Test router status aggregation - should fail initially."""
|
||||
# Add and connect routers
|
||||
await router_manager.add_router("router_001", "192.168.1.100")
|
||||
await router_manager.add_router("router_002", "192.168.1.101")
|
||||
|
||||
await router_manager.connect_router("router_001")
|
||||
await router_manager.connect_router("router_002")
|
||||
|
||||
# Get all status
|
||||
all_status = await router_manager.get_all_status()
|
||||
|
||||
# This will fail initially
|
||||
assert isinstance(all_status, dict)
|
||||
assert len(all_status) == 2
|
||||
assert "router_001" in all_status
|
||||
assert "router_002" in all_status
|
||||
|
||||
# Verify status structure
|
||||
for router_id, status in all_status.items():
|
||||
assert "router_id" in status
|
||||
assert "ip_address" in status
|
||||
assert "is_connected" in status
|
||||
assert status["is_connected"] is True
|
||||
|
||||
|
||||
class TestCSIDataCollection:
|
||||
"""Test CSI data collection from routers."""
|
||||
|
||||
@pytest.fixture
|
||||
def csi_collector(self):
|
||||
"""Create CSI data collector."""
|
||||
class CSICollector:
|
||||
def __init__(self):
|
||||
self.collected_data = []
|
||||
self.is_collecting = False
|
||||
self.collection_rate = 0
|
||||
|
||||
async def start_collection(self, router_interfaces: List[MockRouterInterface]) -> bool:
|
||||
"""Start CSI data collection."""
|
||||
# This should fail initially
|
||||
return False
|
||||
|
||||
async def stop_collection(self) -> bool:
|
||||
"""Stop CSI data collection."""
|
||||
if self.is_collecting:
|
||||
self.is_collecting = False
|
||||
return True
|
||||
return False
|
||||
|
||||
async def collect_frame(self, router_interface: MockRouterInterface) -> Optional[Dict[str, Any]]:
|
||||
"""Collect a single CSI frame."""
|
||||
if not router_interface.csi_streaming:
|
||||
return None
|
||||
|
||||
# Simulate CSI data
|
||||
return {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"router_id": router_interface.router_id,
|
||||
"amplitude": np.random.rand(64, 32).tolist(),
|
||||
"phase": np.random.rand(64, 32).tolist(),
|
||||
"frequency": 5.8e9,
|
||||
"bandwidth": 80e6,
|
||||
"antenna_count": 4,
|
||||
"subcarrier_count": 64,
|
||||
"signal_quality": np.random.uniform(0.7, 0.95)
|
||||
}
|
||||
|
||||
def get_collection_stats(self) -> Dict[str, Any]:
|
||||
"""Get collection statistics."""
|
||||
return {
|
||||
"total_frames": len(self.collected_data),
|
||||
"collection_rate": self.collection_rate,
|
||||
"is_collecting": self.is_collecting,
|
||||
"last_collection": self.collected_data[-1]["timestamp"] if self.collected_data else None
|
||||
}
|
||||
|
||||
return CSICollector()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_csi_collection_start_should_fail_initially(self, csi_collector):
|
||||
"""Test CSI collection start - should fail initially."""
|
||||
router_interfaces = [
|
||||
MockRouterInterface("router_001", "192.168.1.100"),
|
||||
MockRouterInterface("router_002", "192.168.1.101")
|
||||
]
|
||||
|
||||
result = await csi_collector.start_collection(router_interfaces)
|
||||
|
||||
# This will fail initially because the collector is designed to return False
|
||||
assert result is False
|
||||
assert csi_collector.is_collecting is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_frame_collection_should_fail_initially(self, csi_collector):
|
||||
"""Test single frame collection - should fail initially."""
|
||||
router = MockRouterInterface("router_001", "192.168.1.100")
|
||||
|
||||
# Without CSI streaming enabled
|
||||
frame = await csi_collector.collect_frame(router)
|
||||
|
||||
# This will fail initially
|
||||
assert frame is None
|
||||
|
||||
# Enable CSI streaming (manually for testing)
|
||||
router.csi_streaming = True
|
||||
frame = await csi_collector.collect_frame(router)
|
||||
|
||||
assert frame is not None
|
||||
assert "timestamp" in frame
|
||||
assert "router_id" in frame
|
||||
assert "amplitude" in frame
|
||||
assert "phase" in frame
|
||||
assert frame["router_id"] == "router_001"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collection_statistics_should_fail_initially(self, csi_collector):
|
||||
"""Test collection statistics - should fail initially."""
|
||||
stats = csi_collector.get_collection_stats()
|
||||
|
||||
# This will fail initially
|
||||
assert isinstance(stats, dict)
|
||||
assert "total_frames" in stats
|
||||
assert "collection_rate" in stats
|
||||
assert "is_collecting" in stats
|
||||
assert "last_collection" in stats
|
||||
|
||||
assert stats["total_frames"] == 0
|
||||
assert stats["is_collecting"] is False
|
||||
assert stats["last_collection"] is None
|
||||
|
||||
|
||||
class TestHardwareErrorHandling:
|
||||
"""Test hardware error handling scenarios."""
|
||||
|
||||
@pytest.fixture
|
||||
def unreliable_router(self):
|
||||
"""Create unreliable router for error testing."""
|
||||
class UnreliableRouter(MockRouterInterface):
|
||||
def __init__(self, router_id: str, ip_address: str = "192.168.1.1"):
|
||||
super().__init__(router_id, ip_address)
|
||||
self.failure_rate = 0.3 # 30% failure rate
|
||||
self.connection_drops = 0
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Unreliable connection."""
|
||||
if np.random.random() < self.failure_rate:
|
||||
return False
|
||||
return await super().connect()
|
||||
|
||||
async def send_heartbeat(self) -> bool:
|
||||
"""Unreliable heartbeat."""
|
||||
if np.random.random() < self.failure_rate:
|
||||
self.is_connected = False
|
||||
self.connection_drops += 1
|
||||
return False
|
||||
return await super().send_heartbeat()
|
||||
|
||||
async def start_csi_streaming(self, config: Dict[str, Any]) -> bool:
|
||||
"""Unreliable CSI streaming."""
|
||||
if np.random.random() < self.failure_rate:
|
||||
return False
|
||||
|
||||
# Still return False for initial test failure
|
||||
return False
|
||||
|
||||
return UnreliableRouter("unreliable_router", "192.168.1.200")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_retry_mechanism_should_fail_initially(self, unreliable_router):
|
||||
"""Test connection retry mechanism - should fail initially."""
|
||||
max_retries = 5
|
||||
success = False
|
||||
|
||||
for attempt in range(max_retries):
|
||||
result = await unreliable_router.connect()
|
||||
if result:
|
||||
success = True
|
||||
break
|
||||
|
||||
# Wait before retry
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# This will fail initially due to randomness, but should eventually pass
|
||||
# The test demonstrates the need for retry logic
|
||||
assert success or unreliable_router.connection_attempts >= max_retries
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_drop_detection_should_fail_initially(self, unreliable_router):
|
||||
"""Test connection drop detection - should fail initially."""
|
||||
# Establish connection
|
||||
await unreliable_router.connect()
|
||||
await unreliable_router.connect() # Ensure connection
|
||||
|
||||
initial_drops = unreliable_router.connection_drops
|
||||
|
||||
# Send multiple heartbeats to trigger potential drops
|
||||
for _ in range(10):
|
||||
await unreliable_router.send_heartbeat()
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# This will fail initially
|
||||
# Should detect connection drops
|
||||
final_drops = unreliable_router.connection_drops
|
||||
assert final_drops >= initial_drops # May have detected drops
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hardware_timeout_handling_should_fail_initially(self):
|
||||
"""Test hardware timeout handling - should fail initially."""
|
||||
async def slow_operation():
|
||||
"""Simulate slow hardware operation."""
|
||||
await asyncio.sleep(2.0) # 2 second delay
|
||||
return "success"
|
||||
|
||||
# Test with timeout
|
||||
try:
|
||||
result = await asyncio.wait_for(slow_operation(), timeout=1.0)
|
||||
# This should not be reached
|
||||
assert False, "Operation should have timed out"
|
||||
except asyncio.TimeoutError:
|
||||
# This will fail initially because we expect timeout handling
|
||||
assert True # Timeout was properly handled
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_network_error_simulation_should_fail_initially(self):
|
||||
"""Test network error simulation - should fail initially."""
|
||||
class NetworkErrorRouter(MockRouterInterface):
|
||||
async def connect(self) -> bool:
|
||||
"""Simulate network error."""
|
||||
raise ConnectionError("Network unreachable")
|
||||
|
||||
router = NetworkErrorRouter("error_router", "192.168.1.999")
|
||||
|
||||
# This will fail initially
|
||||
with pytest.raises(ConnectionError, match="Network unreachable"):
|
||||
await router.connect()
|
||||
|
||||
|
||||
class TestHardwareConfiguration:
|
||||
"""Test hardware configuration management."""
|
||||
|
||||
@pytest.fixture
|
||||
def config_manager(self):
|
||||
"""Create configuration manager."""
|
||||
class ConfigManager:
|
||||
def __init__(self):
|
||||
self.default_config = {
|
||||
"frequency": 5.8e9,
|
||||
"bandwidth": 80e6,
|
||||
"sample_rate": 1000,
|
||||
"antenna_config": "4x4_mimo",
|
||||
"power_level": 20,
|
||||
"channel": 36
|
||||
}
|
||||
self.router_configs = {}
|
||||
|
||||
def get_router_config(self, router_id: str) -> Dict[str, Any]:
|
||||
"""Get configuration for a specific router."""
|
||||
return self.router_configs.get(router_id, self.default_config.copy())
|
||||
|
||||
def set_router_config(self, router_id: str, config: Dict[str, Any]) -> bool:
|
||||
"""Set configuration for a specific router."""
|
||||
# Validate configuration
|
||||
required_fields = ["frequency", "bandwidth", "sample_rate"]
|
||||
if not all(field in config for field in required_fields):
|
||||
return False
|
||||
|
||||
self.router_configs[router_id] = config
|
||||
return True
|
||||
|
||||
def validate_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate router configuration."""
|
||||
errors = []
|
||||
|
||||
# Frequency validation
|
||||
if "frequency" in config:
|
||||
freq = config["frequency"]
|
||||
if not (2.4e9 <= freq <= 6e9):
|
||||
errors.append("Frequency must be between 2.4GHz and 6GHz")
|
||||
|
||||
# Bandwidth validation
|
||||
if "bandwidth" in config:
|
||||
bw = config["bandwidth"]
|
||||
if bw not in [20e6, 40e6, 80e6, 160e6]:
|
||||
errors.append("Bandwidth must be 20, 40, 80, or 160 MHz")
|
||||
|
||||
# Sample rate validation
|
||||
if "sample_rate" in config:
|
||||
sr = config["sample_rate"]
|
||||
if not (100 <= sr <= 10000):
|
||||
errors.append("Sample rate must be between 100 and 10000 Hz")
|
||||
|
||||
return {
|
||||
"valid": len(errors) == 0,
|
||||
"errors": errors
|
||||
}
|
||||
|
||||
return ConfigManager()
|
||||
|
||||
def test_default_configuration_should_fail_initially(self, config_manager):
|
||||
"""Test default configuration retrieval - should fail initially."""
|
||||
config = config_manager.get_router_config("new_router")
|
||||
|
||||
# This will fail initially
|
||||
assert isinstance(config, dict)
|
||||
assert "frequency" in config
|
||||
assert "bandwidth" in config
|
||||
assert "sample_rate" in config
|
||||
assert "antenna_config" in config
|
||||
assert config["frequency"] == 5.8e9
|
||||
assert config["bandwidth"] == 80e6
|
||||
|
||||
def test_configuration_validation_should_fail_initially(self, config_manager):
|
||||
"""Test configuration validation - should fail initially."""
|
||||
# Valid configuration
|
||||
valid_config = {
|
||||
"frequency": 5.8e9,
|
||||
"bandwidth": 80e6,
|
||||
"sample_rate": 1000
|
||||
}
|
||||
|
||||
result = config_manager.validate_config(valid_config)
|
||||
|
||||
# This will fail initially
|
||||
assert result["valid"] is True
|
||||
assert len(result["errors"]) == 0
|
||||
|
||||
# Invalid configuration
|
||||
invalid_config = {
|
||||
"frequency": 10e9, # Too high
|
||||
"bandwidth": 100e6, # Invalid
|
||||
"sample_rate": 50 # Too low
|
||||
}
|
||||
|
||||
result = config_manager.validate_config(invalid_config)
|
||||
assert result["valid"] is False
|
||||
assert len(result["errors"]) == 3
|
||||
|
||||
def test_router_specific_configuration_should_fail_initially(self, config_manager):
|
||||
"""Test router-specific configuration - should fail initially."""
|
||||
router_id = "router_001"
|
||||
custom_config = {
|
||||
"frequency": 2.4e9,
|
||||
"bandwidth": 40e6,
|
||||
"sample_rate": 500,
|
||||
"antenna_config": "2x2_mimo"
|
||||
}
|
||||
|
||||
# Set custom configuration
|
||||
result = config_manager.set_router_config(router_id, custom_config)
|
||||
|
||||
# This will fail initially
|
||||
assert result is True
|
||||
|
||||
# Retrieve custom configuration
|
||||
retrieved_config = config_manager.get_router_config(router_id)
|
||||
assert retrieved_config["frequency"] == 2.4e9
|
||||
assert retrieved_config["bandwidth"] == 40e6
|
||||
assert retrieved_config["antenna_config"] == "2x2_mimo"
|
||||
|
||||
# Test invalid configuration
|
||||
invalid_config = {"frequency": 5.8e9} # Missing required fields
|
||||
result = config_manager.set_router_config(router_id, invalid_config)
|
||||
assert result is False
|
||||
459
v1/tests/integration/test_inference_pipeline.py
Normal file
459
v1/tests/integration/test_inference_pipeline.py
Normal file
@@ -0,0 +1,459 @@
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from src.core.csi_processor import CSIProcessor
|
||||
from src.core.phase_sanitizer import PhaseSanitizer
|
||||
from src.models.modality_translation import ModalityTranslationNetwork
|
||||
from src.models.densepose_head import DensePoseHead
|
||||
|
||||
|
||||
class TestInferencePipeline:
|
||||
"""Integration tests for inference pipeline following London School TDD principles"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_csi_processor_config(self):
|
||||
"""Configuration for CSI processor"""
|
||||
return {
|
||||
'window_size': 100,
|
||||
'overlap': 0.5,
|
||||
'filter_type': 'butterworth',
|
||||
'filter_order': 4,
|
||||
'cutoff_frequency': 50,
|
||||
'normalization': 'minmax',
|
||||
'outlier_threshold': 3.0
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sanitizer_config(self):
|
||||
"""Configuration for phase sanitizer"""
|
||||
return {
|
||||
'unwrap_method': 'numpy',
|
||||
'smoothing_window': 5,
|
||||
'outlier_threshold': 2.0,
|
||||
'interpolation_method': 'linear',
|
||||
'phase_correction': True
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_translation_config(self):
|
||||
"""Configuration for modality translation network"""
|
||||
return {
|
||||
'input_channels': 6,
|
||||
'output_channels': 256,
|
||||
'hidden_channels': [64, 128, 256],
|
||||
'kernel_sizes': [7, 5, 3],
|
||||
'strides': [2, 2, 1],
|
||||
'dropout_rate': 0.1,
|
||||
'use_attention': True,
|
||||
'attention_heads': 8,
|
||||
'use_residual': True,
|
||||
'activation': 'relu',
|
||||
'normalization': 'batch'
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_densepose_config(self):
|
||||
"""Configuration for DensePose head"""
|
||||
return {
|
||||
'input_channels': 256,
|
||||
'num_body_parts': 24,
|
||||
'num_uv_coordinates': 2,
|
||||
'hidden_channels': [128, 64],
|
||||
'kernel_size': 3,
|
||||
'padding': 1,
|
||||
'dropout_rate': 0.1,
|
||||
'use_deformable_conv': False,
|
||||
'use_fpn': True,
|
||||
'fpn_levels': [2, 3, 4, 5],
|
||||
'output_stride': 4
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def inference_pipeline_components(self, mock_csi_processor_config, mock_sanitizer_config,
|
||||
mock_translation_config, mock_densepose_config):
|
||||
"""Create inference pipeline components for testing"""
|
||||
csi_processor = CSIProcessor(mock_csi_processor_config)
|
||||
phase_sanitizer = PhaseSanitizer(mock_sanitizer_config)
|
||||
translation_network = ModalityTranslationNetwork(mock_translation_config)
|
||||
densepose_head = DensePoseHead(mock_densepose_config)
|
||||
|
||||
return {
|
||||
'csi_processor': csi_processor,
|
||||
'phase_sanitizer': phase_sanitizer,
|
||||
'translation_network': translation_network,
|
||||
'densepose_head': densepose_head
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_raw_csi_input(self):
|
||||
"""Generate mock raw CSI input data"""
|
||||
batch_size = 4
|
||||
antennas = 3
|
||||
subcarriers = 56
|
||||
time_samples = 100
|
||||
|
||||
# Generate complex CSI data
|
||||
real_part = np.random.randn(batch_size, antennas, subcarriers, time_samples)
|
||||
imag_part = np.random.randn(batch_size, antennas, subcarriers, time_samples)
|
||||
|
||||
return real_part + 1j * imag_part
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ground_truth_densepose(self):
|
||||
"""Generate mock ground truth DensePose annotations"""
|
||||
batch_size = 4
|
||||
height = 224
|
||||
width = 224
|
||||
num_parts = 24
|
||||
|
||||
# Segmentation masks
|
||||
seg_masks = torch.randint(0, num_parts + 1, (batch_size, height, width))
|
||||
|
||||
# UV coordinates
|
||||
uv_coords = torch.randn(batch_size, 2, height, width)
|
||||
|
||||
return {
|
||||
'segmentation': seg_masks,
|
||||
'uv_coordinates': uv_coords
|
||||
}
|
||||
|
||||
def test_end_to_end_inference_pipeline_produces_valid_output(self, inference_pipeline_components, mock_raw_csi_input):
|
||||
"""Test that end-to-end inference pipeline produces valid DensePose output"""
|
||||
# Arrange
|
||||
csi_processor = inference_pipeline_components['csi_processor']
|
||||
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||
translation_network = inference_pipeline_components['translation_network']
|
||||
densepose_head = inference_pipeline_components['densepose_head']
|
||||
|
||||
# Set models to evaluation mode
|
||||
translation_network.eval()
|
||||
densepose_head.eval()
|
||||
|
||||
# Act - Run the complete inference pipeline
|
||||
with torch.no_grad():
|
||||
# 1. Process CSI data
|
||||
processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input)
|
||||
|
||||
# 2. Sanitize phase information
|
||||
sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi)
|
||||
|
||||
# 3. Translate CSI to visual features
|
||||
visual_features = translation_network(sanitized_csi)
|
||||
|
||||
# 4. Generate DensePose predictions
|
||||
densepose_output = densepose_head(visual_features)
|
||||
|
||||
# Assert
|
||||
assert densepose_output is not None
|
||||
assert isinstance(densepose_output, dict)
|
||||
assert 'segmentation' in densepose_output
|
||||
assert 'uv_coordinates' in densepose_output
|
||||
|
||||
seg_output = densepose_output['segmentation']
|
||||
uv_output = densepose_output['uv_coordinates']
|
||||
|
||||
# Check output shapes
|
||||
assert seg_output.shape[0] == mock_raw_csi_input.shape[0] # Batch size preserved
|
||||
assert seg_output.shape[1] == 25 # 24 body parts + 1 background
|
||||
assert uv_output.shape[0] == mock_raw_csi_input.shape[0] # Batch size preserved
|
||||
assert uv_output.shape[1] == 2 # U and V coordinates
|
||||
|
||||
# Check output ranges
|
||||
assert torch.all(uv_output >= 0) and torch.all(uv_output <= 1) # UV in [0, 1]
|
||||
|
||||
def test_inference_pipeline_handles_different_batch_sizes(self, inference_pipeline_components):
|
||||
"""Test that inference pipeline handles different batch sizes"""
|
||||
# Arrange
|
||||
csi_processor = inference_pipeline_components['csi_processor']
|
||||
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||
translation_network = inference_pipeline_components['translation_network']
|
||||
densepose_head = inference_pipeline_components['densepose_head']
|
||||
|
||||
# Different batch sizes
|
||||
small_batch = np.random.randn(1, 3, 56, 100) + 1j * np.random.randn(1, 3, 56, 100)
|
||||
large_batch = np.random.randn(8, 3, 56, 100) + 1j * np.random.randn(8, 3, 56, 100)
|
||||
|
||||
# Set models to evaluation mode
|
||||
translation_network.eval()
|
||||
densepose_head.eval()
|
||||
|
||||
# Act
|
||||
with torch.no_grad():
|
||||
# Small batch
|
||||
small_processed = csi_processor.process_csi_batch(small_batch)
|
||||
small_sanitized = phase_sanitizer.sanitize_phase_batch(small_processed)
|
||||
small_features = translation_network(small_sanitized)
|
||||
small_output = densepose_head(small_features)
|
||||
|
||||
# Large batch
|
||||
large_processed = csi_processor.process_csi_batch(large_batch)
|
||||
large_sanitized = phase_sanitizer.sanitize_phase_batch(large_processed)
|
||||
large_features = translation_network(large_sanitized)
|
||||
large_output = densepose_head(large_features)
|
||||
|
||||
# Assert
|
||||
assert small_output['segmentation'].shape[0] == 1
|
||||
assert large_output['segmentation'].shape[0] == 8
|
||||
assert small_output['uv_coordinates'].shape[0] == 1
|
||||
assert large_output['uv_coordinates'].shape[0] == 8
|
||||
|
||||
def test_inference_pipeline_maintains_gradient_flow_during_training(self, inference_pipeline_components,
|
||||
mock_raw_csi_input, mock_ground_truth_densepose):
|
||||
"""Test that inference pipeline maintains gradient flow during training"""
|
||||
# Arrange
|
||||
csi_processor = inference_pipeline_components['csi_processor']
|
||||
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||
translation_network = inference_pipeline_components['translation_network']
|
||||
densepose_head = inference_pipeline_components['densepose_head']
|
||||
|
||||
# Set models to training mode
|
||||
translation_network.train()
|
||||
densepose_head.train()
|
||||
|
||||
# Create optimizer
|
||||
optimizer = torch.optim.Adam(
|
||||
list(translation_network.parameters()) + list(densepose_head.parameters()),
|
||||
lr=0.001
|
||||
)
|
||||
|
||||
# Act
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input)
|
||||
sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi)
|
||||
visual_features = translation_network(sanitized_csi)
|
||||
densepose_output = densepose_head(visual_features)
|
||||
|
||||
# Resize ground truth to match output
|
||||
seg_target = torch.nn.functional.interpolate(
|
||||
mock_ground_truth_densepose['segmentation'].float().unsqueeze(1),
|
||||
size=densepose_output['segmentation'].shape[2:],
|
||||
mode='nearest'
|
||||
).squeeze(1).long()
|
||||
|
||||
uv_target = torch.nn.functional.interpolate(
|
||||
mock_ground_truth_densepose['uv_coordinates'],
|
||||
size=densepose_output['uv_coordinates'].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Compute loss
|
||||
loss = densepose_head.compute_total_loss(densepose_output, seg_target, uv_target)
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
||||
# Assert - Check that gradients are computed
|
||||
for param in translation_network.parameters():
|
||||
if param.requires_grad:
|
||||
assert param.grad is not None
|
||||
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
|
||||
|
||||
for param in densepose_head.parameters():
|
||||
if param.requires_grad:
|
||||
assert param.grad is not None
|
||||
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
|
||||
|
||||
def test_inference_pipeline_performance_benchmarking(self, inference_pipeline_components, mock_raw_csi_input):
|
||||
"""Test inference pipeline performance for real-time requirements"""
|
||||
import time
|
||||
|
||||
# Arrange
|
||||
csi_processor = inference_pipeline_components['csi_processor']
|
||||
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||
translation_network = inference_pipeline_components['translation_network']
|
||||
densepose_head = inference_pipeline_components['densepose_head']
|
||||
|
||||
# Set models to evaluation mode for inference
|
||||
translation_network.eval()
|
||||
densepose_head.eval()
|
||||
|
||||
# Warm up (first inference is often slower)
|
||||
with torch.no_grad():
|
||||
processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input)
|
||||
sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi)
|
||||
visual_features = translation_network(sanitized_csi)
|
||||
_ = densepose_head(visual_features)
|
||||
|
||||
# Act - Measure inference time
|
||||
start_time = time.time()
|
||||
|
||||
with torch.no_grad():
|
||||
processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input)
|
||||
sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi)
|
||||
visual_features = translation_network(sanitized_csi)
|
||||
densepose_output = densepose_head(visual_features)
|
||||
|
||||
end_time = time.time()
|
||||
inference_time = end_time - start_time
|
||||
|
||||
# Assert - Should meet real-time requirements (< 50ms for batch of 4)
|
||||
assert inference_time < 0.05, f"Inference took {inference_time:.3f}s, expected < 0.05s"
|
||||
|
||||
def test_inference_pipeline_handles_edge_cases(self, inference_pipeline_components):
|
||||
"""Test that inference pipeline handles edge cases gracefully"""
|
||||
# Arrange
|
||||
csi_processor = inference_pipeline_components['csi_processor']
|
||||
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||
translation_network = inference_pipeline_components['translation_network']
|
||||
densepose_head = inference_pipeline_components['densepose_head']
|
||||
|
||||
# Edge cases
|
||||
zero_input = np.zeros((1, 3, 56, 100), dtype=complex)
|
||||
noisy_input = np.random.randn(1, 3, 56, 100) * 100 + 1j * np.random.randn(1, 3, 56, 100) * 100
|
||||
|
||||
translation_network.eval()
|
||||
densepose_head.eval()
|
||||
|
||||
# Act & Assert
|
||||
with torch.no_grad():
|
||||
# Zero input
|
||||
zero_processed = csi_processor.process_csi_batch(zero_input)
|
||||
zero_sanitized = phase_sanitizer.sanitize_phase_batch(zero_processed)
|
||||
zero_features = translation_network(zero_sanitized)
|
||||
zero_output = densepose_head(zero_features)
|
||||
|
||||
assert not torch.isnan(zero_output['segmentation']).any()
|
||||
assert not torch.isnan(zero_output['uv_coordinates']).any()
|
||||
|
||||
# Noisy input
|
||||
noisy_processed = csi_processor.process_csi_batch(noisy_input)
|
||||
noisy_sanitized = phase_sanitizer.sanitize_phase_batch(noisy_processed)
|
||||
noisy_features = translation_network(noisy_sanitized)
|
||||
noisy_output = densepose_head(noisy_features)
|
||||
|
||||
assert not torch.isnan(noisy_output['segmentation']).any()
|
||||
assert not torch.isnan(noisy_output['uv_coordinates']).any()
|
||||
|
||||
def test_inference_pipeline_memory_efficiency(self, inference_pipeline_components):
|
||||
"""Test that inference pipeline is memory efficient"""
|
||||
# Arrange
|
||||
csi_processor = inference_pipeline_components['csi_processor']
|
||||
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||
translation_network = inference_pipeline_components['translation_network']
|
||||
densepose_head = inference_pipeline_components['densepose_head']
|
||||
|
||||
# Large batch to test memory usage
|
||||
large_input = np.random.randn(16, 3, 56, 100) + 1j * np.random.randn(16, 3, 56, 100)
|
||||
|
||||
translation_network.eval()
|
||||
densepose_head.eval()
|
||||
|
||||
# Act - Process in chunks to manage memory
|
||||
chunk_size = 4
|
||||
outputs = []
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(0, large_input.shape[0], chunk_size):
|
||||
chunk = large_input[i:i+chunk_size]
|
||||
|
||||
processed_chunk = csi_processor.process_csi_batch(chunk)
|
||||
sanitized_chunk = phase_sanitizer.sanitize_phase_batch(processed_chunk)
|
||||
feature_chunk = translation_network(sanitized_chunk)
|
||||
output_chunk = densepose_head(feature_chunk)
|
||||
|
||||
outputs.append(output_chunk)
|
||||
|
||||
# Clear intermediate tensors to free memory
|
||||
del processed_chunk, sanitized_chunk, feature_chunk
|
||||
|
||||
# Assert
|
||||
assert len(outputs) == 4 # 16 samples / 4 chunk_size
|
||||
for output in outputs:
|
||||
assert output['segmentation'].shape[0] <= chunk_size
|
||||
|
||||
def test_inference_pipeline_deterministic_output(self, inference_pipeline_components, mock_raw_csi_input):
|
||||
"""Test that inference pipeline produces deterministic output in eval mode"""
|
||||
# Arrange
|
||||
csi_processor = inference_pipeline_components['csi_processor']
|
||||
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||
translation_network = inference_pipeline_components['translation_network']
|
||||
densepose_head = inference_pipeline_components['densepose_head']
|
||||
|
||||
# Set models to evaluation mode
|
||||
translation_network.eval()
|
||||
densepose_head.eval()
|
||||
|
||||
# Act - Run inference twice
|
||||
with torch.no_grad():
|
||||
# First run
|
||||
processed_csi_1 = csi_processor.process_csi_batch(mock_raw_csi_input)
|
||||
sanitized_csi_1 = phase_sanitizer.sanitize_phase_batch(processed_csi_1)
|
||||
visual_features_1 = translation_network(sanitized_csi_1)
|
||||
output_1 = densepose_head(visual_features_1)
|
||||
|
||||
# Second run
|
||||
processed_csi_2 = csi_processor.process_csi_batch(mock_raw_csi_input)
|
||||
sanitized_csi_2 = phase_sanitizer.sanitize_phase_batch(processed_csi_2)
|
||||
visual_features_2 = translation_network(sanitized_csi_2)
|
||||
output_2 = densepose_head(visual_features_2)
|
||||
|
||||
# Assert - Outputs should be identical in eval mode
|
||||
assert torch.allclose(output_1['segmentation'], output_2['segmentation'], atol=1e-6)
|
||||
assert torch.allclose(output_1['uv_coordinates'], output_2['uv_coordinates'], atol=1e-6)
|
||||
|
||||
def test_inference_pipeline_confidence_estimation(self, inference_pipeline_components, mock_raw_csi_input):
|
||||
"""Test that inference pipeline provides confidence estimates"""
|
||||
# Arrange
|
||||
csi_processor = inference_pipeline_components['csi_processor']
|
||||
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||
translation_network = inference_pipeline_components['translation_network']
|
||||
densepose_head = inference_pipeline_components['densepose_head']
|
||||
|
||||
translation_network.eval()
|
||||
densepose_head.eval()
|
||||
|
||||
# Act
|
||||
with torch.no_grad():
|
||||
processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input)
|
||||
sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi)
|
||||
visual_features = translation_network(sanitized_csi)
|
||||
densepose_output = densepose_head(visual_features)
|
||||
|
||||
# Get confidence estimates
|
||||
confidence = densepose_head.get_prediction_confidence(densepose_output)
|
||||
|
||||
# Assert
|
||||
assert 'segmentation_confidence' in confidence
|
||||
assert 'uv_confidence' in confidence
|
||||
|
||||
seg_conf = confidence['segmentation_confidence']
|
||||
uv_conf = confidence['uv_confidence']
|
||||
|
||||
assert seg_conf.shape[0] == mock_raw_csi_input.shape[0]
|
||||
assert uv_conf.shape[0] == mock_raw_csi_input.shape[0]
|
||||
assert torch.all(seg_conf >= 0) and torch.all(seg_conf <= 1)
|
||||
assert torch.all(uv_conf >= 0)
|
||||
|
||||
def test_inference_pipeline_post_processing(self, inference_pipeline_components, mock_raw_csi_input):
|
||||
"""Test that inference pipeline post-processes predictions correctly"""
|
||||
# Arrange
|
||||
csi_processor = inference_pipeline_components['csi_processor']
|
||||
phase_sanitizer = inference_pipeline_components['phase_sanitizer']
|
||||
translation_network = inference_pipeline_components['translation_network']
|
||||
densepose_head = inference_pipeline_components['densepose_head']
|
||||
|
||||
translation_network.eval()
|
||||
densepose_head.eval()
|
||||
|
||||
# Act
|
||||
with torch.no_grad():
|
||||
processed_csi = csi_processor.process_csi_batch(mock_raw_csi_input)
|
||||
sanitized_csi = phase_sanitizer.sanitize_phase_batch(processed_csi)
|
||||
visual_features = translation_network(sanitized_csi)
|
||||
raw_output = densepose_head(visual_features)
|
||||
|
||||
# Post-process predictions
|
||||
processed_output = densepose_head.post_process_predictions(raw_output)
|
||||
|
||||
# Assert
|
||||
assert 'body_parts' in processed_output
|
||||
assert 'uv_coordinates' in processed_output
|
||||
assert 'confidence_scores' in processed_output
|
||||
|
||||
body_parts = processed_output['body_parts']
|
||||
assert body_parts.dtype == torch.long # Class indices
|
||||
assert torch.all(body_parts >= 0) and torch.all(body_parts <= 24) # Valid class range
|
||||
577
v1/tests/integration/test_pose_pipeline.py
Normal file
577
v1/tests/integration/test_pose_pipeline.py
Normal file
@@ -0,0 +1,577 @@
|
||||
"""
|
||||
Integration tests for end-to-end pose estimation pipeline.
|
||||
|
||||
Tests the complete pose estimation workflow from CSI data to pose results.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import json
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CSIData:
|
||||
"""CSI data structure for testing."""
|
||||
timestamp: datetime
|
||||
router_id: str
|
||||
amplitude: np.ndarray
|
||||
phase: np.ndarray
|
||||
frequency: float
|
||||
bandwidth: float
|
||||
antenna_count: int
|
||||
subcarrier_count: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class PoseResult:
|
||||
"""Pose estimation result structure."""
|
||||
timestamp: datetime
|
||||
frame_id: str
|
||||
persons: List[Dict[str, Any]]
|
||||
zone_summary: Dict[str, int]
|
||||
processing_time_ms: float
|
||||
confidence_scores: List[float]
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class MockCSIProcessor:
|
||||
"""Mock CSI data processor."""
|
||||
|
||||
def __init__(self):
|
||||
self.is_initialized = False
|
||||
self.processing_enabled = True
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the processor."""
|
||||
self.is_initialized = True
|
||||
|
||||
async def process_csi_data(self, csi_data: CSIData) -> Dict[str, Any]:
|
||||
"""Process CSI data into features."""
|
||||
if not self.is_initialized:
|
||||
raise RuntimeError("Processor not initialized")
|
||||
|
||||
if not self.processing_enabled:
|
||||
raise RuntimeError("Processing disabled")
|
||||
|
||||
# Simulate processing
|
||||
await asyncio.sleep(0.01) # Simulate processing time
|
||||
|
||||
return {
|
||||
"features": np.random.rand(64, 32).tolist(), # Mock feature matrix
|
||||
"quality_score": 0.85,
|
||||
"signal_strength": -45.2,
|
||||
"noise_level": -78.1,
|
||||
"processed_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
def set_processing_enabled(self, enabled: bool):
|
||||
"""Enable/disable processing."""
|
||||
self.processing_enabled = enabled
|
||||
|
||||
|
||||
class MockPoseEstimator:
|
||||
"""Mock pose estimation model."""
|
||||
|
||||
def __init__(self):
|
||||
self.is_loaded = False
|
||||
self.model_version = "1.0.0"
|
||||
self.confidence_threshold = 0.5
|
||||
|
||||
async def load_model(self):
|
||||
"""Load the pose estimation model."""
|
||||
await asyncio.sleep(0.1) # Simulate model loading
|
||||
self.is_loaded = True
|
||||
|
||||
async def estimate_poses(self, features: np.ndarray) -> Dict[str, Any]:
|
||||
"""Estimate poses from features."""
|
||||
if not self.is_loaded:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
# Simulate pose estimation
|
||||
await asyncio.sleep(0.05) # Simulate inference time
|
||||
|
||||
# Generate mock pose data
|
||||
num_persons = np.random.randint(0, 4) # 0-3 persons
|
||||
persons = []
|
||||
|
||||
for i in range(num_persons):
|
||||
confidence = np.random.uniform(0.3, 0.95)
|
||||
if confidence >= self.confidence_threshold:
|
||||
persons.append({
|
||||
"person_id": f"person_{i}",
|
||||
"confidence": confidence,
|
||||
"bounding_box": {
|
||||
"x": np.random.uniform(0, 800),
|
||||
"y": np.random.uniform(0, 600),
|
||||
"width": np.random.uniform(50, 200),
|
||||
"height": np.random.uniform(100, 400)
|
||||
},
|
||||
"keypoints": [
|
||||
{
|
||||
"name": "head",
|
||||
"x": np.random.uniform(0, 800),
|
||||
"y": np.random.uniform(0, 200),
|
||||
"confidence": np.random.uniform(0.5, 0.95)
|
||||
},
|
||||
{
|
||||
"name": "torso",
|
||||
"x": np.random.uniform(0, 800),
|
||||
"y": np.random.uniform(200, 400),
|
||||
"confidence": np.random.uniform(0.5, 0.95)
|
||||
}
|
||||
],
|
||||
"activity": "standing" if np.random.random() > 0.2 else "sitting"
|
||||
})
|
||||
|
||||
return {
|
||||
"persons": persons,
|
||||
"processing_time_ms": np.random.uniform(20, 80),
|
||||
"model_version": self.model_version,
|
||||
"confidence_threshold": self.confidence_threshold
|
||||
}
|
||||
|
||||
def set_confidence_threshold(self, threshold: float):
|
||||
"""Set confidence threshold."""
|
||||
self.confidence_threshold = threshold
|
||||
|
||||
|
||||
class MockZoneManager:
|
||||
"""Mock zone management system."""
|
||||
|
||||
def __init__(self):
|
||||
self.zones = {
|
||||
"zone1": {"id": "zone1", "name": "Zone 1", "bounds": [0, 0, 400, 600]},
|
||||
"zone2": {"id": "zone2", "name": "Zone 2", "bounds": [400, 0, 800, 600]},
|
||||
"zone3": {"id": "zone3", "name": "Zone 3", "bounds": [0, 300, 800, 600]}
|
||||
}
|
||||
|
||||
def assign_persons_to_zones(self, persons: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Assign detected persons to zones."""
|
||||
zone_summary = {zone_id: 0 for zone_id in self.zones.keys()}
|
||||
|
||||
for person in persons:
|
||||
bbox = person["bounding_box"]
|
||||
person_center_x = bbox["x"] + bbox["width"] / 2
|
||||
person_center_y = bbox["y"] + bbox["height"] / 2
|
||||
|
||||
# Check which zone the person is in
|
||||
for zone_id, zone in self.zones.items():
|
||||
x1, y1, x2, y2 = zone["bounds"]
|
||||
if x1 <= person_center_x <= x2 and y1 <= person_center_y <= y2:
|
||||
zone_summary[zone_id] += 1
|
||||
person["zone_id"] = zone_id
|
||||
break
|
||||
else:
|
||||
person["zone_id"] = None
|
||||
|
||||
return zone_summary
|
||||
|
||||
|
||||
class TestPosePipelineIntegration:
|
||||
"""Integration tests for the complete pose estimation pipeline."""
|
||||
|
||||
@pytest.fixture
|
||||
def csi_processor(self):
|
||||
"""Create CSI processor."""
|
||||
return MockCSIProcessor()
|
||||
|
||||
@pytest.fixture
|
||||
def pose_estimator(self):
|
||||
"""Create pose estimator."""
|
||||
return MockPoseEstimator()
|
||||
|
||||
@pytest.fixture
|
||||
def zone_manager(self):
|
||||
"""Create zone manager."""
|
||||
return MockZoneManager()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_csi_data(self):
|
||||
"""Create sample CSI data."""
|
||||
return CSIData(
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
amplitude=np.random.rand(64, 32),
|
||||
phase=np.random.rand(64, 32),
|
||||
frequency=5.8e9, # 5.8 GHz
|
||||
bandwidth=80e6, # 80 MHz
|
||||
antenna_count=4,
|
||||
subcarrier_count=64
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
async def pose_pipeline(self, csi_processor, pose_estimator, zone_manager):
|
||||
"""Create complete pose pipeline."""
|
||||
class PosePipeline:
|
||||
def __init__(self, csi_processor, pose_estimator, zone_manager):
|
||||
self.csi_processor = csi_processor
|
||||
self.pose_estimator = pose_estimator
|
||||
self.zone_manager = zone_manager
|
||||
self.is_initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the pipeline."""
|
||||
await self.csi_processor.initialize()
|
||||
await self.pose_estimator.load_model()
|
||||
self.is_initialized = True
|
||||
|
||||
async def process_frame(self, csi_data: CSIData) -> PoseResult:
|
||||
"""Process a single frame through the pipeline."""
|
||||
if not self.is_initialized:
|
||||
raise RuntimeError("Pipeline not initialized")
|
||||
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
# Step 1: Process CSI data
|
||||
processed_data = await self.csi_processor.process_csi_data(csi_data)
|
||||
|
||||
# Step 2: Extract features
|
||||
features = np.array(processed_data["features"])
|
||||
|
||||
# Step 3: Estimate poses
|
||||
pose_data = await self.pose_estimator.estimate_poses(features)
|
||||
|
||||
# Step 4: Assign to zones
|
||||
zone_summary = self.zone_manager.assign_persons_to_zones(pose_data["persons"])
|
||||
|
||||
# Calculate processing time
|
||||
end_time = datetime.utcnow()
|
||||
processing_time = (end_time - start_time).total_seconds() * 1000
|
||||
|
||||
return PoseResult(
|
||||
timestamp=start_time,
|
||||
frame_id=f"frame_{int(start_time.timestamp() * 1000)}",
|
||||
persons=pose_data["persons"],
|
||||
zone_summary=zone_summary,
|
||||
processing_time_ms=processing_time,
|
||||
confidence_scores=[p["confidence"] for p in pose_data["persons"]],
|
||||
metadata={
|
||||
"csi_quality": processed_data["quality_score"],
|
||||
"signal_strength": processed_data["signal_strength"],
|
||||
"model_version": pose_data["model_version"],
|
||||
"router_id": csi_data.router_id
|
||||
}
|
||||
)
|
||||
|
||||
pipeline = PosePipeline(csi_processor, pose_estimator, zone_manager)
|
||||
await pipeline.initialize()
|
||||
return pipeline
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_initialization_should_fail_initially(self, csi_processor, pose_estimator, zone_manager):
|
||||
"""Test pipeline initialization - should fail initially."""
|
||||
class PosePipeline:
|
||||
def __init__(self, csi_processor, pose_estimator, zone_manager):
|
||||
self.csi_processor = csi_processor
|
||||
self.pose_estimator = pose_estimator
|
||||
self.zone_manager = zone_manager
|
||||
self.is_initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
await self.csi_processor.initialize()
|
||||
await self.pose_estimator.load_model()
|
||||
self.is_initialized = True
|
||||
|
||||
pipeline = PosePipeline(csi_processor, pose_estimator, zone_manager)
|
||||
|
||||
# Initially not initialized
|
||||
assert not pipeline.is_initialized
|
||||
assert not csi_processor.is_initialized
|
||||
assert not pose_estimator.is_loaded
|
||||
|
||||
# Initialize pipeline
|
||||
await pipeline.initialize()
|
||||
|
||||
# This will fail initially
|
||||
assert pipeline.is_initialized
|
||||
assert csi_processor.is_initialized
|
||||
assert pose_estimator.is_loaded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_pose_estimation_should_fail_initially(self, pose_pipeline, sample_csi_data):
|
||||
"""Test end-to-end pose estimation - should fail initially."""
|
||||
result = await pose_pipeline.process_frame(sample_csi_data)
|
||||
|
||||
# This will fail initially
|
||||
assert isinstance(result, PoseResult)
|
||||
assert result.timestamp is not None
|
||||
assert result.frame_id.startswith("frame_")
|
||||
assert isinstance(result.persons, list)
|
||||
assert isinstance(result.zone_summary, dict)
|
||||
assert result.processing_time_ms > 0
|
||||
assert isinstance(result.confidence_scores, list)
|
||||
assert isinstance(result.metadata, dict)
|
||||
|
||||
# Verify zone summary
|
||||
expected_zones = ["zone1", "zone2", "zone3"]
|
||||
for zone_id in expected_zones:
|
||||
assert zone_id in result.zone_summary
|
||||
assert isinstance(result.zone_summary[zone_id], int)
|
||||
assert result.zone_summary[zone_id] >= 0
|
||||
|
||||
# Verify metadata
|
||||
assert "csi_quality" in result.metadata
|
||||
assert "signal_strength" in result.metadata
|
||||
assert "model_version" in result.metadata
|
||||
assert "router_id" in result.metadata
|
||||
assert result.metadata["router_id"] == sample_csi_data.router_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_with_multiple_frames_should_fail_initially(self, pose_pipeline):
|
||||
"""Test pipeline with multiple frames - should fail initially."""
|
||||
results = []
|
||||
|
||||
# Process multiple frames
|
||||
for i in range(5):
|
||||
csi_data = CSIData(
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id=f"router_{i % 2 + 1:03d}", # Alternate between router_001 and router_002
|
||||
amplitude=np.random.rand(64, 32),
|
||||
phase=np.random.rand(64, 32),
|
||||
frequency=5.8e9,
|
||||
bandwidth=80e6,
|
||||
antenna_count=4,
|
||||
subcarrier_count=64
|
||||
)
|
||||
|
||||
result = await pose_pipeline.process_frame(csi_data)
|
||||
results.append(result)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == 5
|
||||
|
||||
# Verify each result
|
||||
for i, result in enumerate(results):
|
||||
assert result.frame_id != results[0].frame_id if i > 0 else True
|
||||
assert result.metadata["router_id"] in ["router_001", "router_002"]
|
||||
assert result.processing_time_ms > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_error_handling_should_fail_initially(self, csi_processor, pose_estimator, zone_manager, sample_csi_data):
|
||||
"""Test pipeline error handling - should fail initially."""
|
||||
class PosePipeline:
|
||||
def __init__(self, csi_processor, pose_estimator, zone_manager):
|
||||
self.csi_processor = csi_processor
|
||||
self.pose_estimator = pose_estimator
|
||||
self.zone_manager = zone_manager
|
||||
self.is_initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
await self.csi_processor.initialize()
|
||||
await self.pose_estimator.load_model()
|
||||
self.is_initialized = True
|
||||
|
||||
async def process_frame(self, csi_data):
|
||||
if not self.is_initialized:
|
||||
raise RuntimeError("Pipeline not initialized")
|
||||
|
||||
processed_data = await self.csi_processor.process_csi_data(csi_data)
|
||||
features = np.array(processed_data["features"])
|
||||
pose_data = await self.pose_estimator.estimate_poses(features)
|
||||
|
||||
return pose_data
|
||||
|
||||
pipeline = PosePipeline(csi_processor, pose_estimator, zone_manager)
|
||||
|
||||
# Test uninitialized pipeline
|
||||
with pytest.raises(RuntimeError, match="Pipeline not initialized"):
|
||||
await pipeline.process_frame(sample_csi_data)
|
||||
|
||||
# Initialize pipeline
|
||||
await pipeline.initialize()
|
||||
|
||||
# Test with disabled CSI processor
|
||||
csi_processor.set_processing_enabled(False)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Processing disabled"):
|
||||
await pipeline.process_frame(sample_csi_data)
|
||||
|
||||
# This assertion will fail initially
|
||||
assert True # Test completed successfully
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confidence_threshold_filtering_should_fail_initially(self, pose_pipeline, sample_csi_data):
|
||||
"""Test confidence threshold filtering - should fail initially."""
|
||||
# Set high confidence threshold
|
||||
pose_pipeline.pose_estimator.set_confidence_threshold(0.9)
|
||||
|
||||
result = await pose_pipeline.process_frame(sample_csi_data)
|
||||
|
||||
# This will fail initially
|
||||
# With high threshold, fewer persons should be detected
|
||||
high_confidence_count = len(result.persons)
|
||||
|
||||
# Set low confidence threshold
|
||||
pose_pipeline.pose_estimator.set_confidence_threshold(0.1)
|
||||
|
||||
result = await pose_pipeline.process_frame(sample_csi_data)
|
||||
low_confidence_count = len(result.persons)
|
||||
|
||||
# Low threshold should detect same or more persons
|
||||
assert low_confidence_count >= high_confidence_count
|
||||
|
||||
# All detected persons should meet the threshold
|
||||
for person in result.persons:
|
||||
assert person["confidence"] >= 0.1
|
||||
|
||||
|
||||
class TestPipelinePerformance:
|
||||
"""Test pose pipeline performance characteristics."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_throughput_should_fail_initially(self, pose_pipeline):
|
||||
"""Test pipeline throughput - should fail initially."""
|
||||
frame_count = 10
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
# Process multiple frames
|
||||
for i in range(frame_count):
|
||||
csi_data = CSIData(
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
amplitude=np.random.rand(64, 32),
|
||||
phase=np.random.rand(64, 32),
|
||||
frequency=5.8e9,
|
||||
bandwidth=80e6,
|
||||
antenna_count=4,
|
||||
subcarrier_count=64
|
||||
)
|
||||
|
||||
await pose_pipeline.process_frame(csi_data)
|
||||
|
||||
end_time = datetime.utcnow()
|
||||
total_time = (end_time - start_time).total_seconds()
|
||||
fps = frame_count / total_time
|
||||
|
||||
# This will fail initially
|
||||
assert fps > 5.0 # Should process at least 5 FPS
|
||||
assert total_time < 5.0 # Should complete 10 frames in under 5 seconds
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_frame_processing_should_fail_initially(self, pose_pipeline):
|
||||
"""Test concurrent frame processing - should fail initially."""
|
||||
async def process_single_frame(frame_id: int):
|
||||
csi_data = CSIData(
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id=f"router_{frame_id % 3 + 1:03d}",
|
||||
amplitude=np.random.rand(64, 32),
|
||||
phase=np.random.rand(64, 32),
|
||||
frequency=5.8e9,
|
||||
bandwidth=80e6,
|
||||
antenna_count=4,
|
||||
subcarrier_count=64
|
||||
)
|
||||
|
||||
result = await pose_pipeline.process_frame(csi_data)
|
||||
return result.frame_id
|
||||
|
||||
# Process frames concurrently
|
||||
tasks = [process_single_frame(i) for i in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == 5
|
||||
assert len(set(results)) == 5 # All frame IDs should be unique
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_usage_stability_should_fail_initially(self, pose_pipeline):
|
||||
"""Test memory usage stability - should fail initially."""
|
||||
import psutil
|
||||
import os
|
||||
|
||||
process = psutil.Process(os.getpid())
|
||||
initial_memory = process.memory_info().rss
|
||||
|
||||
# Process many frames
|
||||
for i in range(50):
|
||||
csi_data = CSIData(
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
amplitude=np.random.rand(64, 32),
|
||||
phase=np.random.rand(64, 32),
|
||||
frequency=5.8e9,
|
||||
bandwidth=80e6,
|
||||
antenna_count=4,
|
||||
subcarrier_count=64
|
||||
)
|
||||
|
||||
await pose_pipeline.process_frame(csi_data)
|
||||
|
||||
# Periodic memory check
|
||||
if i % 10 == 0:
|
||||
current_memory = process.memory_info().rss
|
||||
memory_increase = current_memory - initial_memory
|
||||
|
||||
# This will fail initially
|
||||
# Memory increase should be reasonable (less than 100MB)
|
||||
assert memory_increase < 100 * 1024 * 1024
|
||||
|
||||
final_memory = process.memory_info().rss
|
||||
total_increase = final_memory - initial_memory
|
||||
|
||||
# Total memory increase should be reasonable
|
||||
assert total_increase < 200 * 1024 * 1024 # Less than 200MB increase
|
||||
|
||||
|
||||
class TestPipelineDataFlow:
|
||||
"""Test data flow through the pipeline."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_transformation_chain_should_fail_initially(self, csi_processor, pose_estimator, zone_manager, sample_csi_data):
|
||||
"""Test data transformation through the pipeline - should fail initially."""
|
||||
# Step 1: CSI processing
|
||||
await csi_processor.initialize()
|
||||
processed_data = await csi_processor.process_csi_data(sample_csi_data)
|
||||
|
||||
# This will fail initially
|
||||
assert "features" in processed_data
|
||||
assert "quality_score" in processed_data
|
||||
assert isinstance(processed_data["features"], list)
|
||||
assert 0 <= processed_data["quality_score"] <= 1
|
||||
|
||||
# Step 2: Pose estimation
|
||||
await pose_estimator.load_model()
|
||||
features = np.array(processed_data["features"])
|
||||
pose_data = await pose_estimator.estimate_poses(features)
|
||||
|
||||
assert "persons" in pose_data
|
||||
assert "processing_time_ms" in pose_data
|
||||
assert isinstance(pose_data["persons"], list)
|
||||
|
||||
# Step 3: Zone assignment
|
||||
zone_summary = zone_manager.assign_persons_to_zones(pose_data["persons"])
|
||||
|
||||
assert isinstance(zone_summary, dict)
|
||||
assert all(isinstance(count, int) for count in zone_summary.values())
|
||||
|
||||
# Verify person zone assignments
|
||||
for person in pose_data["persons"]:
|
||||
if "zone_id" in person and person["zone_id"]:
|
||||
assert person["zone_id"] in zone_summary
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_state_consistency_should_fail_initially(self, pose_pipeline, sample_csi_data):
|
||||
"""Test pipeline state consistency - should fail initially."""
|
||||
# Process the same frame multiple times
|
||||
results = []
|
||||
for _ in range(3):
|
||||
result = await pose_pipeline.process_frame(sample_csi_data)
|
||||
results.append(result)
|
||||
|
||||
# This will fail initially
|
||||
# Results should be consistent (same input should produce similar output)
|
||||
assert len(results) == 3
|
||||
|
||||
# All results should have the same router_id
|
||||
router_ids = [r.metadata["router_id"] for r in results]
|
||||
assert all(rid == router_ids[0] for rid in router_ids)
|
||||
|
||||
# Processing times should be reasonable and similar
|
||||
processing_times = [r.processing_time_ms for r in results]
|
||||
assert all(10 <= pt <= 200 for pt in processing_times) # Between 10ms and 200ms
|
||||
565
v1/tests/integration/test_rate_limiting.py
Normal file
565
v1/tests/integration/test_rate_limiting.py
Normal file
@@ -0,0 +1,565 @@
|
||||
"""
|
||||
Integration tests for rate limiting functionality.
|
||||
|
||||
Tests rate limit behavior, throttling, and quota management.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import time
|
||||
|
||||
from fastapi import HTTPException, status, Request, Response
|
||||
|
||||
|
||||
class MockRateLimiter:
|
||||
"""Mock rate limiter for testing."""
|
||||
|
||||
def __init__(self, requests_per_minute: int = 60, requests_per_hour: int = 1000):
|
||||
self.requests_per_minute = requests_per_minute
|
||||
self.requests_per_hour = requests_per_hour
|
||||
self.request_history = {}
|
||||
self.blocked_clients = set()
|
||||
|
||||
def _get_client_key(self, client_id: str, endpoint: str = None) -> str:
|
||||
"""Get client key for rate limiting."""
|
||||
return f"{client_id}:{endpoint}" if endpoint else client_id
|
||||
|
||||
def _cleanup_old_requests(self, client_key: str):
|
||||
"""Clean up old request records."""
|
||||
if client_key not in self.request_history:
|
||||
return
|
||||
|
||||
now = datetime.utcnow()
|
||||
minute_ago = now - timedelta(minutes=1)
|
||||
hour_ago = now - timedelta(hours=1)
|
||||
|
||||
# Keep only requests from the last hour
|
||||
self.request_history[client_key] = [
|
||||
req_time for req_time in self.request_history[client_key]
|
||||
if req_time > hour_ago
|
||||
]
|
||||
|
||||
def check_rate_limit(self, client_id: str, endpoint: str = None) -> Dict[str, Any]:
|
||||
"""Check if client is within rate limits."""
|
||||
client_key = self._get_client_key(client_id, endpoint)
|
||||
|
||||
if client_id in self.blocked_clients:
|
||||
return {
|
||||
"allowed": False,
|
||||
"reason": "Client blocked",
|
||||
"retry_after": 3600 # 1 hour
|
||||
}
|
||||
|
||||
self._cleanup_old_requests(client_key)
|
||||
|
||||
if client_key not in self.request_history:
|
||||
self.request_history[client_key] = []
|
||||
|
||||
now = datetime.utcnow()
|
||||
minute_ago = now - timedelta(minutes=1)
|
||||
|
||||
# Count requests in the last minute
|
||||
recent_requests = [
|
||||
req_time for req_time in self.request_history[client_key]
|
||||
if req_time > minute_ago
|
||||
]
|
||||
|
||||
# Count requests in the last hour
|
||||
hour_requests = len(self.request_history[client_key])
|
||||
|
||||
if len(recent_requests) >= self.requests_per_minute:
|
||||
return {
|
||||
"allowed": False,
|
||||
"reason": "Rate limit exceeded (per minute)",
|
||||
"retry_after": 60,
|
||||
"current_requests": len(recent_requests),
|
||||
"limit": self.requests_per_minute
|
||||
}
|
||||
|
||||
if hour_requests >= self.requests_per_hour:
|
||||
return {
|
||||
"allowed": False,
|
||||
"reason": "Rate limit exceeded (per hour)",
|
||||
"retry_after": 3600,
|
||||
"current_requests": hour_requests,
|
||||
"limit": self.requests_per_hour
|
||||
}
|
||||
|
||||
# Record this request
|
||||
self.request_history[client_key].append(now)
|
||||
|
||||
return {
|
||||
"allowed": True,
|
||||
"remaining_minute": self.requests_per_minute - len(recent_requests) - 1,
|
||||
"remaining_hour": self.requests_per_hour - hour_requests - 1,
|
||||
"reset_time": minute_ago + timedelta(minutes=1)
|
||||
}
|
||||
|
||||
def block_client(self, client_id: str):
|
||||
"""Block a client."""
|
||||
self.blocked_clients.add(client_id)
|
||||
|
||||
def unblock_client(self, client_id: str):
|
||||
"""Unblock a client."""
|
||||
self.blocked_clients.discard(client_id)
|
||||
|
||||
|
||||
class TestRateLimitingBasic:
|
||||
"""Test basic rate limiting functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def rate_limiter(self):
|
||||
"""Create rate limiter for testing."""
|
||||
return MockRateLimiter(requests_per_minute=5, requests_per_hour=100)
|
||||
|
||||
def test_rate_limit_within_bounds_should_fail_initially(self, rate_limiter):
|
||||
"""Test rate limiting within bounds - should fail initially."""
|
||||
client_id = "test-client-001"
|
||||
|
||||
# Make requests within limit
|
||||
for i in range(3):
|
||||
result = rate_limiter.check_rate_limit(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert result["allowed"] is True
|
||||
assert "remaining_minute" in result
|
||||
assert "remaining_hour" in result
|
||||
|
||||
def test_rate_limit_per_minute_exceeded_should_fail_initially(self, rate_limiter):
|
||||
"""Test per-minute rate limit exceeded - should fail initially."""
|
||||
client_id = "test-client-002"
|
||||
|
||||
# Make requests up to the limit
|
||||
for i in range(5):
|
||||
result = rate_limiter.check_rate_limit(client_id)
|
||||
assert result["allowed"] is True
|
||||
|
||||
# Next request should be blocked
|
||||
result = rate_limiter.check_rate_limit(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert result["allowed"] is False
|
||||
assert "per minute" in result["reason"]
|
||||
assert result["retry_after"] == 60
|
||||
assert result["current_requests"] == 5
|
||||
assert result["limit"] == 5
|
||||
|
||||
def test_rate_limit_per_hour_exceeded_should_fail_initially(self, rate_limiter):
|
||||
"""Test per-hour rate limit exceeded - should fail initially."""
|
||||
# Create rate limiter with very low hour limit for testing
|
||||
limiter = MockRateLimiter(requests_per_minute=10, requests_per_hour=3)
|
||||
client_id = "test-client-003"
|
||||
|
||||
# Make requests up to hour limit
|
||||
for i in range(3):
|
||||
result = limiter.check_rate_limit(client_id)
|
||||
assert result["allowed"] is True
|
||||
|
||||
# Next request should be blocked
|
||||
result = limiter.check_rate_limit(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert result["allowed"] is False
|
||||
assert "per hour" in result["reason"]
|
||||
assert result["retry_after"] == 3600
|
||||
|
||||
def test_blocked_client_should_fail_initially(self, rate_limiter):
|
||||
"""Test blocked client handling - should fail initially."""
|
||||
client_id = "blocked-client"
|
||||
|
||||
# Block the client
|
||||
rate_limiter.block_client(client_id)
|
||||
|
||||
# Request should be blocked
|
||||
result = rate_limiter.check_rate_limit(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert result["allowed"] is False
|
||||
assert result["reason"] == "Client blocked"
|
||||
assert result["retry_after"] == 3600
|
||||
|
||||
# Unblock and test
|
||||
rate_limiter.unblock_client(client_id)
|
||||
result = rate_limiter.check_rate_limit(client_id)
|
||||
assert result["allowed"] is True
|
||||
|
||||
def test_endpoint_specific_rate_limiting_should_fail_initially(self, rate_limiter):
|
||||
"""Test endpoint-specific rate limiting - should fail initially."""
|
||||
client_id = "test-client-004"
|
||||
|
||||
# Make requests to different endpoints
|
||||
result1 = rate_limiter.check_rate_limit(client_id, "/api/pose/current")
|
||||
result2 = rate_limiter.check_rate_limit(client_id, "/api/stream/status")
|
||||
|
||||
# This will fail initially
|
||||
assert result1["allowed"] is True
|
||||
assert result2["allowed"] is True
|
||||
|
||||
# Each endpoint should have separate rate limiting
|
||||
for i in range(4):
|
||||
rate_limiter.check_rate_limit(client_id, "/api/pose/current")
|
||||
|
||||
# Pose endpoint should be at limit, but stream should still work
|
||||
pose_result = rate_limiter.check_rate_limit(client_id, "/api/pose/current")
|
||||
stream_result = rate_limiter.check_rate_limit(client_id, "/api/stream/status")
|
||||
|
||||
assert pose_result["allowed"] is False
|
||||
assert stream_result["allowed"] is True
|
||||
|
||||
|
||||
class TestRateLimitMiddleware:
|
||||
"""Test rate limiting middleware functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request(self):
|
||||
"""Mock FastAPI request."""
|
||||
class MockRequest:
|
||||
def __init__(self, client_ip="127.0.0.1", path="/api/test", method="GET"):
|
||||
self.client = MagicMock()
|
||||
self.client.host = client_ip
|
||||
self.url = MagicMock()
|
||||
self.url.path = path
|
||||
self.method = method
|
||||
self.headers = {}
|
||||
self.state = MagicMock()
|
||||
|
||||
return MockRequest
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response(self):
|
||||
"""Mock FastAPI response."""
|
||||
class MockResponse:
|
||||
def __init__(self):
|
||||
self.status_code = 200
|
||||
self.headers = {}
|
||||
|
||||
return MockResponse()
|
||||
|
||||
@pytest.fixture
|
||||
def rate_limit_middleware(self, rate_limiter):
|
||||
"""Create rate limiting middleware."""
|
||||
class RateLimitMiddleware:
|
||||
def __init__(self, rate_limiter):
|
||||
self.rate_limiter = rate_limiter
|
||||
|
||||
async def __call__(self, request, call_next):
|
||||
# Get client identifier
|
||||
client_id = self._get_client_id(request)
|
||||
endpoint = request.url.path
|
||||
|
||||
# Check rate limit
|
||||
limit_result = self.rate_limiter.check_rate_limit(client_id, endpoint)
|
||||
|
||||
if not limit_result["allowed"]:
|
||||
# Return rate limit exceeded response
|
||||
response = Response(
|
||||
content=f"Rate limit exceeded: {limit_result['reason']}",
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS
|
||||
)
|
||||
response.headers["Retry-After"] = str(limit_result["retry_after"])
|
||||
response.headers["X-RateLimit-Limit"] = str(limit_result.get("limit", "unknown"))
|
||||
response.headers["X-RateLimit-Remaining"] = "0"
|
||||
return response
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add rate limit headers
|
||||
response.headers["X-RateLimit-Limit"] = str(self.rate_limiter.requests_per_minute)
|
||||
response.headers["X-RateLimit-Remaining"] = str(limit_result.get("remaining_minute", 0))
|
||||
response.headers["X-RateLimit-Reset"] = str(int(limit_result.get("reset_time", datetime.utcnow()).timestamp()))
|
||||
|
||||
return response
|
||||
|
||||
def _get_client_id(self, request):
|
||||
"""Get client identifier from request."""
|
||||
# Check for API key in headers
|
||||
api_key = request.headers.get("X-API-Key")
|
||||
if api_key:
|
||||
return f"api:{api_key}"
|
||||
|
||||
# Check for user ID in request state (from auth)
|
||||
if hasattr(request.state, "user") and request.state.user:
|
||||
return f"user:{request.state.user.get('id', 'unknown')}"
|
||||
|
||||
# Fall back to IP address
|
||||
return f"ip:{request.client.host}"
|
||||
|
||||
return RateLimitMiddleware(rate_limiter)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_allows_normal_requests_should_fail_initially(
|
||||
self, rate_limit_middleware, mock_request, mock_response
|
||||
):
|
||||
"""Test middleware allows normal requests - should fail initially."""
|
||||
request = mock_request()
|
||||
|
||||
async def mock_call_next(req):
|
||||
return mock_response
|
||||
|
||||
response = await rate_limit_middleware(request, mock_call_next)
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
assert "X-RateLimit-Limit" in response.headers
|
||||
assert "X-RateLimit-Remaining" in response.headers
|
||||
assert "X-RateLimit-Reset" in response.headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_blocks_excessive_requests_should_fail_initially(
|
||||
self, rate_limit_middleware, mock_request
|
||||
):
|
||||
"""Test middleware blocks excessive requests - should fail initially."""
|
||||
request = mock_request()
|
||||
|
||||
async def mock_call_next(req):
|
||||
response = Response(content="OK", status_code=200)
|
||||
return response
|
||||
|
||||
# Make requests up to the limit
|
||||
for i in range(5):
|
||||
response = await rate_limit_middleware(request, mock_call_next)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Next request should be blocked
|
||||
response = await rate_limit_middleware(request, mock_call_next)
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
assert "Retry-After" in response.headers
|
||||
assert "X-RateLimit-Remaining" in response.headers
|
||||
assert response.headers["X-RateLimit-Remaining"] == "0"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_client_identification_should_fail_initially(
|
||||
self, rate_limit_middleware, mock_request
|
||||
):
|
||||
"""Test middleware client identification - should fail initially."""
|
||||
# Test API key identification
|
||||
request_with_api_key = mock_request()
|
||||
request_with_api_key.headers["X-API-Key"] = "test-api-key-123"
|
||||
|
||||
# Test user identification
|
||||
request_with_user = mock_request()
|
||||
request_with_user.state.user = {"id": "user-123"}
|
||||
|
||||
# Test IP identification
|
||||
request_with_ip = mock_request(client_ip="192.168.1.100")
|
||||
|
||||
async def mock_call_next(req):
|
||||
return Response(content="OK", status_code=200)
|
||||
|
||||
# Each should be treated as different clients
|
||||
response1 = await rate_limit_middleware(request_with_api_key, mock_call_next)
|
||||
response2 = await rate_limit_middleware(request_with_user, mock_call_next)
|
||||
response3 = await rate_limit_middleware(request_with_ip, mock_call_next)
|
||||
|
||||
# This will fail initially
|
||||
assert response1.status_code == 200
|
||||
assert response2.status_code == 200
|
||||
assert response3.status_code == 200
|
||||
|
||||
|
||||
class TestRateLimitingStrategies:
|
||||
"""Test different rate limiting strategies."""
|
||||
|
||||
@pytest.fixture
|
||||
def sliding_window_limiter(self):
|
||||
"""Create sliding window rate limiter."""
|
||||
class SlidingWindowLimiter:
|
||||
def __init__(self, window_size_seconds: int = 60, max_requests: int = 10):
|
||||
self.window_size = window_size_seconds
|
||||
self.max_requests = max_requests
|
||||
self.request_times = {}
|
||||
|
||||
def check_limit(self, client_id: str) -> Dict[str, Any]:
|
||||
now = time.time()
|
||||
|
||||
if client_id not in self.request_times:
|
||||
self.request_times[client_id] = []
|
||||
|
||||
# Remove old requests outside the window
|
||||
cutoff_time = now - self.window_size
|
||||
self.request_times[client_id] = [
|
||||
req_time for req_time in self.request_times[client_id]
|
||||
if req_time > cutoff_time
|
||||
]
|
||||
|
||||
# Check if we're at the limit
|
||||
if len(self.request_times[client_id]) >= self.max_requests:
|
||||
oldest_request = min(self.request_times[client_id])
|
||||
retry_after = int(oldest_request + self.window_size - now)
|
||||
|
||||
return {
|
||||
"allowed": False,
|
||||
"retry_after": max(retry_after, 1),
|
||||
"current_requests": len(self.request_times[client_id]),
|
||||
"limit": self.max_requests
|
||||
}
|
||||
|
||||
# Record this request
|
||||
self.request_times[client_id].append(now)
|
||||
|
||||
return {
|
||||
"allowed": True,
|
||||
"remaining": self.max_requests - len(self.request_times[client_id]),
|
||||
"window_reset": int(now + self.window_size)
|
||||
}
|
||||
|
||||
return SlidingWindowLimiter(window_size_seconds=10, max_requests=3)
|
||||
|
||||
@pytest.fixture
|
||||
def token_bucket_limiter(self):
|
||||
"""Create token bucket rate limiter."""
|
||||
class TokenBucketLimiter:
|
||||
def __init__(self, capacity: int = 10, refill_rate: float = 1.0):
|
||||
self.capacity = capacity
|
||||
self.refill_rate = refill_rate # tokens per second
|
||||
self.buckets = {}
|
||||
|
||||
def check_limit(self, client_id: str) -> Dict[str, Any]:
|
||||
now = time.time()
|
||||
|
||||
if client_id not in self.buckets:
|
||||
self.buckets[client_id] = {
|
||||
"tokens": self.capacity,
|
||||
"last_refill": now
|
||||
}
|
||||
|
||||
bucket = self.buckets[client_id]
|
||||
|
||||
# Refill tokens based on time elapsed
|
||||
time_elapsed = now - bucket["last_refill"]
|
||||
tokens_to_add = time_elapsed * self.refill_rate
|
||||
bucket["tokens"] = min(self.capacity, bucket["tokens"] + tokens_to_add)
|
||||
bucket["last_refill"] = now
|
||||
|
||||
# Check if we have tokens available
|
||||
if bucket["tokens"] < 1:
|
||||
return {
|
||||
"allowed": False,
|
||||
"retry_after": int((1 - bucket["tokens"]) / self.refill_rate),
|
||||
"tokens_remaining": bucket["tokens"]
|
||||
}
|
||||
|
||||
# Consume a token
|
||||
bucket["tokens"] -= 1
|
||||
|
||||
return {
|
||||
"allowed": True,
|
||||
"tokens_remaining": bucket["tokens"]
|
||||
}
|
||||
|
||||
return TokenBucketLimiter(capacity=5, refill_rate=0.5) # 0.5 tokens per second
|
||||
|
||||
def test_sliding_window_limiter_should_fail_initially(self, sliding_window_limiter):
|
||||
"""Test sliding window rate limiter - should fail initially."""
|
||||
client_id = "sliding-test-client"
|
||||
|
||||
# Make requests up to limit
|
||||
for i in range(3):
|
||||
result = sliding_window_limiter.check_limit(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert result["allowed"] is True
|
||||
assert "remaining" in result
|
||||
|
||||
# Next request should be blocked
|
||||
result = sliding_window_limiter.check_limit(client_id)
|
||||
assert result["allowed"] is False
|
||||
assert result["current_requests"] == 3
|
||||
assert result["limit"] == 3
|
||||
|
||||
def test_token_bucket_limiter_should_fail_initially(self, token_bucket_limiter):
|
||||
"""Test token bucket rate limiter - should fail initially."""
|
||||
client_id = "bucket-test-client"
|
||||
|
||||
# Make requests up to capacity
|
||||
for i in range(5):
|
||||
result = token_bucket_limiter.check_limit(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert result["allowed"] is True
|
||||
assert "tokens_remaining" in result
|
||||
|
||||
# Next request should be blocked (no tokens left)
|
||||
result = token_bucket_limiter.check_limit(client_id)
|
||||
assert result["allowed"] is False
|
||||
assert result["tokens_remaining"] < 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_bucket_refill_should_fail_initially(self, token_bucket_limiter):
|
||||
"""Test token bucket refill mechanism - should fail initially."""
|
||||
client_id = "refill-test-client"
|
||||
|
||||
# Exhaust all tokens
|
||||
for i in range(5):
|
||||
token_bucket_limiter.check_limit(client_id)
|
||||
|
||||
# Should be blocked
|
||||
result = token_bucket_limiter.check_limit(client_id)
|
||||
assert result["allowed"] is False
|
||||
|
||||
# Wait for refill (simulate time passing)
|
||||
await asyncio.sleep(2.1) # Wait for 1 token to be refilled (0.5 tokens/sec * 2.1 sec > 1)
|
||||
|
||||
# Should now be allowed
|
||||
result = token_bucket_limiter.check_limit(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert result["allowed"] is True
|
||||
|
||||
|
||||
class TestRateLimitingPerformance:
|
||||
"""Test rate limiting performance characteristics."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_rate_limit_checks_should_fail_initially(self):
|
||||
"""Test concurrent rate limit checks - should fail initially."""
|
||||
rate_limiter = MockRateLimiter(requests_per_minute=100, requests_per_hour=1000)
|
||||
|
||||
async def make_request(client_id: str, request_id: int):
|
||||
result = rate_limiter.check_rate_limit(f"{client_id}-{request_id}")
|
||||
return result["allowed"]
|
||||
|
||||
# Create many concurrent requests
|
||||
tasks = [
|
||||
make_request("concurrent-client", i)
|
||||
for i in range(50)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == 50
|
||||
assert all(results) # All should be allowed since they're different clients
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limiter_memory_cleanup_should_fail_initially(self):
|
||||
"""Test rate limiter memory cleanup - should fail initially."""
|
||||
rate_limiter = MockRateLimiter(requests_per_minute=10, requests_per_hour=100)
|
||||
|
||||
# Make requests for many different clients
|
||||
for i in range(100):
|
||||
rate_limiter.check_rate_limit(f"client-{i}")
|
||||
|
||||
initial_memory_size = len(rate_limiter.request_history)
|
||||
|
||||
# Simulate time passing and cleanup
|
||||
for client_key in list(rate_limiter.request_history.keys()):
|
||||
rate_limiter._cleanup_old_requests(client_key)
|
||||
|
||||
# This will fail initially
|
||||
assert initial_memory_size == 100
|
||||
|
||||
# After cleanup, old entries should be removed
|
||||
# (In a real implementation, this would clean up old timestamps)
|
||||
final_memory_size = len([
|
||||
key for key, history in rate_limiter.request_history.items()
|
||||
if history # Only count non-empty histories
|
||||
])
|
||||
|
||||
assert final_memory_size <= initial_memory_size
|
||||
729
v1/tests/integration/test_streaming_pipeline.py
Normal file
729
v1/tests/integration/test_streaming_pipeline.py
Normal file
@@ -0,0 +1,729 @@
|
||||
"""
|
||||
Integration tests for real-time streaming pipeline.
|
||||
|
||||
Tests the complete real-time data flow from CSI collection to client delivery.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional, AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import json
|
||||
import queue
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamFrame:
|
||||
"""Streaming frame data structure."""
|
||||
frame_id: str
|
||||
timestamp: datetime
|
||||
router_id: str
|
||||
pose_data: Dict[str, Any]
|
||||
processing_time_ms: float
|
||||
quality_score: float
|
||||
|
||||
|
||||
class MockStreamBuffer:
|
||||
"""Mock streaming buffer for testing."""
|
||||
|
||||
def __init__(self, max_size: int = 100):
|
||||
self.max_size = max_size
|
||||
self.buffer = asyncio.Queue(maxsize=max_size)
|
||||
self.dropped_frames = 0
|
||||
self.total_frames = 0
|
||||
|
||||
async def put_frame(self, frame: StreamFrame) -> bool:
|
||||
"""Add frame to buffer."""
|
||||
self.total_frames += 1
|
||||
|
||||
try:
|
||||
self.buffer.put_nowait(frame)
|
||||
return True
|
||||
except asyncio.QueueFull:
|
||||
self.dropped_frames += 1
|
||||
return False
|
||||
|
||||
async def get_frame(self, timeout: float = 1.0) -> Optional[StreamFrame]:
|
||||
"""Get frame from buffer."""
|
||||
try:
|
||||
return await asyncio.wait_for(self.buffer.get(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
return None
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get buffer statistics."""
|
||||
return {
|
||||
"buffer_size": self.buffer.qsize(),
|
||||
"max_size": self.max_size,
|
||||
"total_frames": self.total_frames,
|
||||
"dropped_frames": self.dropped_frames,
|
||||
"drop_rate": self.dropped_frames / max(self.total_frames, 1)
|
||||
}
|
||||
|
||||
|
||||
class MockStreamProcessor:
|
||||
"""Mock stream processor for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.is_running = False
|
||||
self.processing_rate = 30 # FPS
|
||||
self.frame_counter = 0
|
||||
self.error_rate = 0.0
|
||||
|
||||
async def start_processing(self, input_buffer: MockStreamBuffer, output_buffer: MockStreamBuffer):
|
||||
"""Start stream processing."""
|
||||
self.is_running = True
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
# Get frame from input
|
||||
frame = await input_buffer.get_frame(timeout=0.1)
|
||||
if frame is None:
|
||||
continue
|
||||
|
||||
# Simulate processing error
|
||||
if np.random.random() < self.error_rate:
|
||||
continue # Skip frame due to error
|
||||
|
||||
# Process frame
|
||||
processed_frame = await self._process_frame(frame)
|
||||
|
||||
# Put to output buffer
|
||||
await output_buffer.put_frame(processed_frame)
|
||||
|
||||
# Control processing rate
|
||||
await asyncio.sleep(1.0 / self.processing_rate)
|
||||
|
||||
except Exception as e:
|
||||
# Handle processing errors
|
||||
continue
|
||||
|
||||
async def _process_frame(self, frame: StreamFrame) -> StreamFrame:
|
||||
"""Process a single frame."""
|
||||
# Simulate processing time
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Add processing metadata
|
||||
processed_pose_data = frame.pose_data.copy()
|
||||
processed_pose_data["processed_at"] = datetime.utcnow().isoformat()
|
||||
processed_pose_data["processor_id"] = "stream_processor_001"
|
||||
|
||||
return StreamFrame(
|
||||
frame_id=f"processed_{frame.frame_id}",
|
||||
timestamp=frame.timestamp,
|
||||
router_id=frame.router_id,
|
||||
pose_data=processed_pose_data,
|
||||
processing_time_ms=frame.processing_time_ms + 10, # Add processing overhead
|
||||
quality_score=frame.quality_score * 0.95 # Slight quality degradation
|
||||
)
|
||||
|
||||
def stop_processing(self):
|
||||
"""Stop stream processing."""
|
||||
self.is_running = False
|
||||
|
||||
def set_error_rate(self, error_rate: float):
|
||||
"""Set processing error rate."""
|
||||
self.error_rate = error_rate
|
||||
|
||||
|
||||
class MockWebSocketManager:
|
||||
"""Mock WebSocket manager for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.connected_clients = {}
|
||||
self.message_queue = asyncio.Queue()
|
||||
self.total_messages_sent = 0
|
||||
self.failed_sends = 0
|
||||
|
||||
async def add_client(self, client_id: str, websocket_mock) -> bool:
|
||||
"""Add WebSocket client."""
|
||||
if client_id in self.connected_clients:
|
||||
return False
|
||||
|
||||
self.connected_clients[client_id] = {
|
||||
"websocket": websocket_mock,
|
||||
"connected_at": datetime.utcnow(),
|
||||
"messages_sent": 0,
|
||||
"last_ping": datetime.utcnow()
|
||||
}
|
||||
return True
|
||||
|
||||
async def remove_client(self, client_id: str) -> bool:
|
||||
"""Remove WebSocket client."""
|
||||
if client_id in self.connected_clients:
|
||||
del self.connected_clients[client_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
async def broadcast_frame(self, frame: StreamFrame) -> Dict[str, bool]:
|
||||
"""Broadcast frame to all connected clients."""
|
||||
results = {}
|
||||
|
||||
message = {
|
||||
"type": "pose_update",
|
||||
"frame_id": frame.frame_id,
|
||||
"timestamp": frame.timestamp.isoformat(),
|
||||
"router_id": frame.router_id,
|
||||
"pose_data": frame.pose_data,
|
||||
"processing_time_ms": frame.processing_time_ms,
|
||||
"quality_score": frame.quality_score
|
||||
}
|
||||
|
||||
for client_id, client_info in self.connected_clients.items():
|
||||
try:
|
||||
# Simulate WebSocket send
|
||||
success = await self._send_to_client(client_id, message)
|
||||
results[client_id] = success
|
||||
|
||||
if success:
|
||||
client_info["messages_sent"] += 1
|
||||
self.total_messages_sent += 1
|
||||
else:
|
||||
self.failed_sends += 1
|
||||
|
||||
except Exception:
|
||||
results[client_id] = False
|
||||
self.failed_sends += 1
|
||||
|
||||
return results
|
||||
|
||||
async def _send_to_client(self, client_id: str, message: Dict[str, Any]) -> bool:
|
||||
"""Send message to specific client."""
|
||||
# Simulate network issues
|
||||
if np.random.random() < 0.05: # 5% failure rate
|
||||
return False
|
||||
|
||||
# Simulate send delay
|
||||
await asyncio.sleep(0.001)
|
||||
return True
|
||||
|
||||
def get_client_stats(self) -> Dict[str, Any]:
|
||||
"""Get client statistics."""
|
||||
return {
|
||||
"connected_clients": len(self.connected_clients),
|
||||
"total_messages_sent": self.total_messages_sent,
|
||||
"failed_sends": self.failed_sends,
|
||||
"clients": {
|
||||
client_id: {
|
||||
"messages_sent": info["messages_sent"],
|
||||
"connected_duration": (datetime.utcnow() - info["connected_at"]).total_seconds()
|
||||
}
|
||||
for client_id, info in self.connected_clients.items()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TestStreamingPipelineBasic:
|
||||
"""Test basic streaming pipeline functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def stream_buffer(self):
|
||||
"""Create stream buffer."""
|
||||
return MockStreamBuffer(max_size=50)
|
||||
|
||||
@pytest.fixture
|
||||
def stream_processor(self):
|
||||
"""Create stream processor."""
|
||||
return MockStreamProcessor()
|
||||
|
||||
@pytest.fixture
|
||||
def websocket_manager(self):
|
||||
"""Create WebSocket manager."""
|
||||
return MockWebSocketManager()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_frame(self):
|
||||
"""Create sample stream frame."""
|
||||
return StreamFrame(
|
||||
frame_id="frame_001",
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
pose_data={
|
||||
"persons": [
|
||||
{
|
||||
"person_id": "person_1",
|
||||
"confidence": 0.85,
|
||||
"bounding_box": {"x": 100, "y": 150, "width": 80, "height": 180},
|
||||
"activity": "standing"
|
||||
}
|
||||
],
|
||||
"zone_summary": {"zone1": 1, "zone2": 0}
|
||||
},
|
||||
processing_time_ms=45.2,
|
||||
quality_score=0.92
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffer_frame_operations_should_fail_initially(self, stream_buffer, sample_frame):
|
||||
"""Test buffer frame operations - should fail initially."""
|
||||
# Put frame in buffer
|
||||
result = await stream_buffer.put_frame(sample_frame)
|
||||
|
||||
# This will fail initially
|
||||
assert result is True
|
||||
|
||||
# Get frame from buffer
|
||||
retrieved_frame = await stream_buffer.get_frame()
|
||||
assert retrieved_frame is not None
|
||||
assert retrieved_frame.frame_id == sample_frame.frame_id
|
||||
assert retrieved_frame.router_id == sample_frame.router_id
|
||||
|
||||
# Buffer should be empty now
|
||||
empty_frame = await stream_buffer.get_frame(timeout=0.1)
|
||||
assert empty_frame is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffer_overflow_handling_should_fail_initially(self, sample_frame):
|
||||
"""Test buffer overflow handling - should fail initially."""
|
||||
small_buffer = MockStreamBuffer(max_size=2)
|
||||
|
||||
# Fill buffer to capacity
|
||||
result1 = await small_buffer.put_frame(sample_frame)
|
||||
result2 = await small_buffer.put_frame(sample_frame)
|
||||
|
||||
# This will fail initially
|
||||
assert result1 is True
|
||||
assert result2 is True
|
||||
|
||||
# Next frame should be dropped
|
||||
result3 = await small_buffer.put_frame(sample_frame)
|
||||
assert result3 is False
|
||||
|
||||
# Check statistics
|
||||
stats = small_buffer.get_stats()
|
||||
assert stats["total_frames"] == 3
|
||||
assert stats["dropped_frames"] == 1
|
||||
assert stats["drop_rate"] > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_processing_should_fail_initially(self, stream_processor, sample_frame):
|
||||
"""Test stream processing - should fail initially."""
|
||||
input_buffer = MockStreamBuffer()
|
||||
output_buffer = MockStreamBuffer()
|
||||
|
||||
# Add frame to input buffer
|
||||
await input_buffer.put_frame(sample_frame)
|
||||
|
||||
# Start processing task
|
||||
processing_task = asyncio.create_task(
|
||||
stream_processor.start_processing(input_buffer, output_buffer)
|
||||
)
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Stop processing
|
||||
stream_processor.stop_processing()
|
||||
await processing_task
|
||||
|
||||
# Check output
|
||||
processed_frame = await output_buffer.get_frame(timeout=0.1)
|
||||
|
||||
# This will fail initially
|
||||
assert processed_frame is not None
|
||||
assert processed_frame.frame_id.startswith("processed_")
|
||||
assert "processed_at" in processed_frame.pose_data
|
||||
assert processed_frame.processing_time_ms > sample_frame.processing_time_ms
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_client_management_should_fail_initially(self, websocket_manager):
|
||||
"""Test WebSocket client management - should fail initially."""
|
||||
mock_websocket = MagicMock()
|
||||
|
||||
# Add client
|
||||
result = await websocket_manager.add_client("client_001", mock_websocket)
|
||||
|
||||
# This will fail initially
|
||||
assert result is True
|
||||
assert "client_001" in websocket_manager.connected_clients
|
||||
|
||||
# Try to add duplicate client
|
||||
result = await websocket_manager.add_client("client_001", mock_websocket)
|
||||
assert result is False
|
||||
|
||||
# Remove client
|
||||
result = await websocket_manager.remove_client("client_001")
|
||||
assert result is True
|
||||
assert "client_001" not in websocket_manager.connected_clients
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_frame_broadcasting_should_fail_initially(self, websocket_manager, sample_frame):
|
||||
"""Test frame broadcasting - should fail initially."""
|
||||
# Add multiple clients
|
||||
for i in range(3):
|
||||
await websocket_manager.add_client(f"client_{i:03d}", MagicMock())
|
||||
|
||||
# Broadcast frame
|
||||
results = await websocket_manager.broadcast_frame(sample_frame)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == 3
|
||||
assert all(isinstance(success, bool) for success in results.values())
|
||||
|
||||
# Check statistics
|
||||
stats = websocket_manager.get_client_stats()
|
||||
assert stats["connected_clients"] == 3
|
||||
assert stats["total_messages_sent"] >= 0
|
||||
|
||||
|
||||
class TestStreamingPipelineIntegration:
|
||||
"""Test complete streaming pipeline integration."""
|
||||
|
||||
@pytest.fixture
|
||||
async def streaming_pipeline(self):
|
||||
"""Create complete streaming pipeline."""
|
||||
class StreamingPipeline:
|
||||
def __init__(self):
|
||||
self.input_buffer = MockStreamBuffer(max_size=100)
|
||||
self.output_buffer = MockStreamBuffer(max_size=100)
|
||||
self.processor = MockStreamProcessor()
|
||||
self.websocket_manager = MockWebSocketManager()
|
||||
self.is_running = False
|
||||
self.processing_task = None
|
||||
self.broadcasting_task = None
|
||||
|
||||
async def start(self):
|
||||
"""Start the streaming pipeline."""
|
||||
if self.is_running:
|
||||
return False
|
||||
|
||||
self.is_running = True
|
||||
|
||||
# Start processing task
|
||||
self.processing_task = asyncio.create_task(
|
||||
self.processor.start_processing(self.input_buffer, self.output_buffer)
|
||||
)
|
||||
|
||||
# Start broadcasting task
|
||||
self.broadcasting_task = asyncio.create_task(
|
||||
self._broadcast_loop()
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the streaming pipeline."""
|
||||
if not self.is_running:
|
||||
return False
|
||||
|
||||
self.is_running = False
|
||||
self.processor.stop_processing()
|
||||
|
||||
# Cancel tasks
|
||||
if self.processing_task:
|
||||
self.processing_task.cancel()
|
||||
if self.broadcasting_task:
|
||||
self.broadcasting_task.cancel()
|
||||
|
||||
return True
|
||||
|
||||
async def add_frame(self, frame: StreamFrame) -> bool:
|
||||
"""Add frame to pipeline."""
|
||||
return await self.input_buffer.put_frame(frame)
|
||||
|
||||
async def add_client(self, client_id: str, websocket_mock) -> bool:
|
||||
"""Add WebSocket client."""
|
||||
return await self.websocket_manager.add_client(client_id, websocket_mock)
|
||||
|
||||
async def _broadcast_loop(self):
|
||||
"""Broadcasting loop."""
|
||||
while self.is_running:
|
||||
try:
|
||||
frame = await self.output_buffer.get_frame(timeout=0.1)
|
||||
if frame:
|
||||
await self.websocket_manager.broadcast_frame(frame)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
def get_pipeline_stats(self) -> Dict[str, Any]:
|
||||
"""Get pipeline statistics."""
|
||||
return {
|
||||
"is_running": self.is_running,
|
||||
"input_buffer": self.input_buffer.get_stats(),
|
||||
"output_buffer": self.output_buffer.get_stats(),
|
||||
"websocket_clients": self.websocket_manager.get_client_stats()
|
||||
}
|
||||
|
||||
return StreamingPipeline()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_streaming_should_fail_initially(self, streaming_pipeline):
|
||||
"""Test end-to-end streaming - should fail initially."""
|
||||
# Start pipeline
|
||||
result = await streaming_pipeline.start()
|
||||
|
||||
# This will fail initially
|
||||
assert result is True
|
||||
assert streaming_pipeline.is_running is True
|
||||
|
||||
# Add clients
|
||||
for i in range(2):
|
||||
await streaming_pipeline.add_client(f"client_{i}", MagicMock())
|
||||
|
||||
# Add frames
|
||||
for i in range(5):
|
||||
frame = StreamFrame(
|
||||
frame_id=f"frame_{i:03d}",
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
pose_data={"persons": [], "zone_summary": {}},
|
||||
processing_time_ms=30.0,
|
||||
quality_score=0.9
|
||||
)
|
||||
await streaming_pipeline.add_frame(frame)
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Stop pipeline
|
||||
await streaming_pipeline.stop()
|
||||
|
||||
# Check statistics
|
||||
stats = streaming_pipeline.get_pipeline_stats()
|
||||
assert stats["input_buffer"]["total_frames"] == 5
|
||||
assert stats["websocket_clients"]["connected_clients"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_performance_should_fail_initially(self, streaming_pipeline):
|
||||
"""Test pipeline performance - should fail initially."""
|
||||
await streaming_pipeline.start()
|
||||
|
||||
# Add multiple clients
|
||||
for i in range(10):
|
||||
await streaming_pipeline.add_client(f"client_{i:03d}", MagicMock())
|
||||
|
||||
# Measure throughput
|
||||
start_time = datetime.utcnow()
|
||||
frame_count = 50
|
||||
|
||||
for i in range(frame_count):
|
||||
frame = StreamFrame(
|
||||
frame_id=f"perf_frame_{i:03d}",
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
pose_data={"persons": [], "zone_summary": {}},
|
||||
processing_time_ms=25.0,
|
||||
quality_score=0.88
|
||||
)
|
||||
await streaming_pipeline.add_frame(frame)
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(2.0)
|
||||
|
||||
end_time = datetime.utcnow()
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
|
||||
await streaming_pipeline.stop()
|
||||
|
||||
# This will fail initially
|
||||
# Check performance metrics
|
||||
stats = streaming_pipeline.get_pipeline_stats()
|
||||
throughput = frame_count / duration
|
||||
|
||||
assert throughput > 10 # Should process at least 10 FPS
|
||||
assert stats["input_buffer"]["drop_rate"] < 0.1 # Less than 10% drop rate
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_error_recovery_should_fail_initially(self, streaming_pipeline):
|
||||
"""Test pipeline error recovery - should fail initially."""
|
||||
await streaming_pipeline.start()
|
||||
|
||||
# Set high error rate
|
||||
streaming_pipeline.processor.set_error_rate(0.5) # 50% error rate
|
||||
|
||||
# Add frames
|
||||
for i in range(20):
|
||||
frame = StreamFrame(
|
||||
frame_id=f"error_frame_{i:03d}",
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
pose_data={"persons": [], "zone_summary": {}},
|
||||
processing_time_ms=30.0,
|
||||
quality_score=0.9
|
||||
)
|
||||
await streaming_pipeline.add_frame(frame)
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
await streaming_pipeline.stop()
|
||||
|
||||
# This will fail initially
|
||||
# Pipeline should continue running despite errors
|
||||
stats = streaming_pipeline.get_pipeline_stats()
|
||||
assert stats["input_buffer"]["total_frames"] == 20
|
||||
# Some frames should be processed despite errors
|
||||
assert stats["output_buffer"]["total_frames"] > 0
|
||||
|
||||
|
||||
class TestStreamingLatency:
|
||||
"""Test streaming latency characteristics."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_latency_should_fail_initially(self):
|
||||
"""Test end-to-end latency - should fail initially."""
|
||||
class LatencyTracker:
|
||||
def __init__(self):
|
||||
self.latencies = []
|
||||
|
||||
async def measure_latency(self, frame: StreamFrame) -> float:
|
||||
"""Measure processing latency."""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
# Simulate processing pipeline
|
||||
await asyncio.sleep(0.05) # 50ms processing time
|
||||
|
||||
end_time = datetime.utcnow()
|
||||
latency = (end_time - start_time).total_seconds() * 1000 # Convert to ms
|
||||
|
||||
self.latencies.append(latency)
|
||||
return latency
|
||||
|
||||
tracker = LatencyTracker()
|
||||
|
||||
# Measure latency for multiple frames
|
||||
for i in range(10):
|
||||
frame = StreamFrame(
|
||||
frame_id=f"latency_frame_{i}",
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
pose_data={},
|
||||
processing_time_ms=0,
|
||||
quality_score=1.0
|
||||
)
|
||||
|
||||
latency = await tracker.measure_latency(frame)
|
||||
|
||||
# This will fail initially
|
||||
assert latency > 0
|
||||
assert latency < 200 # Should be less than 200ms
|
||||
|
||||
# Check average latency
|
||||
avg_latency = sum(tracker.latencies) / len(tracker.latencies)
|
||||
assert avg_latency < 100 # Average should be less than 100ms
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_stream_handling_should_fail_initially(self):
|
||||
"""Test concurrent stream handling - should fail initially."""
|
||||
async def process_stream(stream_id: str, frame_count: int) -> Dict[str, Any]:
|
||||
"""Process a single stream."""
|
||||
buffer = MockStreamBuffer()
|
||||
processed_frames = 0
|
||||
|
||||
for i in range(frame_count):
|
||||
frame = StreamFrame(
|
||||
frame_id=f"{stream_id}_frame_{i}",
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id=stream_id,
|
||||
pose_data={},
|
||||
processing_time_ms=20.0,
|
||||
quality_score=0.9
|
||||
)
|
||||
|
||||
success = await buffer.put_frame(frame)
|
||||
if success:
|
||||
processed_frames += 1
|
||||
|
||||
await asyncio.sleep(0.01) # Simulate frame rate
|
||||
|
||||
return {
|
||||
"stream_id": stream_id,
|
||||
"processed_frames": processed_frames,
|
||||
"total_frames": frame_count
|
||||
}
|
||||
|
||||
# Process multiple streams concurrently
|
||||
streams = ["router_001", "router_002", "router_003"]
|
||||
tasks = [process_stream(stream_id, 20) for stream_id in streams]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == 3
|
||||
|
||||
for result in results:
|
||||
assert result["processed_frames"] == result["total_frames"]
|
||||
assert result["stream_id"] in streams
|
||||
|
||||
|
||||
class TestStreamingResilience:
|
||||
"""Test streaming pipeline resilience."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_disconnection_handling_should_fail_initially(self):
|
||||
"""Test client disconnection handling - should fail initially."""
|
||||
websocket_manager = MockWebSocketManager()
|
||||
|
||||
# Add clients
|
||||
client_ids = [f"client_{i:03d}" for i in range(5)]
|
||||
for client_id in client_ids:
|
||||
await websocket_manager.add_client(client_id, MagicMock())
|
||||
|
||||
# Simulate frame broadcasting
|
||||
frame = StreamFrame(
|
||||
frame_id="disconnect_test_frame",
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
pose_data={},
|
||||
processing_time_ms=30.0,
|
||||
quality_score=0.9
|
||||
)
|
||||
|
||||
# Broadcast to all clients
|
||||
results = await websocket_manager.broadcast_frame(frame)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == 5
|
||||
|
||||
# Simulate client disconnections
|
||||
await websocket_manager.remove_client("client_001")
|
||||
await websocket_manager.remove_client("client_003")
|
||||
|
||||
# Broadcast again
|
||||
results = await websocket_manager.broadcast_frame(frame)
|
||||
assert len(results) == 3 # Only remaining clients
|
||||
|
||||
# Check statistics
|
||||
stats = websocket_manager.get_client_stats()
|
||||
assert stats["connected_clients"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_pressure_handling_should_fail_initially(self):
|
||||
"""Test memory pressure handling - should fail initially."""
|
||||
# Create small buffers to simulate memory pressure
|
||||
small_buffer = MockStreamBuffer(max_size=5)
|
||||
|
||||
# Generate many frames quickly
|
||||
frames_generated = 0
|
||||
frames_accepted = 0
|
||||
|
||||
for i in range(20):
|
||||
frame = StreamFrame(
|
||||
frame_id=f"memory_pressure_frame_{i}",
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
pose_data={},
|
||||
processing_time_ms=25.0,
|
||||
quality_score=0.85
|
||||
)
|
||||
|
||||
frames_generated += 1
|
||||
success = await small_buffer.put_frame(frame)
|
||||
if success:
|
||||
frames_accepted += 1
|
||||
|
||||
# This will fail initially
|
||||
# Buffer should handle memory pressure gracefully
|
||||
stats = small_buffer.get_stats()
|
||||
assert stats["total_frames"] == frames_generated
|
||||
assert stats["dropped_frames"] > 0 # Some frames should be dropped
|
||||
assert frames_accepted <= small_buffer.max_size
|
||||
|
||||
# Drop rate should be reasonable
|
||||
assert stats["drop_rate"] > 0.5 # More than 50% dropped due to small buffer
|
||||
419
v1/tests/integration/test_websocket_streaming.py
Normal file
419
v1/tests/integration/test_websocket_streaming.py
Normal file
@@ -0,0 +1,419 @@
|
||||
"""
|
||||
Integration tests for WebSocket streaming functionality.
|
||||
|
||||
Tests WebSocket connections, message handling, and real-time data streaming.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import websockets
|
||||
from fastapi import FastAPI, WebSocket
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
class MockWebSocket:
|
||||
"""Mock WebSocket for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.messages_sent = []
|
||||
self.messages_received = []
|
||||
self.closed = False
|
||||
self.accept_called = False
|
||||
|
||||
async def accept(self):
|
||||
"""Mock accept method."""
|
||||
self.accept_called = True
|
||||
|
||||
async def send_json(self, data: Dict[str, Any]):
|
||||
"""Mock send_json method."""
|
||||
self.messages_sent.append(data)
|
||||
|
||||
async def send_text(self, text: str):
|
||||
"""Mock send_text method."""
|
||||
self.messages_sent.append(text)
|
||||
|
||||
async def receive_text(self) -> str:
|
||||
"""Mock receive_text method."""
|
||||
if self.messages_received:
|
||||
return self.messages_received.pop(0)
|
||||
# Simulate WebSocket disconnect
|
||||
from fastapi import WebSocketDisconnect
|
||||
raise WebSocketDisconnect()
|
||||
|
||||
async def close(self):
|
||||
"""Mock close method."""
|
||||
self.closed = True
|
||||
|
||||
def add_received_message(self, message: str):
|
||||
"""Add a message to be received."""
|
||||
self.messages_received.append(message)
|
||||
|
||||
|
||||
class TestWebSocketStreaming:
|
||||
"""Integration tests for WebSocket streaming."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket(self):
|
||||
"""Create mock WebSocket."""
|
||||
return MockWebSocket()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_connection_manager(self):
|
||||
"""Mock connection manager."""
|
||||
manager = AsyncMock()
|
||||
manager.connect.return_value = "client-001"
|
||||
manager.disconnect.return_value = True
|
||||
manager.get_connection_stats.return_value = {
|
||||
"total_clients": 1,
|
||||
"active_streams": ["pose"]
|
||||
}
|
||||
manager.broadcast.return_value = 1
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stream_service(self):
|
||||
"""Mock stream service."""
|
||||
service = AsyncMock()
|
||||
service.get_status.return_value = {
|
||||
"is_active": True,
|
||||
"active_streams": [],
|
||||
"uptime_seconds": 3600.0
|
||||
}
|
||||
service.is_active.return_value = True
|
||||
service.start.return_value = None
|
||||
service.stop.return_value = None
|
||||
return service
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_pose_connection_should_fail_initially(self, mock_websocket, mock_connection_manager):
|
||||
"""Test WebSocket pose connection establishment - should fail initially."""
|
||||
# This test should fail because we haven't implemented the WebSocket handler properly
|
||||
|
||||
# Simulate WebSocket connection
|
||||
zone_ids = "zone1,zone2"
|
||||
min_confidence = 0.7
|
||||
max_fps = 30
|
||||
|
||||
# Mock the websocket_pose_stream function
|
||||
async def mock_websocket_handler(websocket, zone_ids, min_confidence, max_fps):
|
||||
await websocket.accept()
|
||||
|
||||
# Parse zone IDs
|
||||
zone_list = [zone.strip() for zone in zone_ids.split(",") if zone.strip()]
|
||||
|
||||
# Register client
|
||||
client_id = await mock_connection_manager.connect(
|
||||
websocket=websocket,
|
||||
stream_type="pose",
|
||||
zone_ids=zone_list,
|
||||
min_confidence=min_confidence,
|
||||
max_fps=max_fps
|
||||
)
|
||||
|
||||
# Send confirmation
|
||||
await websocket.send_json({
|
||||
"type": "connection_established",
|
||||
"client_id": client_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"config": {
|
||||
"zone_ids": zone_list,
|
||||
"min_confidence": min_confidence,
|
||||
"max_fps": max_fps
|
||||
}
|
||||
})
|
||||
|
||||
return client_id
|
||||
|
||||
# Execute the handler
|
||||
client_id = await mock_websocket_handler(mock_websocket, zone_ids, min_confidence, max_fps)
|
||||
|
||||
# This assertion will fail initially, driving us to implement the WebSocket handler
|
||||
assert mock_websocket.accept_called
|
||||
assert len(mock_websocket.messages_sent) == 1
|
||||
assert mock_websocket.messages_sent[0]["type"] == "connection_established"
|
||||
assert mock_websocket.messages_sent[0]["client_id"] == "client-001"
|
||||
assert "config" in mock_websocket.messages_sent[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_message_handling_should_fail_initially(self, mock_websocket):
|
||||
"""Test WebSocket message handling - should fail initially."""
|
||||
# Mock message handler
|
||||
async def handle_websocket_message(client_id: str, data: Dict[str, Any], websocket):
|
||||
message_type = data.get("type")
|
||||
|
||||
if message_type == "ping":
|
||||
await websocket.send_json({
|
||||
"type": "pong",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
elif message_type == "update_config":
|
||||
config = data.get("config", {})
|
||||
await websocket.send_json({
|
||||
"type": "config_updated",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"config": config
|
||||
})
|
||||
else:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": f"Unknown message type: {message_type}"
|
||||
})
|
||||
|
||||
# Test ping message
|
||||
ping_data = {"type": "ping"}
|
||||
await handle_websocket_message("client-001", ping_data, mock_websocket)
|
||||
|
||||
# This will fail initially
|
||||
assert len(mock_websocket.messages_sent) == 1
|
||||
assert mock_websocket.messages_sent[0]["type"] == "pong"
|
||||
|
||||
# Test config update
|
||||
mock_websocket.messages_sent.clear()
|
||||
config_data = {
|
||||
"type": "update_config",
|
||||
"config": {"min_confidence": 0.8, "max_fps": 15}
|
||||
}
|
||||
await handle_websocket_message("client-001", config_data, mock_websocket)
|
||||
|
||||
# This will fail initially
|
||||
assert len(mock_websocket.messages_sent) == 1
|
||||
assert mock_websocket.messages_sent[0]["type"] == "config_updated"
|
||||
assert mock_websocket.messages_sent[0]["config"]["min_confidence"] == 0.8
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_events_stream_should_fail_initially(self, mock_websocket, mock_connection_manager):
|
||||
"""Test WebSocket events stream - should fail initially."""
|
||||
# Mock events stream handler
|
||||
async def mock_events_handler(websocket, event_types, zone_ids):
|
||||
await websocket.accept()
|
||||
|
||||
# Parse parameters
|
||||
event_list = [event.strip() for event in event_types.split(",") if event.strip()] if event_types else None
|
||||
zone_list = [zone.strip() for zone in zone_ids.split(",") if zone.strip()] if zone_ids else None
|
||||
|
||||
# Register client
|
||||
client_id = await mock_connection_manager.connect(
|
||||
websocket=websocket,
|
||||
stream_type="events",
|
||||
zone_ids=zone_list,
|
||||
event_types=event_list
|
||||
)
|
||||
|
||||
# Send confirmation
|
||||
await websocket.send_json({
|
||||
"type": "connection_established",
|
||||
"client_id": client_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"config": {
|
||||
"event_types": event_list,
|
||||
"zone_ids": zone_list
|
||||
}
|
||||
})
|
||||
|
||||
return client_id
|
||||
|
||||
# Execute handler
|
||||
client_id = await mock_events_handler(mock_websocket, "fall_detection,intrusion", "zone1")
|
||||
|
||||
# This will fail initially
|
||||
assert mock_websocket.accept_called
|
||||
assert len(mock_websocket.messages_sent) == 1
|
||||
assert mock_websocket.messages_sent[0]["type"] == "connection_established"
|
||||
assert mock_websocket.messages_sent[0]["config"]["event_types"] == ["fall_detection", "intrusion"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_disconnect_handling_should_fail_initially(self, mock_websocket, mock_connection_manager):
|
||||
"""Test WebSocket disconnect handling - should fail initially."""
|
||||
# Mock disconnect scenario
|
||||
client_id = "client-001"
|
||||
|
||||
# Simulate disconnect
|
||||
disconnect_result = await mock_connection_manager.disconnect(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert disconnect_result is True
|
||||
mock_connection_manager.disconnect.assert_called_once_with(client_id)
|
||||
|
||||
|
||||
class TestWebSocketConnectionManager:
|
||||
"""Test WebSocket connection management."""
|
||||
|
||||
@pytest.fixture
|
||||
def connection_manager(self):
|
||||
"""Create connection manager for testing."""
|
||||
# Mock connection manager implementation
|
||||
class MockConnectionManager:
|
||||
def __init__(self):
|
||||
self.connections = {}
|
||||
self.client_counter = 0
|
||||
|
||||
async def connect(self, websocket, stream_type, zone_ids=None, **kwargs):
|
||||
self.client_counter += 1
|
||||
client_id = f"client-{self.client_counter:03d}"
|
||||
self.connections[client_id] = {
|
||||
"websocket": websocket,
|
||||
"stream_type": stream_type,
|
||||
"zone_ids": zone_ids or [],
|
||||
"connected_at": datetime.utcnow(),
|
||||
**kwargs
|
||||
}
|
||||
return client_id
|
||||
|
||||
async def disconnect(self, client_id):
|
||||
if client_id in self.connections:
|
||||
del self.connections[client_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_connected_clients(self):
|
||||
return list(self.connections.keys())
|
||||
|
||||
async def get_connection_stats(self):
|
||||
return {
|
||||
"total_clients": len(self.connections),
|
||||
"active_streams": list(set(conn["stream_type"] for conn in self.connections.values()))
|
||||
}
|
||||
|
||||
async def broadcast(self, data, stream_type=None, zone_ids=None):
|
||||
sent_count = 0
|
||||
for client_id, conn in self.connections.items():
|
||||
if stream_type and conn["stream_type"] != stream_type:
|
||||
continue
|
||||
if zone_ids and not any(zone in conn["zone_ids"] for zone in zone_ids):
|
||||
continue
|
||||
|
||||
# Mock sending data
|
||||
sent_count += 1
|
||||
|
||||
return sent_count
|
||||
|
||||
return MockConnectionManager()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_manager_connect_should_fail_initially(self, connection_manager, mock_websocket):
|
||||
"""Test connection manager connect functionality - should fail initially."""
|
||||
client_id = await connection_manager.connect(
|
||||
websocket=mock_websocket,
|
||||
stream_type="pose",
|
||||
zone_ids=["zone1", "zone2"],
|
||||
min_confidence=0.7
|
||||
)
|
||||
|
||||
# This will fail initially
|
||||
assert client_id == "client-001"
|
||||
assert client_id in connection_manager.connections
|
||||
assert connection_manager.connections[client_id]["stream_type"] == "pose"
|
||||
assert connection_manager.connections[client_id]["zone_ids"] == ["zone1", "zone2"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_manager_disconnect_should_fail_initially(self, connection_manager, mock_websocket):
|
||||
"""Test connection manager disconnect functionality - should fail initially."""
|
||||
# Connect first
|
||||
client_id = await connection_manager.connect(
|
||||
websocket=mock_websocket,
|
||||
stream_type="pose"
|
||||
)
|
||||
|
||||
# Disconnect
|
||||
result = await connection_manager.disconnect(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert result is True
|
||||
assert client_id not in connection_manager.connections
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_manager_broadcast_should_fail_initially(self, connection_manager):
|
||||
"""Test connection manager broadcast functionality - should fail initially."""
|
||||
# Connect multiple clients
|
||||
ws1 = MockWebSocket()
|
||||
ws2 = MockWebSocket()
|
||||
|
||||
client1 = await connection_manager.connect(ws1, "pose", zone_ids=["zone1"])
|
||||
client2 = await connection_manager.connect(ws2, "events", zone_ids=["zone2"])
|
||||
|
||||
# Broadcast to pose stream
|
||||
sent_count = await connection_manager.broadcast(
|
||||
data={"type": "pose_data", "data": {}},
|
||||
stream_type="pose"
|
||||
)
|
||||
|
||||
# This will fail initially
|
||||
assert sent_count == 1
|
||||
|
||||
# Broadcast to specific zone
|
||||
sent_count = await connection_manager.broadcast(
|
||||
data={"type": "zone_event", "data": {}},
|
||||
zone_ids=["zone1"]
|
||||
)
|
||||
|
||||
# This will fail initially
|
||||
assert sent_count == 1
|
||||
|
||||
|
||||
class TestWebSocketPerformance:
|
||||
"""Test WebSocket performance characteristics."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_concurrent_connections_should_fail_initially(self):
|
||||
"""Test handling multiple concurrent WebSocket connections - should fail initially."""
|
||||
# Mock multiple connections
|
||||
connection_count = 10
|
||||
connections = []
|
||||
|
||||
for i in range(connection_count):
|
||||
mock_ws = MockWebSocket()
|
||||
connections.append(mock_ws)
|
||||
|
||||
# Simulate concurrent connections
|
||||
async def simulate_connection(websocket, client_id):
|
||||
await websocket.accept()
|
||||
await websocket.send_json({
|
||||
"type": "connection_established",
|
||||
"client_id": client_id
|
||||
})
|
||||
return True
|
||||
|
||||
# Execute concurrent connections
|
||||
tasks = [
|
||||
simulate_connection(ws, f"client-{i:03d}")
|
||||
for i, ws in enumerate(connections)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == connection_count
|
||||
assert all(results)
|
||||
assert all(ws.accept_called for ws in connections)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_message_throughput_should_fail_initially(self):
|
||||
"""Test WebSocket message throughput - should fail initially."""
|
||||
mock_ws = MockWebSocket()
|
||||
message_count = 100
|
||||
|
||||
# Simulate high-frequency message sending
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
for i in range(message_count):
|
||||
await mock_ws.send_json({
|
||||
"type": "pose_data",
|
||||
"frame_id": f"frame-{i:04d}",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
|
||||
end_time = datetime.utcnow()
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
|
||||
# This will fail initially
|
||||
assert len(mock_ws.messages_sent) == message_count
|
||||
assert duration < 1.0 # Should handle 100 messages in under 1 second
|
||||
|
||||
# Calculate throughput
|
||||
throughput = message_count / duration if duration > 0 else float('inf')
|
||||
assert throughput > 100 # Should handle at least 100 messages per second
|
||||
712
v1/tests/mocks/hardware_mocks.py
Normal file
712
v1/tests/mocks/hardware_mocks.py
Normal file
@@ -0,0 +1,712 @@
|
||||
"""
|
||||
Hardware simulation mocks for testing.
|
||||
|
||||
Provides realistic hardware behavior simulation for routers and sensors.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional, Callable, AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
import json
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class RouterStatus(Enum):
|
||||
"""Router status enumeration."""
|
||||
OFFLINE = "offline"
|
||||
CONNECTING = "connecting"
|
||||
ONLINE = "online"
|
||||
ERROR = "error"
|
||||
MAINTENANCE = "maintenance"
|
||||
|
||||
|
||||
class SignalQuality(Enum):
|
||||
"""Signal quality levels."""
|
||||
POOR = "poor"
|
||||
FAIR = "fair"
|
||||
GOOD = "good"
|
||||
EXCELLENT = "excellent"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RouterConfig:
|
||||
"""Router configuration."""
|
||||
router_id: str
|
||||
frequency: float = 5.8e9 # 5.8 GHz
|
||||
bandwidth: float = 80e6 # 80 MHz
|
||||
num_antennas: int = 4
|
||||
num_subcarriers: int = 64
|
||||
tx_power: float = 20.0 # dBm
|
||||
location: Dict[str, float] = field(default_factory=lambda: {"x": 0, "y": 0, "z": 0})
|
||||
firmware_version: str = "1.2.3"
|
||||
|
||||
|
||||
class MockWiFiRouter:
|
||||
"""Mock WiFi router with CSI capabilities."""
|
||||
|
||||
def __init__(self, config: RouterConfig):
|
||||
self.config = config
|
||||
self.status = RouterStatus.OFFLINE
|
||||
self.signal_quality = SignalQuality.GOOD
|
||||
self.is_streaming = False
|
||||
self.connected_devices = []
|
||||
self.csi_data_buffer = []
|
||||
self.error_rate = 0.01 # 1% error rate
|
||||
self.latency_ms = 5.0
|
||||
self.throughput_mbps = 100.0
|
||||
self.temperature_celsius = 45.0
|
||||
self.uptime_seconds = 0
|
||||
self.last_heartbeat = None
|
||||
self.callbacks = {
|
||||
"on_status_change": [],
|
||||
"on_csi_data": [],
|
||||
"on_error": []
|
||||
}
|
||||
self._streaming_task = None
|
||||
self._heartbeat_task = None
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to router."""
|
||||
if self.status != RouterStatus.OFFLINE:
|
||||
return False
|
||||
|
||||
self.status = RouterStatus.CONNECTING
|
||||
await self._notify_status_change()
|
||||
|
||||
# Simulate connection delay
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Simulate occasional connection failures
|
||||
if random.random() < 0.05: # 5% failure rate
|
||||
self.status = RouterStatus.ERROR
|
||||
await self._notify_error("Connection failed")
|
||||
return False
|
||||
|
||||
self.status = RouterStatus.ONLINE
|
||||
self.last_heartbeat = datetime.utcnow()
|
||||
await self._notify_status_change()
|
||||
|
||||
# Start heartbeat
|
||||
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
||||
|
||||
return True
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from router."""
|
||||
if self.status == RouterStatus.OFFLINE:
|
||||
return
|
||||
|
||||
# Stop streaming if active
|
||||
if self.is_streaming:
|
||||
await self.stop_csi_streaming()
|
||||
|
||||
# Stop heartbeat
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
try:
|
||||
await self._heartbeat_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self.status = RouterStatus.OFFLINE
|
||||
await self._notify_status_change()
|
||||
|
||||
async def start_csi_streaming(self, sample_rate: int = 1000) -> bool:
|
||||
"""Start CSI data streaming."""
|
||||
if self.status != RouterStatus.ONLINE:
|
||||
return False
|
||||
|
||||
if self.is_streaming:
|
||||
return False
|
||||
|
||||
self.is_streaming = True
|
||||
self._streaming_task = asyncio.create_task(self._csi_streaming_loop(sample_rate))
|
||||
|
||||
return True
|
||||
|
||||
async def stop_csi_streaming(self):
|
||||
"""Stop CSI data streaming."""
|
||||
if not self.is_streaming:
|
||||
return
|
||||
|
||||
self.is_streaming = False
|
||||
|
||||
if self._streaming_task:
|
||||
self._streaming_task.cancel()
|
||||
try:
|
||||
await self._streaming_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _csi_streaming_loop(self, sample_rate: int):
|
||||
"""CSI data streaming loop."""
|
||||
interval = 1.0 / sample_rate
|
||||
|
||||
try:
|
||||
while self.is_streaming:
|
||||
# Generate CSI data
|
||||
csi_data = self._generate_csi_sample()
|
||||
|
||||
# Add to buffer
|
||||
self.csi_data_buffer.append(csi_data)
|
||||
|
||||
# Keep buffer size manageable
|
||||
if len(self.csi_data_buffer) > 1000:
|
||||
self.csi_data_buffer = self.csi_data_buffer[-1000:]
|
||||
|
||||
# Notify callbacks
|
||||
await self._notify_csi_data(csi_data)
|
||||
|
||||
# Simulate processing delay and jitter
|
||||
actual_interval = interval * random.uniform(0.9, 1.1)
|
||||
await asyncio.sleep(actual_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _heartbeat_loop(self):
|
||||
"""Heartbeat loop to maintain connection."""
|
||||
try:
|
||||
while self.status == RouterStatus.ONLINE:
|
||||
self.last_heartbeat = datetime.utcnow()
|
||||
self.uptime_seconds += 1
|
||||
|
||||
# Simulate temperature variations
|
||||
self.temperature_celsius += random.uniform(-1, 1)
|
||||
self.temperature_celsius = max(30, min(80, self.temperature_celsius))
|
||||
|
||||
# Check for overheating
|
||||
if self.temperature_celsius > 75:
|
||||
self.signal_quality = SignalQuality.POOR
|
||||
await self._notify_error("High temperature warning")
|
||||
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def _generate_csi_sample(self) -> Dict[str, Any]:
|
||||
"""Generate realistic CSI sample."""
|
||||
# Base amplitude and phase matrices
|
||||
amplitude = np.random.uniform(0.2, 0.8, (self.config.num_antennas, self.config.num_subcarriers))
|
||||
phase = np.random.uniform(-np.pi, np.pi, (self.config.num_antennas, self.config.num_subcarriers))
|
||||
|
||||
# Add signal quality effects
|
||||
if self.signal_quality == SignalQuality.POOR:
|
||||
noise_level = 0.3
|
||||
elif self.signal_quality == SignalQuality.FAIR:
|
||||
noise_level = 0.2
|
||||
elif self.signal_quality == SignalQuality.GOOD:
|
||||
noise_level = 0.1
|
||||
else: # EXCELLENT
|
||||
noise_level = 0.05
|
||||
|
||||
# Add noise
|
||||
amplitude += np.random.normal(0, noise_level, amplitude.shape)
|
||||
phase += np.random.normal(0, noise_level * np.pi, phase.shape)
|
||||
|
||||
# Clip values
|
||||
amplitude = np.clip(amplitude, 0, 1)
|
||||
phase = np.mod(phase + np.pi, 2 * np.pi) - np.pi
|
||||
|
||||
# Simulate packet errors
|
||||
if random.random() < self.error_rate:
|
||||
# Corrupt some data
|
||||
corruption_mask = np.random.random(amplitude.shape) < 0.1
|
||||
amplitude[corruption_mask] = 0
|
||||
phase[corruption_mask] = 0
|
||||
|
||||
return {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"router_id": self.config.router_id,
|
||||
"amplitude": amplitude.tolist(),
|
||||
"phase": phase.tolist(),
|
||||
"frequency": self.config.frequency,
|
||||
"bandwidth": self.config.bandwidth,
|
||||
"num_antennas": self.config.num_antennas,
|
||||
"num_subcarriers": self.config.num_subcarriers,
|
||||
"signal_quality": self.signal_quality.value,
|
||||
"temperature": self.temperature_celsius,
|
||||
"tx_power": self.config.tx_power,
|
||||
"sequence_number": len(self.csi_data_buffer)
|
||||
}
|
||||
|
||||
def register_callback(self, event: str, callback: Callable):
|
||||
"""Register event callback."""
|
||||
if event in self.callbacks:
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def unregister_callback(self, event: str, callback: Callable):
|
||||
"""Unregister event callback."""
|
||||
if event in self.callbacks and callback in self.callbacks[event]:
|
||||
self.callbacks[event].remove(callback)
|
||||
|
||||
async def _notify_status_change(self):
|
||||
"""Notify status change callbacks."""
|
||||
for callback in self.callbacks["on_status_change"]:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(self.status)
|
||||
else:
|
||||
callback(self.status)
|
||||
except Exception:
|
||||
pass # Ignore callback errors
|
||||
|
||||
async def _notify_csi_data(self, data: Dict[str, Any]):
|
||||
"""Notify CSI data callbacks."""
|
||||
for callback in self.callbacks["on_csi_data"]:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(data)
|
||||
else:
|
||||
callback(data)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _notify_error(self, error_message: str):
|
||||
"""Notify error callbacks."""
|
||||
for callback in self.callbacks["on_error"]:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(error_message)
|
||||
else:
|
||||
callback(error_message)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Get router status information."""
|
||||
return {
|
||||
"router_id": self.config.router_id,
|
||||
"status": self.status.value,
|
||||
"signal_quality": self.signal_quality.value,
|
||||
"is_streaming": self.is_streaming,
|
||||
"connected_devices": len(self.connected_devices),
|
||||
"temperature": self.temperature_celsius,
|
||||
"uptime_seconds": self.uptime_seconds,
|
||||
"last_heartbeat": self.last_heartbeat.isoformat() if self.last_heartbeat else None,
|
||||
"error_rate": self.error_rate,
|
||||
"latency_ms": self.latency_ms,
|
||||
"throughput_mbps": self.throughput_mbps,
|
||||
"firmware_version": self.config.firmware_version,
|
||||
"location": self.config.location
|
||||
}
|
||||
|
||||
def set_signal_quality(self, quality: SignalQuality):
|
||||
"""Set signal quality for testing."""
|
||||
self.signal_quality = quality
|
||||
|
||||
def set_error_rate(self, error_rate: float):
|
||||
"""Set error rate for testing."""
|
||||
self.error_rate = max(0, min(1, error_rate))
|
||||
|
||||
def simulate_interference(self, duration_seconds: float = 5.0):
|
||||
"""Simulate interference for testing."""
|
||||
async def interference_task():
|
||||
original_quality = self.signal_quality
|
||||
self.signal_quality = SignalQuality.POOR
|
||||
await asyncio.sleep(duration_seconds)
|
||||
self.signal_quality = original_quality
|
||||
|
||||
asyncio.create_task(interference_task())
|
||||
|
||||
def get_csi_buffer(self) -> List[Dict[str, Any]]:
|
||||
"""Get CSI data buffer."""
|
||||
return self.csi_data_buffer.copy()
|
||||
|
||||
def clear_csi_buffer(self):
|
||||
"""Clear CSI data buffer."""
|
||||
self.csi_data_buffer.clear()
|
||||
|
||||
|
||||
class MockRouterNetwork:
|
||||
"""Mock network of WiFi routers."""
|
||||
|
||||
def __init__(self):
|
||||
self.routers = {}
|
||||
self.network_topology = {}
|
||||
self.interference_sources = []
|
||||
self.global_callbacks = {
|
||||
"on_router_added": [],
|
||||
"on_router_removed": [],
|
||||
"on_network_event": []
|
||||
}
|
||||
|
||||
def add_router(self, config: RouterConfig) -> MockWiFiRouter:
|
||||
"""Add router to network."""
|
||||
if config.router_id in self.routers:
|
||||
raise ValueError(f"Router {config.router_id} already exists")
|
||||
|
||||
router = MockWiFiRouter(config)
|
||||
self.routers[config.router_id] = router
|
||||
|
||||
# Register for router events
|
||||
router.register_callback("on_status_change", self._on_router_status_change)
|
||||
router.register_callback("on_error", self._on_router_error)
|
||||
|
||||
# Notify callbacks
|
||||
for callback in self.global_callbacks["on_router_added"]:
|
||||
callback(router)
|
||||
|
||||
return router
|
||||
|
||||
def remove_router(self, router_id: str) -> bool:
|
||||
"""Remove router from network."""
|
||||
if router_id not in self.routers:
|
||||
return False
|
||||
|
||||
router = self.routers[router_id]
|
||||
|
||||
# Disconnect if connected
|
||||
if router.status != RouterStatus.OFFLINE:
|
||||
asyncio.create_task(router.disconnect())
|
||||
|
||||
del self.routers[router_id]
|
||||
|
||||
# Notify callbacks
|
||||
for callback in self.global_callbacks["on_router_removed"]:
|
||||
callback(router_id)
|
||||
|
||||
return True
|
||||
|
||||
def get_router(self, router_id: str) -> Optional[MockWiFiRouter]:
|
||||
"""Get router by ID."""
|
||||
return self.routers.get(router_id)
|
||||
|
||||
def get_all_routers(self) -> Dict[str, MockWiFiRouter]:
|
||||
"""Get all routers."""
|
||||
return self.routers.copy()
|
||||
|
||||
async def connect_all_routers(self) -> Dict[str, bool]:
|
||||
"""Connect all routers."""
|
||||
results = {}
|
||||
tasks = []
|
||||
|
||||
for router_id, router in self.routers.items():
|
||||
task = asyncio.create_task(router.connect())
|
||||
tasks.append((router_id, task))
|
||||
|
||||
for router_id, task in tasks:
|
||||
try:
|
||||
result = await task
|
||||
results[router_id] = result
|
||||
except Exception:
|
||||
results[router_id] = False
|
||||
|
||||
return results
|
||||
|
||||
async def disconnect_all_routers(self):
|
||||
"""Disconnect all routers."""
|
||||
tasks = []
|
||||
|
||||
for router in self.routers.values():
|
||||
if router.status != RouterStatus.OFFLINE:
|
||||
task = asyncio.create_task(router.disconnect())
|
||||
tasks.append(task)
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def start_all_streaming(self, sample_rate: int = 1000) -> Dict[str, bool]:
|
||||
"""Start CSI streaming on all routers."""
|
||||
results = {}
|
||||
|
||||
for router_id, router in self.routers.items():
|
||||
if router.status == RouterStatus.ONLINE:
|
||||
result = await router.start_csi_streaming(sample_rate)
|
||||
results[router_id] = result
|
||||
else:
|
||||
results[router_id] = False
|
||||
|
||||
return results
|
||||
|
||||
async def stop_all_streaming(self):
|
||||
"""Stop CSI streaming on all routers."""
|
||||
tasks = []
|
||||
|
||||
for router in self.routers.values():
|
||||
if router.is_streaming:
|
||||
task = asyncio.create_task(router.stop_csi_streaming())
|
||||
tasks.append(task)
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
def get_network_status(self) -> Dict[str, Any]:
|
||||
"""Get overall network status."""
|
||||
total_routers = len(self.routers)
|
||||
online_routers = sum(1 for r in self.routers.values() if r.status == RouterStatus.ONLINE)
|
||||
streaming_routers = sum(1 for r in self.routers.values() if r.is_streaming)
|
||||
|
||||
return {
|
||||
"total_routers": total_routers,
|
||||
"online_routers": online_routers,
|
||||
"streaming_routers": streaming_routers,
|
||||
"network_health": online_routers / max(total_routers, 1),
|
||||
"interference_sources": len(self.interference_sources),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
def simulate_network_partition(self, router_ids: List[str], duration_seconds: float = 10.0):
|
||||
"""Simulate network partition for testing."""
|
||||
async def partition_task():
|
||||
# Disconnect specified routers
|
||||
affected_routers = [self.routers[rid] for rid in router_ids if rid in self.routers]
|
||||
|
||||
for router in affected_routers:
|
||||
if router.status == RouterStatus.ONLINE:
|
||||
router.status = RouterStatus.ERROR
|
||||
await router._notify_status_change()
|
||||
|
||||
await asyncio.sleep(duration_seconds)
|
||||
|
||||
# Reconnect routers
|
||||
for router in affected_routers:
|
||||
if router.status == RouterStatus.ERROR:
|
||||
await router.connect()
|
||||
|
||||
asyncio.create_task(partition_task())
|
||||
|
||||
def add_interference_source(self, location: Dict[str, float], strength: float, frequency: float):
|
||||
"""Add interference source."""
|
||||
interference = {
|
||||
"id": f"interference_{len(self.interference_sources)}",
|
||||
"location": location,
|
||||
"strength": strength,
|
||||
"frequency": frequency,
|
||||
"active": True
|
||||
}
|
||||
|
||||
self.interference_sources.append(interference)
|
||||
|
||||
# Affect nearby routers
|
||||
for router in self.routers.values():
|
||||
distance = self._calculate_distance(router.config.location, location)
|
||||
if distance < 50: # Within 50 meters
|
||||
if strength > 0.5:
|
||||
router.set_signal_quality(SignalQuality.POOR)
|
||||
elif strength > 0.3:
|
||||
router.set_signal_quality(SignalQuality.FAIR)
|
||||
|
||||
def _calculate_distance(self, loc1: Dict[str, float], loc2: Dict[str, float]) -> float:
|
||||
"""Calculate distance between two locations."""
|
||||
dx = loc1.get("x", 0) - loc2.get("x", 0)
|
||||
dy = loc1.get("y", 0) - loc2.get("y", 0)
|
||||
dz = loc1.get("z", 0) - loc2.get("z", 0)
|
||||
return np.sqrt(dx**2 + dy**2 + dz**2)
|
||||
|
||||
async def _on_router_status_change(self, status: RouterStatus):
|
||||
"""Handle router status change."""
|
||||
for callback in self.global_callbacks["on_network_event"]:
|
||||
await callback("router_status_change", {"status": status})
|
||||
|
||||
async def _on_router_error(self, error_message: str):
|
||||
"""Handle router error."""
|
||||
for callback in self.global_callbacks["on_network_event"]:
|
||||
await callback("router_error", {"error": error_message})
|
||||
|
||||
def register_global_callback(self, event: str, callback: Callable):
|
||||
"""Register global network callback."""
|
||||
if event in self.global_callbacks:
|
||||
self.global_callbacks[event].append(callback)
|
||||
|
||||
|
||||
class MockSensorArray:
|
||||
"""Mock sensor array for environmental monitoring."""
|
||||
|
||||
def __init__(self, sensor_id: str, location: Dict[str, float]):
|
||||
self.sensor_id = sensor_id
|
||||
self.location = location
|
||||
self.is_active = False
|
||||
self.sensors = {
|
||||
"temperature": {"value": 22.0, "unit": "celsius", "range": (15, 35)},
|
||||
"humidity": {"value": 45.0, "unit": "percent", "range": (30, 70)},
|
||||
"pressure": {"value": 1013.25, "unit": "hPa", "range": (980, 1050)},
|
||||
"light": {"value": 300.0, "unit": "lux", "range": (0, 1000)},
|
||||
"motion": {"value": False, "unit": "boolean", "range": (False, True)},
|
||||
"sound": {"value": 35.0, "unit": "dB", "range": (20, 80)}
|
||||
}
|
||||
self.reading_history = []
|
||||
self.callbacks = []
|
||||
|
||||
async def start_monitoring(self, interval_seconds: float = 1.0):
|
||||
"""Start sensor monitoring."""
|
||||
if self.is_active:
|
||||
return False
|
||||
|
||||
self.is_active = True
|
||||
asyncio.create_task(self._monitoring_loop(interval_seconds))
|
||||
return True
|
||||
|
||||
def stop_monitoring(self):
|
||||
"""Stop sensor monitoring."""
|
||||
self.is_active = False
|
||||
|
||||
async def _monitoring_loop(self, interval: float):
|
||||
"""Sensor monitoring loop."""
|
||||
try:
|
||||
while self.is_active:
|
||||
reading = self._generate_sensor_reading()
|
||||
self.reading_history.append(reading)
|
||||
|
||||
# Keep history manageable
|
||||
if len(self.reading_history) > 1000:
|
||||
self.reading_history = self.reading_history[-1000:]
|
||||
|
||||
# Notify callbacks
|
||||
for callback in self.callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(reading)
|
||||
else:
|
||||
callback(reading)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def _generate_sensor_reading(self) -> Dict[str, Any]:
|
||||
"""Generate realistic sensor reading."""
|
||||
reading = {
|
||||
"sensor_id": self.sensor_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"location": self.location,
|
||||
"readings": {}
|
||||
}
|
||||
|
||||
for sensor_name, config in self.sensors.items():
|
||||
if sensor_name == "motion":
|
||||
# Motion detection with some randomness
|
||||
reading["readings"][sensor_name] = random.random() < 0.1 # 10% chance of motion
|
||||
else:
|
||||
# Continuous sensors with drift
|
||||
current_value = config["value"]
|
||||
min_val, max_val = config["range"]
|
||||
|
||||
# Add small random drift
|
||||
drift = random.uniform(-0.1, 0.1) * (max_val - min_val)
|
||||
new_value = current_value + drift
|
||||
|
||||
# Keep within range
|
||||
new_value = max(min_val, min(max_val, new_value))
|
||||
|
||||
config["value"] = new_value
|
||||
reading["readings"][sensor_name] = {
|
||||
"value": round(new_value, 2),
|
||||
"unit": config["unit"]
|
||||
}
|
||||
|
||||
return reading
|
||||
|
||||
def register_callback(self, callback: Callable):
|
||||
"""Register sensor callback."""
|
||||
self.callbacks.append(callback)
|
||||
|
||||
def unregister_callback(self, callback: Callable):
|
||||
"""Unregister sensor callback."""
|
||||
if callback in self.callbacks:
|
||||
self.callbacks.remove(callback)
|
||||
|
||||
def get_latest_reading(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get latest sensor reading."""
|
||||
return self.reading_history[-1] if self.reading_history else None
|
||||
|
||||
def get_reading_history(self, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Get sensor reading history."""
|
||||
return self.reading_history[-limit:]
|
||||
|
||||
def simulate_event(self, event_type: str, duration_seconds: float = 5.0):
|
||||
"""Simulate environmental event."""
|
||||
async def event_task():
|
||||
if event_type == "motion_detected":
|
||||
self.sensors["motion"]["value"] = True
|
||||
await asyncio.sleep(duration_seconds)
|
||||
self.sensors["motion"]["value"] = False
|
||||
|
||||
elif event_type == "temperature_spike":
|
||||
original_temp = self.sensors["temperature"]["value"]
|
||||
self.sensors["temperature"]["value"] = min(35, original_temp + 10)
|
||||
await asyncio.sleep(duration_seconds)
|
||||
self.sensors["temperature"]["value"] = original_temp
|
||||
|
||||
elif event_type == "loud_noise":
|
||||
original_sound = self.sensors["sound"]["value"]
|
||||
self.sensors["sound"]["value"] = min(80, original_sound + 20)
|
||||
await asyncio.sleep(duration_seconds)
|
||||
self.sensors["sound"]["value"] = original_sound
|
||||
|
||||
asyncio.create_task(event_task())
|
||||
|
||||
|
||||
# Utility functions for creating test hardware setups
|
||||
def create_test_router_network(num_routers: int = 3) -> MockRouterNetwork:
|
||||
"""Create test router network."""
|
||||
network = MockRouterNetwork()
|
||||
|
||||
for i in range(num_routers):
|
||||
config = RouterConfig(
|
||||
router_id=f"router_{i:03d}",
|
||||
location={"x": i * 10, "y": 0, "z": 2.5}
|
||||
)
|
||||
network.add_router(config)
|
||||
|
||||
return network
|
||||
|
||||
|
||||
def create_test_sensor_array(num_sensors: int = 2) -> List[MockSensorArray]:
|
||||
"""Create test sensor array."""
|
||||
sensors = []
|
||||
|
||||
for i in range(num_sensors):
|
||||
sensor = MockSensorArray(
|
||||
sensor_id=f"sensor_{i:03d}",
|
||||
location={"x": i * 5, "y": 5, "z": 1.0}
|
||||
)
|
||||
sensors.append(sensor)
|
||||
|
||||
return sensors
|
||||
|
||||
|
||||
async def setup_test_hardware_environment() -> Dict[str, Any]:
|
||||
"""Setup complete test hardware environment."""
|
||||
# Create router network
|
||||
router_network = create_test_router_network(3)
|
||||
|
||||
# Create sensor arrays
|
||||
sensor_arrays = create_test_sensor_array(2)
|
||||
|
||||
# Connect all routers
|
||||
router_results = await router_network.connect_all_routers()
|
||||
|
||||
# Start sensor monitoring
|
||||
sensor_tasks = []
|
||||
for sensor in sensor_arrays:
|
||||
task = asyncio.create_task(sensor.start_monitoring(1.0))
|
||||
sensor_tasks.append(task)
|
||||
|
||||
sensor_results = await asyncio.gather(*sensor_tasks)
|
||||
|
||||
return {
|
||||
"router_network": router_network,
|
||||
"sensor_arrays": sensor_arrays,
|
||||
"router_connection_results": router_results,
|
||||
"sensor_start_results": sensor_results,
|
||||
"setup_timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
async def teardown_test_hardware_environment(environment: Dict[str, Any]):
|
||||
"""Teardown test hardware environment."""
|
||||
# Stop sensor monitoring
|
||||
for sensor in environment["sensor_arrays"]:
|
||||
sensor.stop_monitoring()
|
||||
|
||||
# Disconnect all routers
|
||||
await environment["router_network"].disconnect_all_routers()
|
||||
649
v1/tests/performance/test_api_throughput.py
Normal file
649
v1/tests/performance/test_api_throughput.py
Normal file
@@ -0,0 +1,649 @@
|
||||
"""
|
||||
Performance tests for API throughput and load testing.
|
||||
|
||||
Tests API endpoint performance under various load conditions.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import time
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import json
|
||||
import statistics
|
||||
|
||||
|
||||
class MockAPIServer:
|
||||
"""Mock API server for load testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.request_count = 0
|
||||
self.response_times = []
|
||||
self.error_count = 0
|
||||
self.concurrent_requests = 0
|
||||
self.max_concurrent = 0
|
||||
self.is_running = False
|
||||
self.rate_limit_enabled = False
|
||||
self.rate_limit_per_second = 100
|
||||
self.request_timestamps = []
|
||||
|
||||
async def handle_request(self, endpoint: str, method: str = "GET", data: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""Handle API request."""
|
||||
start_time = time.time()
|
||||
self.concurrent_requests += 1
|
||||
self.max_concurrent = max(self.max_concurrent, self.concurrent_requests)
|
||||
self.request_count += 1
|
||||
self.request_timestamps.append(start_time)
|
||||
|
||||
try:
|
||||
# Check rate limiting
|
||||
if self.rate_limit_enabled:
|
||||
recent_requests = [
|
||||
ts for ts in self.request_timestamps
|
||||
if start_time - ts <= 1.0
|
||||
]
|
||||
if len(recent_requests) > self.rate_limit_per_second:
|
||||
self.error_count += 1
|
||||
return {
|
||||
"status": 429,
|
||||
"error": "Rate limit exceeded",
|
||||
"response_time_ms": 1.0
|
||||
}
|
||||
|
||||
# Simulate processing time based on endpoint
|
||||
processing_time = self._get_processing_time(endpoint, method)
|
||||
await asyncio.sleep(processing_time)
|
||||
|
||||
# Generate response
|
||||
response = self._generate_response(endpoint, method, data)
|
||||
|
||||
end_time = time.time()
|
||||
response_time = (end_time - start_time) * 1000
|
||||
self.response_times.append(response_time)
|
||||
|
||||
return {
|
||||
"status": 200,
|
||||
"data": response,
|
||||
"response_time_ms": response_time
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.error_count += 1
|
||||
return {
|
||||
"status": 500,
|
||||
"error": str(e),
|
||||
"response_time_ms": (time.time() - start_time) * 1000
|
||||
}
|
||||
finally:
|
||||
self.concurrent_requests -= 1
|
||||
|
||||
def _get_processing_time(self, endpoint: str, method: str) -> float:
|
||||
"""Get processing time for endpoint."""
|
||||
processing_times = {
|
||||
"/health": 0.001,
|
||||
"/pose/detect": 0.05,
|
||||
"/pose/stream": 0.02,
|
||||
"/auth/login": 0.01,
|
||||
"/auth/refresh": 0.005,
|
||||
"/config": 0.003
|
||||
}
|
||||
|
||||
base_time = processing_times.get(endpoint, 0.01)
|
||||
|
||||
# Add some variance
|
||||
return base_time * np.random.uniform(0.8, 1.2)
|
||||
|
||||
def _generate_response(self, endpoint: str, method: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Generate response for endpoint."""
|
||||
if endpoint == "/health":
|
||||
return {"status": "healthy", "timestamp": datetime.utcnow().isoformat()}
|
||||
|
||||
elif endpoint == "/pose/detect":
|
||||
return {
|
||||
"persons": [
|
||||
{
|
||||
"person_id": "person_1",
|
||||
"confidence": 0.85,
|
||||
"bounding_box": {"x": 100, "y": 150, "width": 80, "height": 180},
|
||||
"keypoints": [[x, y, 0.9] for x, y in zip(range(17), range(17))]
|
||||
}
|
||||
],
|
||||
"processing_time_ms": 45.2,
|
||||
"model_version": "v1.0"
|
||||
}
|
||||
|
||||
elif endpoint == "/auth/login":
|
||||
return {
|
||||
"access_token": "mock_access_token",
|
||||
"refresh_token": "mock_refresh_token",
|
||||
"expires_in": 3600
|
||||
}
|
||||
|
||||
else:
|
||||
return {"message": "Success", "endpoint": endpoint, "method": method}
|
||||
|
||||
def get_performance_stats(self) -> Dict[str, Any]:
|
||||
"""Get performance statistics."""
|
||||
if not self.response_times:
|
||||
return {
|
||||
"total_requests": self.request_count,
|
||||
"error_count": self.error_count,
|
||||
"error_rate": 0,
|
||||
"avg_response_time_ms": 0,
|
||||
"median_response_time_ms": 0,
|
||||
"p95_response_time_ms": 0,
|
||||
"p99_response_time_ms": 0,
|
||||
"max_concurrent_requests": self.max_concurrent,
|
||||
"requests_per_second": 0
|
||||
}
|
||||
|
||||
return {
|
||||
"total_requests": self.request_count,
|
||||
"error_count": self.error_count,
|
||||
"error_rate": self.error_count / self.request_count,
|
||||
"avg_response_time_ms": statistics.mean(self.response_times),
|
||||
"median_response_time_ms": statistics.median(self.response_times),
|
||||
"p95_response_time_ms": np.percentile(self.response_times, 95),
|
||||
"p99_response_time_ms": np.percentile(self.response_times, 99),
|
||||
"max_concurrent_requests": self.max_concurrent,
|
||||
"requests_per_second": self._calculate_rps()
|
||||
}
|
||||
|
||||
def _calculate_rps(self) -> float:
|
||||
"""Calculate requests per second."""
|
||||
if len(self.request_timestamps) < 2:
|
||||
return 0
|
||||
|
||||
duration = self.request_timestamps[-1] - self.request_timestamps[0]
|
||||
return len(self.request_timestamps) / max(duration, 0.001)
|
||||
|
||||
def enable_rate_limiting(self, requests_per_second: int):
|
||||
"""Enable rate limiting."""
|
||||
self.rate_limit_enabled = True
|
||||
self.rate_limit_per_second = requests_per_second
|
||||
|
||||
def reset_stats(self):
|
||||
"""Reset performance statistics."""
|
||||
self.request_count = 0
|
||||
self.response_times = []
|
||||
self.error_count = 0
|
||||
self.concurrent_requests = 0
|
||||
self.max_concurrent = 0
|
||||
self.request_timestamps = []
|
||||
|
||||
|
||||
class TestAPIThroughput:
|
||||
"""Test API throughput under various conditions."""
|
||||
|
||||
@pytest.fixture
|
||||
def api_server(self):
|
||||
"""Create mock API server."""
|
||||
return MockAPIServer()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_request_performance_should_fail_initially(self, api_server):
|
||||
"""Test single request performance - should fail initially."""
|
||||
start_time = time.time()
|
||||
response = await api_server.handle_request("/health")
|
||||
end_time = time.time()
|
||||
|
||||
response_time = (end_time - start_time) * 1000
|
||||
|
||||
# This will fail initially
|
||||
assert response["status"] == 200
|
||||
assert response_time < 50 # Should respond within 50ms
|
||||
assert response["response_time_ms"] > 0
|
||||
|
||||
stats = api_server.get_performance_stats()
|
||||
assert stats["total_requests"] == 1
|
||||
assert stats["error_count"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_request_handling_should_fail_initially(self, api_server):
|
||||
"""Test concurrent request handling - should fail initially."""
|
||||
# Send multiple concurrent requests
|
||||
concurrent_requests = 10
|
||||
tasks = []
|
||||
|
||||
for i in range(concurrent_requests):
|
||||
task = asyncio.create_task(api_server.handle_request("/health"))
|
||||
tasks.append(task)
|
||||
|
||||
start_time = time.time()
|
||||
responses = await asyncio.gather(*tasks)
|
||||
end_time = time.time()
|
||||
|
||||
total_time = (end_time - start_time) * 1000
|
||||
|
||||
# This will fail initially
|
||||
assert len(responses) == concurrent_requests
|
||||
assert all(r["status"] == 200 for r in responses)
|
||||
|
||||
# All requests should complete within reasonable time
|
||||
assert total_time < 200 # Should complete within 200ms
|
||||
|
||||
stats = api_server.get_performance_stats()
|
||||
assert stats["total_requests"] == concurrent_requests
|
||||
assert stats["max_concurrent_requests"] <= concurrent_requests
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sustained_load_performance_should_fail_initially(self, api_server):
|
||||
"""Test sustained load performance - should fail initially."""
|
||||
duration_seconds = 3
|
||||
target_rps = 50 # 50 requests per second
|
||||
|
||||
async def send_requests():
|
||||
"""Send requests at target rate."""
|
||||
interval = 1.0 / target_rps
|
||||
end_time = time.time() + duration_seconds
|
||||
|
||||
while time.time() < end_time:
|
||||
await api_server.handle_request("/health")
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
start_time = time.time()
|
||||
await send_requests()
|
||||
actual_duration = time.time() - start_time
|
||||
|
||||
stats = api_server.get_performance_stats()
|
||||
actual_rps = stats["requests_per_second"]
|
||||
|
||||
# This will fail initially
|
||||
assert actual_rps >= target_rps * 0.8 # Within 80% of target
|
||||
assert stats["error_rate"] < 0.05 # Less than 5% error rate
|
||||
assert stats["avg_response_time_ms"] < 100 # Average response time under 100ms
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_endpoint_performance_should_fail_initially(self, api_server):
|
||||
"""Test different endpoint performance - should fail initially."""
|
||||
endpoints = [
|
||||
"/health",
|
||||
"/pose/detect",
|
||||
"/auth/login",
|
||||
"/config"
|
||||
]
|
||||
|
||||
results = {}
|
||||
|
||||
for endpoint in endpoints:
|
||||
# Test each endpoint multiple times
|
||||
response_times = []
|
||||
|
||||
for _ in range(10):
|
||||
response = await api_server.handle_request(endpoint)
|
||||
response_times.append(response["response_time_ms"])
|
||||
|
||||
results[endpoint] = {
|
||||
"avg_response_time": statistics.mean(response_times),
|
||||
"min_response_time": min(response_times),
|
||||
"max_response_time": max(response_times)
|
||||
}
|
||||
|
||||
# This will fail initially
|
||||
# Health endpoint should be fastest
|
||||
assert results["/health"]["avg_response_time"] < results["/pose/detect"]["avg_response_time"]
|
||||
|
||||
# All endpoints should respond within reasonable time
|
||||
for endpoint, metrics in results.items():
|
||||
assert metrics["avg_response_time"] < 200 # Less than 200ms average
|
||||
assert metrics["max_response_time"] < 500 # Less than 500ms max
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limiting_behavior_should_fail_initially(self, api_server):
|
||||
"""Test rate limiting behavior - should fail initially."""
|
||||
# Enable rate limiting
|
||||
api_server.enable_rate_limiting(requests_per_second=10)
|
||||
|
||||
# Send requests faster than rate limit
|
||||
rapid_requests = 20
|
||||
tasks = []
|
||||
|
||||
for i in range(rapid_requests):
|
||||
task = asyncio.create_task(api_server.handle_request("/health"))
|
||||
tasks.append(task)
|
||||
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
# This will fail initially
|
||||
# Some requests should be rate limited
|
||||
success_responses = [r for r in responses if r["status"] == 200]
|
||||
rate_limited_responses = [r for r in responses if r["status"] == 429]
|
||||
|
||||
assert len(success_responses) > 0
|
||||
assert len(rate_limited_responses) > 0
|
||||
assert len(success_responses) + len(rate_limited_responses) == rapid_requests
|
||||
|
||||
stats = api_server.get_performance_stats()
|
||||
assert stats["error_count"] > 0 # Should have rate limit errors
|
||||
|
||||
|
||||
class TestAPILoadTesting:
|
||||
"""Test API under heavy load conditions."""
|
||||
|
||||
@pytest.fixture
|
||||
def load_test_server(self):
|
||||
"""Create server for load testing."""
|
||||
server = MockAPIServer()
|
||||
return server
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_high_concurrency_load_should_fail_initially(self, load_test_server):
|
||||
"""Test high concurrency load - should fail initially."""
|
||||
concurrent_users = 50
|
||||
requests_per_user = 5
|
||||
|
||||
async def user_session(user_id: int):
|
||||
"""Simulate user session."""
|
||||
session_responses = []
|
||||
|
||||
for i in range(requests_per_user):
|
||||
response = await load_test_server.handle_request("/health")
|
||||
session_responses.append(response)
|
||||
|
||||
# Small delay between requests
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return session_responses
|
||||
|
||||
# Create user sessions
|
||||
user_tasks = [user_session(i) for i in range(concurrent_users)]
|
||||
|
||||
start_time = time.time()
|
||||
all_sessions = await asyncio.gather(*user_tasks)
|
||||
end_time = time.time()
|
||||
|
||||
total_duration = end_time - start_time
|
||||
total_requests = concurrent_users * requests_per_user
|
||||
|
||||
# This will fail initially
|
||||
# All sessions should complete
|
||||
assert len(all_sessions) == concurrent_users
|
||||
|
||||
# Check performance metrics
|
||||
stats = load_test_server.get_performance_stats()
|
||||
assert stats["total_requests"] == total_requests
|
||||
assert stats["error_rate"] < 0.1 # Less than 10% error rate
|
||||
assert stats["requests_per_second"] > 100 # Should handle at least 100 RPS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_endpoint_load_should_fail_initially(self, load_test_server):
|
||||
"""Test mixed endpoint load - should fail initially."""
|
||||
# Define endpoint mix (realistic usage pattern)
|
||||
endpoint_mix = [
|
||||
("/health", 0.4), # 40% health checks
|
||||
("/pose/detect", 0.3), # 30% pose detection
|
||||
("/auth/login", 0.1), # 10% authentication
|
||||
("/config", 0.2) # 20% configuration
|
||||
]
|
||||
|
||||
total_requests = 100
|
||||
|
||||
async def send_mixed_requests():
|
||||
"""Send requests with mixed endpoints."""
|
||||
tasks = []
|
||||
|
||||
for i in range(total_requests):
|
||||
# Select endpoint based on distribution
|
||||
rand = np.random.random()
|
||||
cumulative = 0
|
||||
|
||||
for endpoint, probability in endpoint_mix:
|
||||
cumulative += probability
|
||||
if rand <= cumulative:
|
||||
task = asyncio.create_task(
|
||||
load_test_server.handle_request(endpoint)
|
||||
)
|
||||
tasks.append(task)
|
||||
break
|
||||
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
start_time = time.time()
|
||||
responses = await send_mixed_requests()
|
||||
end_time = time.time()
|
||||
|
||||
duration = end_time - start_time
|
||||
|
||||
# This will fail initially
|
||||
assert len(responses) == total_requests
|
||||
|
||||
# Check response distribution
|
||||
success_responses = [r for r in responses if r["status"] == 200]
|
||||
assert len(success_responses) >= total_requests * 0.9 # At least 90% success
|
||||
|
||||
stats = load_test_server.get_performance_stats()
|
||||
assert stats["requests_per_second"] > 50 # Should handle at least 50 RPS
|
||||
assert stats["avg_response_time_ms"] < 150 # Average response time under 150ms
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stress_testing_should_fail_initially(self, load_test_server):
|
||||
"""Test stress testing - should fail initially."""
|
||||
# Gradually increase load to find breaking point
|
||||
load_levels = [10, 25, 50, 100, 200]
|
||||
results = {}
|
||||
|
||||
for concurrent_requests in load_levels:
|
||||
load_test_server.reset_stats()
|
||||
|
||||
# Send concurrent requests
|
||||
tasks = [
|
||||
load_test_server.handle_request("/health")
|
||||
for _ in range(concurrent_requests)
|
||||
]
|
||||
|
||||
start_time = time.time()
|
||||
responses = await asyncio.gather(*tasks)
|
||||
end_time = time.time()
|
||||
|
||||
duration = end_time - start_time
|
||||
stats = load_test_server.get_performance_stats()
|
||||
|
||||
results[concurrent_requests] = {
|
||||
"duration": duration,
|
||||
"rps": stats["requests_per_second"],
|
||||
"error_rate": stats["error_rate"],
|
||||
"avg_response_time": stats["avg_response_time_ms"],
|
||||
"p95_response_time": stats["p95_response_time_ms"]
|
||||
}
|
||||
|
||||
# This will fail initially
|
||||
# Performance should degrade gracefully with increased load
|
||||
for load_level, metrics in results.items():
|
||||
assert metrics["error_rate"] < 0.2 # Less than 20% error rate
|
||||
assert metrics["avg_response_time"] < 1000 # Less than 1 second average
|
||||
|
||||
# Higher loads should have higher response times
|
||||
assert results[10]["avg_response_time"] <= results[200]["avg_response_time"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_usage_under_load_should_fail_initially(self, load_test_server):
|
||||
"""Test memory usage under load - should fail initially."""
|
||||
import psutil
|
||||
import os
|
||||
|
||||
process = psutil.Process(os.getpid())
|
||||
initial_memory = process.memory_info().rss
|
||||
|
||||
# Generate sustained load
|
||||
duration_seconds = 5
|
||||
target_rps = 100
|
||||
|
||||
async def sustained_load():
|
||||
"""Generate sustained load."""
|
||||
interval = 1.0 / target_rps
|
||||
end_time = time.time() + duration_seconds
|
||||
|
||||
while time.time() < end_time:
|
||||
await load_test_server.handle_request("/pose/detect")
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
await sustained_load()
|
||||
|
||||
final_memory = process.memory_info().rss
|
||||
memory_increase = final_memory - initial_memory
|
||||
|
||||
# This will fail initially
|
||||
# Memory increase should be reasonable (less than 100MB)
|
||||
assert memory_increase < 100 * 1024 * 1024
|
||||
|
||||
stats = load_test_server.get_performance_stats()
|
||||
assert stats["total_requests"] > duration_seconds * target_rps * 0.8
|
||||
|
||||
|
||||
class TestAPIPerformanceOptimization:
|
||||
"""Test API performance optimization techniques."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_caching_effect_should_fail_initially(self):
|
||||
"""Test response caching effect - should fail initially."""
|
||||
class CachedAPIServer(MockAPIServer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.cache = {}
|
||||
self.cache_hits = 0
|
||||
self.cache_misses = 0
|
||||
|
||||
async def handle_request(self, endpoint: str, method: str = "GET", data: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
cache_key = f"{method}:{endpoint}"
|
||||
|
||||
if cache_key in self.cache:
|
||||
self.cache_hits += 1
|
||||
cached_response = self.cache[cache_key].copy()
|
||||
cached_response["response_time_ms"] = 1.0 # Cached responses are fast
|
||||
return cached_response
|
||||
|
||||
self.cache_misses += 1
|
||||
response = await super().handle_request(endpoint, method, data)
|
||||
|
||||
# Cache successful responses
|
||||
if response["status"] == 200:
|
||||
self.cache[cache_key] = response.copy()
|
||||
|
||||
return response
|
||||
|
||||
cached_server = CachedAPIServer()
|
||||
|
||||
# First request (cache miss)
|
||||
response1 = await cached_server.handle_request("/health")
|
||||
|
||||
# Second request (cache hit)
|
||||
response2 = await cached_server.handle_request("/health")
|
||||
|
||||
# This will fail initially
|
||||
assert response1["status"] == 200
|
||||
assert response2["status"] == 200
|
||||
assert response2["response_time_ms"] < response1["response_time_ms"]
|
||||
assert cached_server.cache_hits == 1
|
||||
assert cached_server.cache_misses == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_pooling_effect_should_fail_initially(self):
|
||||
"""Test connection pooling effect - should fail initially."""
|
||||
# Simulate connection overhead
|
||||
class ConnectionPoolServer(MockAPIServer):
|
||||
def __init__(self, pool_size: int = 10):
|
||||
super().__init__()
|
||||
self.pool_size = pool_size
|
||||
self.active_connections = 0
|
||||
self.connection_overhead = 0.01 # 10ms connection overhead
|
||||
|
||||
async def handle_request(self, endpoint: str, method: str = "GET", data: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
# Simulate connection acquisition
|
||||
if self.active_connections < self.pool_size:
|
||||
# New connection needed
|
||||
await asyncio.sleep(self.connection_overhead)
|
||||
self.active_connections += 1
|
||||
|
||||
try:
|
||||
return await super().handle_request(endpoint, method, data)
|
||||
finally:
|
||||
# Connection returned to pool (not closed)
|
||||
pass
|
||||
|
||||
pooled_server = ConnectionPoolServer(pool_size=5)
|
||||
|
||||
# Send requests that exceed pool size
|
||||
concurrent_requests = 10
|
||||
tasks = [
|
||||
pooled_server.handle_request("/health")
|
||||
for _ in range(concurrent_requests)
|
||||
]
|
||||
|
||||
start_time = time.time()
|
||||
responses = await asyncio.gather(*tasks)
|
||||
end_time = time.time()
|
||||
|
||||
total_time = (end_time - start_time) * 1000
|
||||
|
||||
# This will fail initially
|
||||
assert len(responses) == concurrent_requests
|
||||
assert all(r["status"] == 200 for r in responses)
|
||||
|
||||
# With connection pooling, should complete reasonably fast
|
||||
assert total_time < 500 # Should complete within 500ms
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_batching_performance_should_fail_initially(self):
|
||||
"""Test request batching performance - should fail initially."""
|
||||
class BatchingServer(MockAPIServer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.batch_size = 5
|
||||
self.pending_requests = []
|
||||
self.batch_processing = False
|
||||
|
||||
async def handle_batch_request(self, requests: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Handle batch of requests."""
|
||||
# Batch processing is more efficient
|
||||
batch_overhead = 0.01 # 10ms overhead for entire batch
|
||||
await asyncio.sleep(batch_overhead)
|
||||
|
||||
responses = []
|
||||
for req in requests:
|
||||
# Individual processing is faster in batch
|
||||
processing_time = self._get_processing_time(req["endpoint"], req["method"]) * 0.5
|
||||
await asyncio.sleep(processing_time)
|
||||
|
||||
response = self._generate_response(req["endpoint"], req["method"], req.get("data"))
|
||||
responses.append({
|
||||
"status": 200,
|
||||
"data": response,
|
||||
"response_time_ms": processing_time * 1000
|
||||
})
|
||||
|
||||
return responses
|
||||
|
||||
batching_server = BatchingServer()
|
||||
|
||||
# Test individual requests vs batch
|
||||
individual_requests = 5
|
||||
|
||||
# Individual requests
|
||||
start_time = time.time()
|
||||
individual_tasks = [
|
||||
batching_server.handle_request("/health")
|
||||
for _ in range(individual_requests)
|
||||
]
|
||||
individual_responses = await asyncio.gather(*individual_tasks)
|
||||
individual_time = (time.time() - start_time) * 1000
|
||||
|
||||
# Batch request
|
||||
batch_requests = [
|
||||
{"endpoint": "/health", "method": "GET"}
|
||||
for _ in range(individual_requests)
|
||||
]
|
||||
|
||||
start_time = time.time()
|
||||
batch_responses = await batching_server.handle_batch_request(batch_requests)
|
||||
batch_time = (time.time() - start_time) * 1000
|
||||
|
||||
# This will fail initially
|
||||
assert len(individual_responses) == individual_requests
|
||||
assert len(batch_responses) == individual_requests
|
||||
|
||||
# Batch should be more efficient
|
||||
assert batch_time < individual_time
|
||||
assert all(r["status"] == 200 for r in batch_responses)
|
||||
507
v1/tests/performance/test_inference_speed.py
Normal file
507
v1/tests/performance/test_inference_speed.py
Normal file
@@ -0,0 +1,507 @@
|
||||
"""
|
||||
Performance tests for ML model inference speed.
|
||||
|
||||
Tests pose estimation model performance, throughput, and optimization.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import numpy as np
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import psutil
|
||||
import os
|
||||
|
||||
|
||||
class MockPoseModel:
|
||||
"""Mock pose estimation model for performance testing."""
|
||||
|
||||
def __init__(self, model_complexity: str = "standard"):
|
||||
self.model_complexity = model_complexity
|
||||
self.is_loaded = False
|
||||
self.inference_count = 0
|
||||
self.total_inference_time = 0.0
|
||||
self.batch_size = 1
|
||||
|
||||
# Model complexity affects inference time
|
||||
self.base_inference_time = {
|
||||
"lightweight": 0.02, # 20ms
|
||||
"standard": 0.05, # 50ms
|
||||
"high_accuracy": 0.15 # 150ms
|
||||
}.get(model_complexity, 0.05)
|
||||
|
||||
async def load_model(self):
|
||||
"""Load the model."""
|
||||
# Simulate model loading time
|
||||
load_time = {
|
||||
"lightweight": 0.5,
|
||||
"standard": 2.0,
|
||||
"high_accuracy": 5.0
|
||||
}.get(self.model_complexity, 2.0)
|
||||
|
||||
await asyncio.sleep(load_time)
|
||||
self.is_loaded = True
|
||||
|
||||
async def predict(self, features: np.ndarray) -> Dict[str, Any]:
|
||||
"""Run inference on features."""
|
||||
if not self.is_loaded:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Simulate inference computation
|
||||
batch_size = features.shape[0] if len(features.shape) > 2 else 1
|
||||
inference_time = self.base_inference_time * batch_size
|
||||
|
||||
# Add some variance
|
||||
inference_time *= np.random.uniform(0.8, 1.2)
|
||||
|
||||
await asyncio.sleep(inference_time)
|
||||
|
||||
end_time = time.time()
|
||||
actual_inference_time = end_time - start_time
|
||||
|
||||
self.inference_count += batch_size
|
||||
self.total_inference_time += actual_inference_time
|
||||
|
||||
# Generate mock predictions
|
||||
predictions = []
|
||||
for i in range(batch_size):
|
||||
predictions.append({
|
||||
"person_id": f"person_{i}",
|
||||
"confidence": np.random.uniform(0.5, 0.95),
|
||||
"keypoints": np.random.rand(17, 3).tolist(), # 17 keypoints with x,y,confidence
|
||||
"bounding_box": {
|
||||
"x": np.random.uniform(0, 640),
|
||||
"y": np.random.uniform(0, 480),
|
||||
"width": np.random.uniform(50, 200),
|
||||
"height": np.random.uniform(100, 300)
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
"predictions": predictions,
|
||||
"inference_time_ms": actual_inference_time * 1000,
|
||||
"model_complexity": self.model_complexity,
|
||||
"batch_size": batch_size
|
||||
}
|
||||
|
||||
def get_performance_stats(self) -> Dict[str, Any]:
|
||||
"""Get performance statistics."""
|
||||
avg_inference_time = (
|
||||
self.total_inference_time / self.inference_count
|
||||
if self.inference_count > 0 else 0
|
||||
)
|
||||
|
||||
return {
|
||||
"total_inferences": self.inference_count,
|
||||
"total_time_seconds": self.total_inference_time,
|
||||
"average_inference_time_ms": avg_inference_time * 1000,
|
||||
"throughput_fps": 1.0 / avg_inference_time if avg_inference_time > 0 else 0,
|
||||
"model_complexity": self.model_complexity
|
||||
}
|
||||
|
||||
|
||||
class TestInferenceSpeed:
|
||||
"""Test inference speed for different model configurations."""
|
||||
|
||||
@pytest.fixture
|
||||
def lightweight_model(self):
|
||||
"""Create lightweight model."""
|
||||
return MockPoseModel("lightweight")
|
||||
|
||||
@pytest.fixture
|
||||
def standard_model(self):
|
||||
"""Create standard model."""
|
||||
return MockPoseModel("standard")
|
||||
|
||||
@pytest.fixture
|
||||
def high_accuracy_model(self):
|
||||
"""Create high accuracy model."""
|
||||
return MockPoseModel("high_accuracy")
|
||||
|
||||
@pytest.fixture
|
||||
def sample_features(self):
|
||||
"""Create sample feature data."""
|
||||
return np.random.rand(64, 32) # 64x32 feature matrix
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_inference_speed_should_fail_initially(self, standard_model, sample_features):
|
||||
"""Test single inference speed - should fail initially."""
|
||||
await standard_model.load_model()
|
||||
|
||||
start_time = time.time()
|
||||
result = await standard_model.predict(sample_features)
|
||||
end_time = time.time()
|
||||
|
||||
inference_time = (end_time - start_time) * 1000 # Convert to ms
|
||||
|
||||
# This will fail initially
|
||||
assert inference_time < 100 # Should be less than 100ms
|
||||
assert result["inference_time_ms"] > 0
|
||||
assert len(result["predictions"]) > 0
|
||||
assert result["model_complexity"] == "standard"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_complexity_comparison_should_fail_initially(self, sample_features):
|
||||
"""Test model complexity comparison - should fail initially."""
|
||||
models = {
|
||||
"lightweight": MockPoseModel("lightweight"),
|
||||
"standard": MockPoseModel("standard"),
|
||||
"high_accuracy": MockPoseModel("high_accuracy")
|
||||
}
|
||||
|
||||
# Load all models
|
||||
for model in models.values():
|
||||
await model.load_model()
|
||||
|
||||
# Run inference on each model
|
||||
results = {}
|
||||
for name, model in models.items():
|
||||
start_time = time.time()
|
||||
result = await model.predict(sample_features)
|
||||
end_time = time.time()
|
||||
|
||||
results[name] = {
|
||||
"inference_time_ms": (end_time - start_time) * 1000,
|
||||
"result": result
|
||||
}
|
||||
|
||||
# This will fail initially
|
||||
# Lightweight should be fastest
|
||||
assert results["lightweight"]["inference_time_ms"] < results["standard"]["inference_time_ms"]
|
||||
assert results["standard"]["inference_time_ms"] < results["high_accuracy"]["inference_time_ms"]
|
||||
|
||||
# All should complete within reasonable time
|
||||
for name, result in results.items():
|
||||
assert result["inference_time_ms"] < 500 # Less than 500ms
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_inference_performance_should_fail_initially(self, standard_model):
|
||||
"""Test batch inference performance - should fail initially."""
|
||||
await standard_model.load_model()
|
||||
|
||||
# Test different batch sizes
|
||||
batch_sizes = [1, 4, 8, 16]
|
||||
results = {}
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
# Create batch of features
|
||||
batch_features = np.random.rand(batch_size, 64, 32)
|
||||
|
||||
start_time = time.time()
|
||||
result = await standard_model.predict(batch_features)
|
||||
end_time = time.time()
|
||||
|
||||
total_time = (end_time - start_time) * 1000
|
||||
per_sample_time = total_time / batch_size
|
||||
|
||||
results[batch_size] = {
|
||||
"total_time_ms": total_time,
|
||||
"per_sample_time_ms": per_sample_time,
|
||||
"throughput_fps": 1000 / per_sample_time,
|
||||
"predictions": len(result["predictions"])
|
||||
}
|
||||
|
||||
# This will fail initially
|
||||
# Batch processing should be more efficient per sample
|
||||
assert results[1]["per_sample_time_ms"] > results[4]["per_sample_time_ms"]
|
||||
assert results[4]["per_sample_time_ms"] > results[8]["per_sample_time_ms"]
|
||||
|
||||
# Verify correct number of predictions
|
||||
for batch_size, result in results.items():
|
||||
assert result["predictions"] == batch_size
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sustained_inference_performance_should_fail_initially(self, standard_model, sample_features):
|
||||
"""Test sustained inference performance - should fail initially."""
|
||||
await standard_model.load_model()
|
||||
|
||||
# Run many inferences to test sustained performance
|
||||
num_inferences = 50
|
||||
inference_times = []
|
||||
|
||||
for i in range(num_inferences):
|
||||
start_time = time.time()
|
||||
await standard_model.predict(sample_features)
|
||||
end_time = time.time()
|
||||
|
||||
inference_times.append((end_time - start_time) * 1000)
|
||||
|
||||
# This will fail initially
|
||||
# Calculate performance metrics
|
||||
avg_time = np.mean(inference_times)
|
||||
std_time = np.std(inference_times)
|
||||
min_time = np.min(inference_times)
|
||||
max_time = np.max(inference_times)
|
||||
|
||||
assert avg_time < 100 # Average should be less than 100ms
|
||||
assert std_time < 20 # Standard deviation should be low (consistent performance)
|
||||
assert max_time < avg_time * 2 # No inference should take more than 2x average
|
||||
|
||||
# Check model statistics
|
||||
stats = standard_model.get_performance_stats()
|
||||
assert stats["total_inferences"] == num_inferences
|
||||
assert stats["throughput_fps"] > 10 # Should achieve at least 10 FPS
|
||||
|
||||
|
||||
class TestInferenceOptimization:
|
||||
"""Test inference optimization techniques."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_warmup_effect_should_fail_initially(self, standard_model, sample_features):
|
||||
"""Test model warmup effect - should fail initially."""
|
||||
await standard_model.load_model()
|
||||
|
||||
# First inference (cold start)
|
||||
start_time = time.time()
|
||||
await standard_model.predict(sample_features)
|
||||
cold_start_time = (time.time() - start_time) * 1000
|
||||
|
||||
# Subsequent inferences (warmed up)
|
||||
warm_times = []
|
||||
for _ in range(5):
|
||||
start_time = time.time()
|
||||
await standard_model.predict(sample_features)
|
||||
warm_times.append((time.time() - start_time) * 1000)
|
||||
|
||||
avg_warm_time = np.mean(warm_times)
|
||||
|
||||
# This will fail initially
|
||||
# Warm inferences should be faster than cold start
|
||||
assert avg_warm_time <= cold_start_time
|
||||
assert cold_start_time > 0
|
||||
assert avg_warm_time > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_inference_performance_should_fail_initially(self, sample_features):
|
||||
"""Test concurrent inference performance - should fail initially."""
|
||||
# Create multiple model instances
|
||||
models = [MockPoseModel("standard") for _ in range(3)]
|
||||
|
||||
# Load all models
|
||||
for model in models:
|
||||
await model.load_model()
|
||||
|
||||
async def run_inference(model, features):
|
||||
start_time = time.time()
|
||||
result = await model.predict(features)
|
||||
end_time = time.time()
|
||||
return (end_time - start_time) * 1000
|
||||
|
||||
# Run concurrent inferences
|
||||
tasks = [run_inference(model, sample_features) for model in models]
|
||||
inference_times = await asyncio.gather(*tasks)
|
||||
|
||||
# This will fail initially
|
||||
# All inferences should complete
|
||||
assert len(inference_times) == 3
|
||||
assert all(time > 0 for time in inference_times)
|
||||
|
||||
# Concurrent execution shouldn't be much slower than sequential
|
||||
avg_concurrent_time = np.mean(inference_times)
|
||||
assert avg_concurrent_time < 200 # Should complete within 200ms each
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_usage_during_inference_should_fail_initially(self, standard_model, sample_features):
|
||||
"""Test memory usage during inference - should fail initially."""
|
||||
process = psutil.Process(os.getpid())
|
||||
|
||||
await standard_model.load_model()
|
||||
initial_memory = process.memory_info().rss
|
||||
|
||||
# Run multiple inferences
|
||||
for i in range(20):
|
||||
await standard_model.predict(sample_features)
|
||||
|
||||
# Check memory every 5 inferences
|
||||
if i % 5 == 0:
|
||||
current_memory = process.memory_info().rss
|
||||
memory_increase = current_memory - initial_memory
|
||||
|
||||
# This will fail initially
|
||||
# Memory increase should be reasonable (less than 50MB)
|
||||
assert memory_increase < 50 * 1024 * 1024
|
||||
|
||||
final_memory = process.memory_info().rss
|
||||
total_increase = final_memory - initial_memory
|
||||
|
||||
# Total memory increase should be reasonable
|
||||
assert total_increase < 100 * 1024 * 1024 # Less than 100MB
|
||||
|
||||
|
||||
class TestInferenceAccuracy:
|
||||
"""Test inference accuracy and quality metrics."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prediction_consistency_should_fail_initially(self, standard_model, sample_features):
|
||||
"""Test prediction consistency - should fail initially."""
|
||||
await standard_model.load_model()
|
||||
|
||||
# Run same inference multiple times
|
||||
results = []
|
||||
for _ in range(5):
|
||||
result = await standard_model.predict(sample_features)
|
||||
results.append(result)
|
||||
|
||||
# This will fail initially
|
||||
# All results should have similar structure
|
||||
for result in results:
|
||||
assert "predictions" in result
|
||||
assert "inference_time_ms" in result
|
||||
assert len(result["predictions"]) > 0
|
||||
|
||||
# Inference times should be consistent
|
||||
inference_times = [r["inference_time_ms"] for r in results]
|
||||
avg_time = np.mean(inference_times)
|
||||
std_time = np.std(inference_times)
|
||||
|
||||
assert std_time < avg_time * 0.5 # Standard deviation should be less than 50% of mean
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confidence_score_distribution_should_fail_initially(self, standard_model, sample_features):
|
||||
"""Test confidence score distribution - should fail initially."""
|
||||
await standard_model.load_model()
|
||||
|
||||
# Collect confidence scores from multiple inferences
|
||||
all_confidences = []
|
||||
|
||||
for _ in range(20):
|
||||
result = await standard_model.predict(sample_features)
|
||||
for prediction in result["predictions"]:
|
||||
all_confidences.append(prediction["confidence"])
|
||||
|
||||
# This will fail initially
|
||||
if all_confidences: # Only test if we have predictions
|
||||
# Confidence scores should be in valid range
|
||||
assert all(0.0 <= conf <= 1.0 for conf in all_confidences)
|
||||
|
||||
# Should have reasonable distribution
|
||||
avg_confidence = np.mean(all_confidences)
|
||||
assert 0.3 <= avg_confidence <= 0.95 # Reasonable average confidence
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keypoint_detection_quality_should_fail_initially(self, standard_model, sample_features):
|
||||
"""Test keypoint detection quality - should fail initially."""
|
||||
await standard_model.load_model()
|
||||
|
||||
result = await standard_model.predict(sample_features)
|
||||
|
||||
# This will fail initially
|
||||
for prediction in result["predictions"]:
|
||||
keypoints = prediction["keypoints"]
|
||||
|
||||
# Should have correct number of keypoints
|
||||
assert len(keypoints) == 17 # Standard pose has 17 keypoints
|
||||
|
||||
# Each keypoint should have x, y, confidence
|
||||
for keypoint in keypoints:
|
||||
assert len(keypoint) == 3
|
||||
x, y, conf = keypoint
|
||||
assert isinstance(x, (int, float))
|
||||
assert isinstance(y, (int, float))
|
||||
assert 0.0 <= conf <= 1.0
|
||||
|
||||
|
||||
class TestInferenceScaling:
|
||||
"""Test inference scaling characteristics."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_size_scaling_should_fail_initially(self, standard_model):
|
||||
"""Test inference scaling with input size - should fail initially."""
|
||||
await standard_model.load_model()
|
||||
|
||||
# Test different input sizes
|
||||
input_sizes = [(32, 16), (64, 32), (128, 64), (256, 128)]
|
||||
results = {}
|
||||
|
||||
for height, width in input_sizes:
|
||||
features = np.random.rand(height, width)
|
||||
|
||||
start_time = time.time()
|
||||
result = await standard_model.predict(features)
|
||||
end_time = time.time()
|
||||
|
||||
inference_time = (end_time - start_time) * 1000
|
||||
input_size = height * width
|
||||
|
||||
results[input_size] = {
|
||||
"inference_time_ms": inference_time,
|
||||
"dimensions": (height, width),
|
||||
"predictions": len(result["predictions"])
|
||||
}
|
||||
|
||||
# This will fail initially
|
||||
# Larger inputs should generally take longer
|
||||
sizes = sorted(results.keys())
|
||||
for i in range(len(sizes) - 1):
|
||||
current_size = sizes[i]
|
||||
next_size = sizes[i + 1]
|
||||
|
||||
# Allow some variance, but larger inputs should generally be slower
|
||||
time_ratio = results[next_size]["inference_time_ms"] / results[current_size]["inference_time_ms"]
|
||||
assert time_ratio >= 0.8 # Next size shouldn't be much faster
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_throughput_under_load_should_fail_initially(self, standard_model, sample_features):
|
||||
"""Test throughput under sustained load - should fail initially."""
|
||||
await standard_model.load_model()
|
||||
|
||||
# Simulate sustained load
|
||||
duration_seconds = 5
|
||||
start_time = time.time()
|
||||
inference_count = 0
|
||||
|
||||
while time.time() - start_time < duration_seconds:
|
||||
await standard_model.predict(sample_features)
|
||||
inference_count += 1
|
||||
|
||||
actual_duration = time.time() - start_time
|
||||
throughput = inference_count / actual_duration
|
||||
|
||||
# This will fail initially
|
||||
# Should maintain reasonable throughput under load
|
||||
assert throughput > 5 # At least 5 FPS
|
||||
assert inference_count > 20 # Should complete at least 20 inferences in 5 seconds
|
||||
|
||||
# Check model statistics
|
||||
stats = standard_model.get_performance_stats()
|
||||
assert stats["total_inferences"] >= inference_count
|
||||
assert stats["throughput_fps"] > 0
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestInferenceBenchmarks:
|
||||
"""Benchmark tests for inference performance."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_benchmark_lightweight_model_should_fail_initially(self, benchmark):
|
||||
"""Benchmark lightweight model performance - should fail initially."""
|
||||
model = MockPoseModel("lightweight")
|
||||
await model.load_model()
|
||||
features = np.random.rand(64, 32)
|
||||
|
||||
async def run_inference():
|
||||
return await model.predict(features)
|
||||
|
||||
# This will fail initially
|
||||
# Benchmark the inference
|
||||
result = await run_inference()
|
||||
assert result["inference_time_ms"] < 50 # Should be less than 50ms
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_benchmark_batch_processing_should_fail_initially(self, benchmark):
|
||||
"""Benchmark batch processing performance - should fail initially."""
|
||||
model = MockPoseModel("standard")
|
||||
await model.load_model()
|
||||
batch_features = np.random.rand(8, 64, 32) # Batch of 8
|
||||
|
||||
async def run_batch_inference():
|
||||
return await model.predict(batch_features)
|
||||
|
||||
# This will fail initially
|
||||
result = await run_batch_inference()
|
||||
assert len(result["predictions"]) == 8
|
||||
assert result["inference_time_ms"] < 200 # Batch should be efficient
|
||||
264
v1/tests/unit/test_csi_extractor.py
Normal file
264
v1/tests/unit/test_csi_extractor.py
Normal file
@@ -0,0 +1,264 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
import torch
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from src.hardware.csi_extractor import CSIExtractor, CSIExtractionError
|
||||
|
||||
|
||||
class TestCSIExtractor:
|
||||
"""Test suite for CSI Extractor following London School TDD principles"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Configuration for CSI extractor"""
|
||||
return {
|
||||
'interface': 'wlan0',
|
||||
'channel': 6,
|
||||
'bandwidth': 20,
|
||||
'sample_rate': 1000,
|
||||
'buffer_size': 1024,
|
||||
'extraction_timeout': 5.0
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_router_interface(self):
|
||||
"""Mock router interface for testing"""
|
||||
mock_router = Mock()
|
||||
mock_router.is_connected = True
|
||||
mock_router.execute_command = Mock()
|
||||
return mock_router
|
||||
|
||||
@pytest.fixture
|
||||
def csi_extractor(self, mock_config, mock_router_interface):
|
||||
"""Create CSI extractor instance for testing"""
|
||||
return CSIExtractor(mock_config, mock_router_interface)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_csi_data(self):
|
||||
"""Generate synthetic CSI data for testing"""
|
||||
# Simulate CSI data: complex values for multiple subcarriers
|
||||
num_subcarriers = 56
|
||||
num_antennas = 3
|
||||
amplitude = np.random.uniform(0.1, 2.0, (num_antennas, num_subcarriers))
|
||||
phase = np.random.uniform(-np.pi, np.pi, (num_antennas, num_subcarriers))
|
||||
return amplitude * np.exp(1j * phase)
|
||||
|
||||
def test_extractor_initialization_creates_correct_configuration(self, mock_config, mock_router_interface):
|
||||
"""Test that CSI extractor initializes with correct configuration"""
|
||||
# Act
|
||||
extractor = CSIExtractor(mock_config, mock_router_interface)
|
||||
|
||||
# Assert
|
||||
assert extractor is not None
|
||||
assert extractor.interface == mock_config['interface']
|
||||
assert extractor.channel == mock_config['channel']
|
||||
assert extractor.bandwidth == mock_config['bandwidth']
|
||||
assert extractor.sample_rate == mock_config['sample_rate']
|
||||
assert extractor.buffer_size == mock_config['buffer_size']
|
||||
assert extractor.extraction_timeout == mock_config['extraction_timeout']
|
||||
assert extractor.router_interface == mock_router_interface
|
||||
assert not extractor.is_extracting
|
||||
|
||||
def test_start_extraction_configures_monitor_mode(self, csi_extractor, mock_router_interface):
|
||||
"""Test that start_extraction configures monitor mode"""
|
||||
# Arrange
|
||||
mock_router_interface.enable_monitor_mode.return_value = True
|
||||
mock_router_interface.execute_command.return_value = "CSI extraction started"
|
||||
|
||||
# Act
|
||||
result = csi_extractor.start_extraction()
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert csi_extractor.is_extracting is True
|
||||
mock_router_interface.enable_monitor_mode.assert_called_once_with(csi_extractor.interface)
|
||||
|
||||
def test_start_extraction_handles_monitor_mode_failure(self, csi_extractor, mock_router_interface):
|
||||
"""Test that start_extraction handles monitor mode configuration failure"""
|
||||
# Arrange
|
||||
mock_router_interface.enable_monitor_mode.return_value = False
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(CSIExtractionError):
|
||||
csi_extractor.start_extraction()
|
||||
|
||||
assert csi_extractor.is_extracting is False
|
||||
|
||||
def test_stop_extraction_disables_monitor_mode(self, csi_extractor, mock_router_interface):
|
||||
"""Test that stop_extraction disables monitor mode"""
|
||||
# Arrange
|
||||
mock_router_interface.enable_monitor_mode.return_value = True
|
||||
mock_router_interface.disable_monitor_mode.return_value = True
|
||||
mock_router_interface.execute_command.return_value = "CSI extraction started"
|
||||
|
||||
csi_extractor.start_extraction()
|
||||
|
||||
# Act
|
||||
result = csi_extractor.stop_extraction()
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert csi_extractor.is_extracting is False
|
||||
mock_router_interface.disable_monitor_mode.assert_called_once_with(csi_extractor.interface)
|
||||
|
||||
def test_extract_csi_data_returns_valid_format(self, csi_extractor, mock_router_interface, mock_csi_data):
|
||||
"""Test that extract_csi_data returns data in valid format"""
|
||||
# Arrange
|
||||
mock_router_interface.enable_monitor_mode.return_value = True
|
||||
mock_router_interface.execute_command.return_value = "CSI extraction started"
|
||||
|
||||
# Mock the CSI data extraction
|
||||
with patch.object(csi_extractor, '_parse_csi_output', return_value=mock_csi_data):
|
||||
csi_extractor.start_extraction()
|
||||
|
||||
# Act
|
||||
csi_data = csi_extractor.extract_csi_data()
|
||||
|
||||
# Assert
|
||||
assert csi_data is not None
|
||||
assert isinstance(csi_data, np.ndarray)
|
||||
assert csi_data.dtype == np.complex128
|
||||
assert csi_data.shape == mock_csi_data.shape
|
||||
|
||||
def test_extract_csi_data_requires_active_extraction(self, csi_extractor):
|
||||
"""Test that extract_csi_data requires active extraction"""
|
||||
# Act & Assert
|
||||
with pytest.raises(CSIExtractionError):
|
||||
csi_extractor.extract_csi_data()
|
||||
|
||||
def test_extract_csi_data_handles_timeout(self, csi_extractor, mock_router_interface):
|
||||
"""Test that extract_csi_data handles extraction timeout"""
|
||||
# Arrange
|
||||
mock_router_interface.enable_monitor_mode.return_value = True
|
||||
mock_router_interface.execute_command.side_effect = [
|
||||
"CSI extraction started",
|
||||
Exception("Timeout")
|
||||
]
|
||||
|
||||
csi_extractor.start_extraction()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(CSIExtractionError):
|
||||
csi_extractor.extract_csi_data()
|
||||
|
||||
def test_convert_to_tensor_produces_correct_format(self, csi_extractor, mock_csi_data):
|
||||
"""Test that convert_to_tensor produces correctly formatted tensor"""
|
||||
# Act
|
||||
tensor = csi_extractor.convert_to_tensor(mock_csi_data)
|
||||
|
||||
# Assert
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
assert tensor.dtype == torch.float32
|
||||
assert tensor.shape[0] == mock_csi_data.shape[0] * 2 # Real and imaginary parts
|
||||
assert tensor.shape[1] == mock_csi_data.shape[1]
|
||||
|
||||
def test_convert_to_tensor_handles_invalid_input(self, csi_extractor):
|
||||
"""Test that convert_to_tensor handles invalid input"""
|
||||
# Arrange
|
||||
invalid_data = "not an array"
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
csi_extractor.convert_to_tensor(invalid_data)
|
||||
|
||||
def test_get_extraction_stats_returns_valid_statistics(self, csi_extractor, mock_router_interface):
|
||||
"""Test that get_extraction_stats returns valid statistics"""
|
||||
# Arrange
|
||||
mock_router_interface.enable_monitor_mode.return_value = True
|
||||
mock_router_interface.execute_command.return_value = "CSI extraction started"
|
||||
|
||||
csi_extractor.start_extraction()
|
||||
|
||||
# Act
|
||||
stats = csi_extractor.get_extraction_stats()
|
||||
|
||||
# Assert
|
||||
assert stats is not None
|
||||
assert isinstance(stats, dict)
|
||||
assert 'samples_extracted' in stats
|
||||
assert 'extraction_rate' in stats
|
||||
assert 'buffer_utilization' in stats
|
||||
assert 'last_extraction_time' in stats
|
||||
|
||||
def test_set_channel_configures_wifi_channel(self, csi_extractor, mock_router_interface):
|
||||
"""Test that set_channel configures WiFi channel"""
|
||||
# Arrange
|
||||
new_channel = 11
|
||||
mock_router_interface.execute_command.return_value = f"Channel set to {new_channel}"
|
||||
|
||||
# Act
|
||||
result = csi_extractor.set_channel(new_channel)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert csi_extractor.channel == new_channel
|
||||
mock_router_interface.execute_command.assert_called()
|
||||
|
||||
def test_set_channel_validates_channel_range(self, csi_extractor):
|
||||
"""Test that set_channel validates channel range"""
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
csi_extractor.set_channel(0) # Invalid channel
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
csi_extractor.set_channel(15) # Invalid channel
|
||||
|
||||
def test_extractor_supports_context_manager(self, csi_extractor, mock_router_interface):
|
||||
"""Test that CSI extractor supports context manager protocol"""
|
||||
# Arrange
|
||||
mock_router_interface.enable_monitor_mode.return_value = True
|
||||
mock_router_interface.disable_monitor_mode.return_value = True
|
||||
mock_router_interface.execute_command.return_value = "CSI extraction started"
|
||||
|
||||
# Act
|
||||
with csi_extractor as extractor:
|
||||
# Assert
|
||||
assert extractor.is_extracting is True
|
||||
|
||||
# Assert - extraction should be stopped after context
|
||||
assert csi_extractor.is_extracting is False
|
||||
|
||||
def test_extractor_validates_configuration(self, mock_router_interface):
|
||||
"""Test that CSI extractor validates configuration parameters"""
|
||||
# Arrange
|
||||
invalid_config = {
|
||||
'interface': '', # Invalid interface
|
||||
'channel': 6,
|
||||
'bandwidth': 20
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
CSIExtractor(invalid_config, mock_router_interface)
|
||||
|
||||
def test_parse_csi_output_processes_raw_data(self, csi_extractor):
|
||||
"""Test that _parse_csi_output processes raw CSI data correctly"""
|
||||
# Arrange
|
||||
raw_output = "CSI_DATA: 1.5+0.5j,2.0-1.0j,0.8+1.2j"
|
||||
|
||||
# Act
|
||||
parsed_data = csi_extractor._parse_csi_output(raw_output)
|
||||
|
||||
# Assert
|
||||
assert parsed_data is not None
|
||||
assert isinstance(parsed_data, np.ndarray)
|
||||
assert parsed_data.dtype == np.complex128
|
||||
|
||||
def test_buffer_management_handles_overflow(self, csi_extractor, mock_router_interface, mock_csi_data):
|
||||
"""Test that buffer management handles overflow correctly"""
|
||||
# Arrange
|
||||
mock_router_interface.enable_monitor_mode.return_value = True
|
||||
mock_router_interface.execute_command.return_value = "CSI extraction started"
|
||||
|
||||
with patch.object(csi_extractor, '_parse_csi_output', return_value=mock_csi_data):
|
||||
csi_extractor.start_extraction()
|
||||
|
||||
# Fill buffer beyond capacity
|
||||
for _ in range(csi_extractor.buffer_size + 10):
|
||||
csi_extractor._add_to_buffer(mock_csi_data)
|
||||
|
||||
# Act
|
||||
stats = csi_extractor.get_extraction_stats()
|
||||
|
||||
# Assert
|
||||
assert stats['buffer_utilization'] <= 1.0 # Should not exceed 100%
|
||||
588
v1/tests/unit/test_csi_extractor_direct.py
Normal file
588
v1/tests/unit/test_csi_extractor_direct.py
Normal file
@@ -0,0 +1,588 @@
|
||||
"""Direct tests for CSI extractor avoiding import issues."""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||||
from typing import Dict, Any, Optional
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Add src to path for direct import
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../'))
|
||||
|
||||
# Import the CSI extractor module directly
|
||||
from src.hardware.csi_extractor import (
|
||||
CSIExtractor,
|
||||
CSIParseError,
|
||||
CSIData,
|
||||
ESP32CSIParser,
|
||||
RouterCSIParser,
|
||||
CSIValidationError
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestCSIExtractorDirect:
|
||||
"""Test CSI extractor with direct imports."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger(self):
|
||||
"""Mock logger for testing."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def esp32_config(self):
|
||||
"""ESP32 configuration for testing."""
|
||||
return {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 1024,
|
||||
'timeout': 5.0,
|
||||
'validation_enabled': True,
|
||||
'retry_attempts': 3
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def router_config(self):
|
||||
"""Router configuration for testing."""
|
||||
return {
|
||||
'hardware_type': 'router',
|
||||
'sampling_rate': 50,
|
||||
'buffer_size': 512,
|
||||
'timeout': 10.0,
|
||||
'validation_enabled': False,
|
||||
'retry_attempts': 1
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def sample_csi_data(self):
|
||||
"""Sample CSI data for testing."""
|
||||
return CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={'source': 'esp32', 'channel': 6}
|
||||
)
|
||||
|
||||
# Initialization tests
|
||||
def test_should_initialize_with_valid_config(self, esp32_config, mock_logger):
|
||||
"""Should initialize CSI extractor with valid configuration."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
assert extractor.config == esp32_config
|
||||
assert extractor.logger == mock_logger
|
||||
assert extractor.is_connected == False
|
||||
assert extractor.hardware_type == 'esp32'
|
||||
|
||||
def test_should_create_esp32_parser(self, esp32_config, mock_logger):
|
||||
"""Should create ESP32 parser when hardware_type is esp32."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
assert isinstance(extractor.parser, ESP32CSIParser)
|
||||
|
||||
def test_should_create_router_parser(self, router_config, mock_logger):
|
||||
"""Should create router parser when hardware_type is router."""
|
||||
extractor = CSIExtractor(config=router_config, logger=mock_logger)
|
||||
|
||||
assert isinstance(extractor.parser, RouterCSIParser)
|
||||
assert extractor.hardware_type == 'router'
|
||||
|
||||
def test_should_raise_error_for_unsupported_hardware(self, mock_logger):
|
||||
"""Should raise error for unsupported hardware type."""
|
||||
invalid_config = {
|
||||
'hardware_type': 'unsupported',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 1024,
|
||||
'timeout': 5.0
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported hardware type: unsupported"):
|
||||
CSIExtractor(config=invalid_config, logger=mock_logger)
|
||||
|
||||
# Configuration validation tests
|
||||
def test_config_validation_missing_fields(self, mock_logger):
|
||||
"""Should validate required configuration fields."""
|
||||
invalid_config = {'invalid': 'config'}
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required configuration"):
|
||||
CSIExtractor(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_config_validation_negative_sampling_rate(self, mock_logger):
|
||||
"""Should validate sampling_rate is positive."""
|
||||
invalid_config = {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': -1,
|
||||
'buffer_size': 1024,
|
||||
'timeout': 5.0
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="sampling_rate must be positive"):
|
||||
CSIExtractor(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_config_validation_zero_buffer_size(self, mock_logger):
|
||||
"""Should validate buffer_size is positive."""
|
||||
invalid_config = {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 0,
|
||||
'timeout': 5.0
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="buffer_size must be positive"):
|
||||
CSIExtractor(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_config_validation_negative_timeout(self, mock_logger):
|
||||
"""Should validate timeout is positive."""
|
||||
invalid_config = {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 1024,
|
||||
'timeout': -1.0
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="timeout must be positive"):
|
||||
CSIExtractor(config=invalid_config, logger=mock_logger)
|
||||
|
||||
# Connection tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_establish_connection_successfully(self, esp32_config, mock_logger):
|
||||
"""Should establish connection to hardware successfully."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
with patch.object(extractor, '_establish_hardware_connection', new_callable=AsyncMock) as mock_connect:
|
||||
mock_connect.return_value = True
|
||||
|
||||
result = await extractor.connect()
|
||||
|
||||
assert result == True
|
||||
assert extractor.is_connected == True
|
||||
mock_connect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_handle_connection_failure(self, esp32_config, mock_logger):
|
||||
"""Should handle connection failure gracefully."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
with patch.object(extractor, '_establish_hardware_connection', new_callable=AsyncMock) as mock_connect:
|
||||
mock_connect.side_effect = ConnectionError("Hardware not found")
|
||||
|
||||
result = await extractor.connect()
|
||||
|
||||
assert result == False
|
||||
assert extractor.is_connected == False
|
||||
extractor.logger.error.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_disconnect_properly(self, esp32_config, mock_logger):
|
||||
"""Should disconnect from hardware properly."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
|
||||
with patch.object(extractor, '_close_hardware_connection', new_callable=AsyncMock) as mock_disconnect:
|
||||
await extractor.disconnect()
|
||||
|
||||
assert extractor.is_connected == False
|
||||
mock_disconnect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_when_not_connected(self, esp32_config, mock_logger):
|
||||
"""Should handle disconnect when not connected."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = False
|
||||
|
||||
with patch.object(extractor, '_close_hardware_connection', new_callable=AsyncMock) as mock_close:
|
||||
await extractor.disconnect()
|
||||
|
||||
# Should not call close when not connected
|
||||
mock_close.assert_not_called()
|
||||
assert extractor.is_connected == False
|
||||
|
||||
# Data extraction tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_extract_csi_data_successfully(self, esp32_config, mock_logger, sample_csi_data):
|
||||
"""Should extract CSI data successfully from hardware."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
|
||||
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
|
||||
with patch.object(extractor.parser, 'parse', return_value=sample_csi_data) as mock_parse:
|
||||
mock_read.return_value = b"raw_csi_data"
|
||||
|
||||
result = await extractor.extract_csi()
|
||||
|
||||
assert result == sample_csi_data
|
||||
mock_read.assert_called_once()
|
||||
mock_parse.assert_called_once_with(b"raw_csi_data")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_handle_extraction_failure_when_not_connected(self, esp32_config, mock_logger):
|
||||
"""Should handle extraction failure when not connected."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = False
|
||||
|
||||
with pytest.raises(CSIParseError, match="Not connected to hardware"):
|
||||
await extractor.extract_csi()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_retry_on_temporary_failure(self, esp32_config, mock_logger, sample_csi_data):
|
||||
"""Should retry extraction on temporary failure."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
|
||||
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
|
||||
with patch.object(extractor.parser, 'parse') as mock_parse:
|
||||
# First two calls fail, third succeeds
|
||||
mock_read.side_effect = [ConnectionError(), ConnectionError(), b"raw_data"]
|
||||
mock_parse.return_value = sample_csi_data
|
||||
|
||||
result = await extractor.extract_csi()
|
||||
|
||||
assert result == sample_csi_data
|
||||
assert mock_read.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_with_validation_disabled(self, esp32_config, mock_logger, sample_csi_data):
|
||||
"""Should skip validation when disabled."""
|
||||
esp32_config['validation_enabled'] = False
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
|
||||
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
|
||||
with patch.object(extractor.parser, 'parse', return_value=sample_csi_data) as mock_parse:
|
||||
with patch.object(extractor, 'validate_csi_data') as mock_validate:
|
||||
mock_read.return_value = b"raw_data"
|
||||
|
||||
result = await extractor.extract_csi()
|
||||
|
||||
assert result == sample_csi_data
|
||||
mock_validate.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_max_retries_exceeded(self, esp32_config, mock_logger):
|
||||
"""Should raise error after max retries exceeded."""
|
||||
esp32_config['retry_attempts'] = 2
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
|
||||
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
|
||||
mock_read.side_effect = ConnectionError("Connection failed")
|
||||
|
||||
with pytest.raises(CSIParseError, match="Extraction failed after 2 attempts"):
|
||||
await extractor.extract_csi()
|
||||
|
||||
assert mock_read.call_count == 2
|
||||
|
||||
# Validation tests
|
||||
def test_should_validate_csi_data_successfully(self, esp32_config, mock_logger, sample_csi_data):
|
||||
"""Should validate CSI data successfully."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
result = extractor.validate_csi_data(sample_csi_data)
|
||||
|
||||
assert result == True
|
||||
|
||||
def test_validation_empty_amplitude(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for empty amplitude."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.array([]),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Empty amplitude data"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_empty_phase(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for empty phase."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.array([]),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Empty phase data"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_invalid_frequency(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for invalid frequency."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=0,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid frequency"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_invalid_bandwidth(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for invalid bandwidth."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=0,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid bandwidth"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_invalid_subcarriers(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for invalid subcarriers."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=0,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid number of subcarriers"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_invalid_antennas(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for invalid antennas."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=0,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid number of antennas"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_snr_too_low(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for SNR too low."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=-100,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid SNR value"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_snr_too_high(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for SNR too high."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=100,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid SNR value"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
# Streaming tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_start_streaming_successfully(self, esp32_config, mock_logger, sample_csi_data):
|
||||
"""Should start CSI data streaming successfully."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
callback = Mock()
|
||||
|
||||
with patch.object(extractor, 'extract_csi', new_callable=AsyncMock) as mock_extract:
|
||||
mock_extract.return_value = sample_csi_data
|
||||
|
||||
# Start streaming with limited iterations to avoid infinite loop
|
||||
streaming_task = asyncio.create_task(extractor.start_streaming(callback))
|
||||
await asyncio.sleep(0.1) # Let it run briefly
|
||||
extractor.stop_streaming()
|
||||
await streaming_task
|
||||
|
||||
callback.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_stop_streaming_gracefully(self, esp32_config, mock_logger):
|
||||
"""Should stop streaming gracefully."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_streaming = True
|
||||
|
||||
extractor.stop_streaming()
|
||||
|
||||
assert extractor.is_streaming == False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_with_exception(self, esp32_config, mock_logger):
|
||||
"""Should handle exceptions during streaming."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
callback = Mock()
|
||||
|
||||
with patch.object(extractor, 'extract_csi', new_callable=AsyncMock) as mock_extract:
|
||||
mock_extract.side_effect = Exception("Extraction error")
|
||||
|
||||
# Start streaming and let it handle the exception
|
||||
streaming_task = asyncio.create_task(extractor.start_streaming(callback))
|
||||
await asyncio.sleep(0.1) # Let it run briefly and hit the exception
|
||||
await streaming_task
|
||||
|
||||
# Should log error and stop streaming
|
||||
assert extractor.is_streaming == False
|
||||
extractor.logger.error.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestESP32CSIParserDirect:
|
||||
"""Test ESP32 CSI parser with direct imports."""
|
||||
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
"""Create ESP32 CSI parser for testing."""
|
||||
return ESP32CSIParser()
|
||||
|
||||
@pytest.fixture
|
||||
def raw_esp32_data(self):
|
||||
"""Sample raw ESP32 CSI data."""
|
||||
return b"CSI_DATA:1234567890,3,56,2400,20,15.5,[1.0,2.0,3.0],[0.5,1.5,2.5]"
|
||||
|
||||
def test_should_parse_valid_esp32_data(self, parser, raw_esp32_data):
|
||||
"""Should parse valid ESP32 CSI data successfully."""
|
||||
result = parser.parse(raw_esp32_data)
|
||||
|
||||
assert isinstance(result, CSIData)
|
||||
assert result.num_antennas == 3
|
||||
assert result.num_subcarriers == 56
|
||||
assert result.frequency == 2400000000 # 2.4 GHz
|
||||
assert result.bandwidth == 20000000 # 20 MHz
|
||||
assert result.snr == 15.5
|
||||
|
||||
def test_should_handle_malformed_data(self, parser):
|
||||
"""Should handle malformed ESP32 data gracefully."""
|
||||
malformed_data = b"INVALID_DATA"
|
||||
|
||||
with pytest.raises(CSIParseError, match="Invalid ESP32 CSI data format"):
|
||||
parser.parse(malformed_data)
|
||||
|
||||
def test_should_handle_empty_data(self, parser):
|
||||
"""Should handle empty data gracefully."""
|
||||
with pytest.raises(CSIParseError, match="Empty data received"):
|
||||
parser.parse(b"")
|
||||
|
||||
def test_parse_with_value_error(self, parser):
|
||||
"""Should handle ValueError during parsing."""
|
||||
invalid_data = b"CSI_DATA:invalid_timestamp,3,56,2400,20,15.5"
|
||||
|
||||
with pytest.raises(CSIParseError, match="Failed to parse ESP32 data"):
|
||||
parser.parse(invalid_data)
|
||||
|
||||
def test_parse_with_index_error(self, parser):
|
||||
"""Should handle IndexError during parsing."""
|
||||
invalid_data = b"CSI_DATA:1234567890" # Missing fields
|
||||
|
||||
with pytest.raises(CSIParseError, match="Failed to parse ESP32 data"):
|
||||
parser.parse(invalid_data)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestRouterCSIParserDirect:
|
||||
"""Test Router CSI parser with direct imports."""
|
||||
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
"""Create Router CSI parser for testing."""
|
||||
return RouterCSIParser()
|
||||
|
||||
def test_should_parse_atheros_format(self, parser):
|
||||
"""Should parse Atheros CSI format successfully."""
|
||||
raw_data = b"ATHEROS_CSI:mock_data"
|
||||
|
||||
with patch.object(parser, '_parse_atheros_format', return_value=Mock(spec=CSIData)) as mock_parse:
|
||||
result = parser.parse(raw_data)
|
||||
|
||||
mock_parse.assert_called_once()
|
||||
assert result is not None
|
||||
|
||||
def test_should_handle_unknown_format(self, parser):
|
||||
"""Should handle unknown router format gracefully."""
|
||||
unknown_data = b"UNKNOWN_FORMAT:data"
|
||||
|
||||
with pytest.raises(CSIParseError, match="Unknown router CSI format"):
|
||||
parser.parse(unknown_data)
|
||||
|
||||
def test_parse_atheros_format_directly(self, parser):
|
||||
"""Should parse Atheros format directly."""
|
||||
raw_data = b"ATHEROS_CSI:mock_data"
|
||||
|
||||
result = parser.parse(raw_data)
|
||||
|
||||
assert isinstance(result, CSIData)
|
||||
assert result.metadata['source'] == 'atheros_router'
|
||||
|
||||
def test_should_handle_empty_data_router(self, parser):
|
||||
"""Should handle empty data gracefully."""
|
||||
with pytest.raises(CSIParseError, match="Empty data received"):
|
||||
parser.parse(b"")
|
||||
275
v1/tests/unit/test_csi_extractor_tdd.py
Normal file
275
v1/tests/unit/test_csi_extractor_tdd.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""Test-Driven Development tests for CSI extractor using London School approach."""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||||
from typing import Dict, Any, Optional
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from src.hardware.csi_extractor import (
|
||||
CSIExtractor,
|
||||
CSIParseError,
|
||||
CSIData,
|
||||
ESP32CSIParser,
|
||||
RouterCSIParser,
|
||||
CSIValidationError
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestCSIExtractor:
|
||||
"""Test CSI extractor using London School TDD - focus on interactions and behavior."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger(self):
|
||||
"""Mock logger for testing."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Mock configuration for CSI extractor."""
|
||||
return {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 1024,
|
||||
'timeout': 5.0,
|
||||
'validation_enabled': True,
|
||||
'retry_attempts': 3
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def csi_extractor(self, mock_config, mock_logger):
|
||||
"""Create CSI extractor instance for testing."""
|
||||
return CSIExtractor(config=mock_config, logger=mock_logger)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_csi_data(self):
|
||||
"""Sample CSI data for testing."""
|
||||
return CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={'source': 'esp32', 'channel': 6}
|
||||
)
|
||||
|
||||
def test_should_initialize_with_valid_config(self, mock_config, mock_logger):
|
||||
"""Should initialize CSI extractor with valid configuration."""
|
||||
extractor = CSIExtractor(config=mock_config, logger=mock_logger)
|
||||
|
||||
assert extractor.config == mock_config
|
||||
assert extractor.logger == mock_logger
|
||||
assert extractor.is_connected == False
|
||||
assert extractor.hardware_type == 'esp32'
|
||||
|
||||
def test_should_raise_error_with_invalid_config(self, mock_logger):
|
||||
"""Should raise error when initialized with invalid configuration."""
|
||||
invalid_config = {'invalid': 'config'}
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required configuration"):
|
||||
CSIExtractor(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_should_create_appropriate_parser(self, mock_config, mock_logger):
|
||||
"""Should create appropriate parser based on hardware type."""
|
||||
extractor = CSIExtractor(config=mock_config, logger=mock_logger)
|
||||
|
||||
assert isinstance(extractor.parser, ESP32CSIParser)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_establish_connection_successfully(self, csi_extractor):
|
||||
"""Should establish connection to hardware successfully."""
|
||||
with patch.object(csi_extractor, '_establish_hardware_connection', new_callable=AsyncMock) as mock_connect:
|
||||
mock_connect.return_value = True
|
||||
|
||||
result = await csi_extractor.connect()
|
||||
|
||||
assert result == True
|
||||
assert csi_extractor.is_connected == True
|
||||
mock_connect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_handle_connection_failure(self, csi_extractor):
|
||||
"""Should handle connection failure gracefully."""
|
||||
with patch.object(csi_extractor, '_establish_hardware_connection', new_callable=AsyncMock) as mock_connect:
|
||||
mock_connect.side_effect = ConnectionError("Hardware not found")
|
||||
|
||||
result = await csi_extractor.connect()
|
||||
|
||||
assert result == False
|
||||
assert csi_extractor.is_connected == False
|
||||
csi_extractor.logger.error.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_disconnect_properly(self, csi_extractor):
|
||||
"""Should disconnect from hardware properly."""
|
||||
csi_extractor.is_connected = True
|
||||
|
||||
with patch.object(csi_extractor, '_close_hardware_connection', new_callable=AsyncMock) as mock_disconnect:
|
||||
await csi_extractor.disconnect()
|
||||
|
||||
assert csi_extractor.is_connected == False
|
||||
mock_disconnect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_extract_csi_data_successfully(self, csi_extractor, sample_csi_data):
|
||||
"""Should extract CSI data successfully from hardware."""
|
||||
csi_extractor.is_connected = True
|
||||
|
||||
with patch.object(csi_extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
|
||||
with patch.object(csi_extractor.parser, 'parse', return_value=sample_csi_data) as mock_parse:
|
||||
mock_read.return_value = b"raw_csi_data"
|
||||
|
||||
result = await csi_extractor.extract_csi()
|
||||
|
||||
assert result == sample_csi_data
|
||||
mock_read.assert_called_once()
|
||||
mock_parse.assert_called_once_with(b"raw_csi_data")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_handle_extraction_failure_when_not_connected(self, csi_extractor):
|
||||
"""Should handle extraction failure when not connected."""
|
||||
csi_extractor.is_connected = False
|
||||
|
||||
with pytest.raises(CSIParseError, match="Not connected to hardware"):
|
||||
await csi_extractor.extract_csi()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_retry_on_temporary_failure(self, csi_extractor, sample_csi_data):
|
||||
"""Should retry extraction on temporary failure."""
|
||||
csi_extractor.is_connected = True
|
||||
|
||||
with patch.object(csi_extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
|
||||
with patch.object(csi_extractor.parser, 'parse') as mock_parse:
|
||||
# First two calls fail, third succeeds
|
||||
mock_read.side_effect = [ConnectionError(), ConnectionError(), b"raw_data"]
|
||||
mock_parse.return_value = sample_csi_data
|
||||
|
||||
result = await csi_extractor.extract_csi()
|
||||
|
||||
assert result == sample_csi_data
|
||||
assert mock_read.call_count == 3
|
||||
|
||||
def test_should_validate_csi_data_successfully(self, csi_extractor, sample_csi_data):
|
||||
"""Should validate CSI data successfully."""
|
||||
result = csi_extractor.validate_csi_data(sample_csi_data)
|
||||
|
||||
assert result == True
|
||||
|
||||
def test_should_reject_invalid_csi_data(self, csi_extractor):
|
||||
"""Should reject CSI data with invalid structure."""
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.array([]), # Empty array
|
||||
phase=np.array([]),
|
||||
frequency=0, # Invalid frequency
|
||||
bandwidth=0,
|
||||
num_subcarriers=0,
|
||||
num_antennas=0,
|
||||
snr=-100, # Invalid SNR
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError):
|
||||
csi_extractor.validate_csi_data(invalid_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_start_streaming_successfully(self, csi_extractor, sample_csi_data):
|
||||
"""Should start CSI data streaming successfully."""
|
||||
csi_extractor.is_connected = True
|
||||
callback = Mock()
|
||||
|
||||
with patch.object(csi_extractor, 'extract_csi', new_callable=AsyncMock) as mock_extract:
|
||||
mock_extract.return_value = sample_csi_data
|
||||
|
||||
# Start streaming with limited iterations to avoid infinite loop
|
||||
streaming_task = asyncio.create_task(csi_extractor.start_streaming(callback))
|
||||
await asyncio.sleep(0.1) # Let it run briefly
|
||||
csi_extractor.stop_streaming()
|
||||
await streaming_task
|
||||
|
||||
callback.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_stop_streaming_gracefully(self, csi_extractor):
|
||||
"""Should stop streaming gracefully."""
|
||||
csi_extractor.is_streaming = True
|
||||
|
||||
csi_extractor.stop_streaming()
|
||||
|
||||
assert csi_extractor.is_streaming == False
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestESP32CSIParser:
|
||||
"""Test ESP32 CSI parser using London School TDD."""
|
||||
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
"""Create ESP32 CSI parser for testing."""
|
||||
return ESP32CSIParser()
|
||||
|
||||
@pytest.fixture
|
||||
def raw_esp32_data(self):
|
||||
"""Sample raw ESP32 CSI data."""
|
||||
return b"CSI_DATA:1234567890,3,56,2400,20,15.5,[1.0,2.0,3.0],[0.5,1.5,2.5]"
|
||||
|
||||
def test_should_parse_valid_esp32_data(self, parser, raw_esp32_data):
|
||||
"""Should parse valid ESP32 CSI data successfully."""
|
||||
result = parser.parse(raw_esp32_data)
|
||||
|
||||
assert isinstance(result, CSIData)
|
||||
assert result.num_antennas == 3
|
||||
assert result.num_subcarriers == 56
|
||||
assert result.frequency == 2400000000 # 2.4 GHz
|
||||
assert result.bandwidth == 20000000 # 20 MHz
|
||||
assert result.snr == 15.5
|
||||
|
||||
def test_should_handle_malformed_data(self, parser):
|
||||
"""Should handle malformed ESP32 data gracefully."""
|
||||
malformed_data = b"INVALID_DATA"
|
||||
|
||||
with pytest.raises(CSIParseError, match="Invalid ESP32 CSI data format"):
|
||||
parser.parse(malformed_data)
|
||||
|
||||
def test_should_handle_empty_data(self, parser):
|
||||
"""Should handle empty data gracefully."""
|
||||
with pytest.raises(CSIParseError, match="Empty data received"):
|
||||
parser.parse(b"")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestRouterCSIParser:
|
||||
"""Test Router CSI parser using London School TDD."""
|
||||
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
"""Create Router CSI parser for testing."""
|
||||
return RouterCSIParser()
|
||||
|
||||
def test_should_parse_atheros_format(self, parser):
|
||||
"""Should parse Atheros CSI format successfully."""
|
||||
raw_data = b"ATHEROS_CSI:mock_data"
|
||||
|
||||
with patch.object(parser, '_parse_atheros_format', return_value=Mock(spec=CSIData)) as mock_parse:
|
||||
result = parser.parse(raw_data)
|
||||
|
||||
mock_parse.assert_called_once()
|
||||
assert result is not None
|
||||
|
||||
def test_should_handle_unknown_format(self, parser):
|
||||
"""Should handle unknown router format gracefully."""
|
||||
unknown_data = b"UNKNOWN_FORMAT:data"
|
||||
|
||||
with pytest.raises(CSIParseError, match="Unknown router CSI format"):
|
||||
parser.parse(unknown_data)
|
||||
386
v1/tests/unit/test_csi_extractor_tdd_complete.py
Normal file
386
v1/tests/unit/test_csi_extractor_tdd_complete.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""Complete TDD tests for CSI extractor with 100% coverage."""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||||
from typing import Dict, Any, Optional
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from src.hardware.csi_extractor import (
|
||||
CSIExtractor,
|
||||
CSIParseError,
|
||||
CSIData,
|
||||
ESP32CSIParser,
|
||||
RouterCSIParser,
|
||||
CSIValidationError
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestCSIExtractorComplete:
|
||||
"""Complete CSI extractor tests for 100% coverage."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger(self):
|
||||
"""Mock logger for testing."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def esp32_config(self):
|
||||
"""ESP32 configuration for testing."""
|
||||
return {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 1024,
|
||||
'timeout': 5.0,
|
||||
'validation_enabled': True,
|
||||
'retry_attempts': 3
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def router_config(self):
|
||||
"""Router configuration for testing."""
|
||||
return {
|
||||
'hardware_type': 'router',
|
||||
'sampling_rate': 50,
|
||||
'buffer_size': 512,
|
||||
'timeout': 10.0,
|
||||
'validation_enabled': False,
|
||||
'retry_attempts': 1
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def sample_csi_data(self):
|
||||
"""Sample CSI data for testing."""
|
||||
return CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={'source': 'esp32', 'channel': 6}
|
||||
)
|
||||
|
||||
def test_should_create_router_parser(self, router_config, mock_logger):
|
||||
"""Should create router parser when hardware_type is router."""
|
||||
extractor = CSIExtractor(config=router_config, logger=mock_logger)
|
||||
|
||||
assert isinstance(extractor.parser, RouterCSIParser)
|
||||
assert extractor.hardware_type == 'router'
|
||||
|
||||
def test_should_raise_error_for_unsupported_hardware(self, mock_logger):
|
||||
"""Should raise error for unsupported hardware type."""
|
||||
invalid_config = {
|
||||
'hardware_type': 'unsupported',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 1024,
|
||||
'timeout': 5.0
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported hardware type: unsupported"):
|
||||
CSIExtractor(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_config_validation_negative_sampling_rate(self, mock_logger):
|
||||
"""Should validate sampling_rate is positive."""
|
||||
invalid_config = {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': -1,
|
||||
'buffer_size': 1024,
|
||||
'timeout': 5.0
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="sampling_rate must be positive"):
|
||||
CSIExtractor(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_config_validation_zero_buffer_size(self, mock_logger):
|
||||
"""Should validate buffer_size is positive."""
|
||||
invalid_config = {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 0,
|
||||
'timeout': 5.0
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="buffer_size must be positive"):
|
||||
CSIExtractor(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_config_validation_negative_timeout(self, mock_logger):
|
||||
"""Should validate timeout is positive."""
|
||||
invalid_config = {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 1024,
|
||||
'timeout': -1.0
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="timeout must be positive"):
|
||||
CSIExtractor(config=invalid_config, logger=mock_logger)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_when_not_connected(self, esp32_config, mock_logger):
|
||||
"""Should handle disconnect when not connected."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = False
|
||||
|
||||
with patch.object(extractor, '_close_hardware_connection', new_callable=AsyncMock) as mock_close:
|
||||
await extractor.disconnect()
|
||||
|
||||
# Should not call close when not connected
|
||||
mock_close.assert_not_called()
|
||||
assert extractor.is_connected == False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_with_validation_disabled(self, esp32_config, mock_logger, sample_csi_data):
|
||||
"""Should skip validation when disabled."""
|
||||
esp32_config['validation_enabled'] = False
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
|
||||
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
|
||||
with patch.object(extractor.parser, 'parse', return_value=sample_csi_data) as mock_parse:
|
||||
with patch.object(extractor, 'validate_csi_data') as mock_validate:
|
||||
mock_read.return_value = b"raw_data"
|
||||
|
||||
result = await extractor.extract_csi()
|
||||
|
||||
assert result == sample_csi_data
|
||||
mock_validate.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_max_retries_exceeded(self, esp32_config, mock_logger):
|
||||
"""Should raise error after max retries exceeded."""
|
||||
esp32_config['retry_attempts'] = 2
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
|
||||
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
|
||||
mock_read.side_effect = ConnectionError("Connection failed")
|
||||
|
||||
with pytest.raises(CSIParseError, match="Extraction failed after 2 attempts"):
|
||||
await extractor.extract_csi()
|
||||
|
||||
assert mock_read.call_count == 2
|
||||
|
||||
def test_validation_empty_amplitude(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for empty amplitude."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.array([]),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Empty amplitude data"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_empty_phase(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for empty phase."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.array([]),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Empty phase data"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_invalid_frequency(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for invalid frequency."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=0,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid frequency"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_invalid_bandwidth(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for invalid bandwidth."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=0,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid bandwidth"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_invalid_subcarriers(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for invalid subcarriers."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=0,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid number of subcarriers"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_invalid_antennas(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for invalid antennas."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=0,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid number of antennas"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_snr_too_low(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for SNR too low."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=-100,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid SNR value"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
def test_validation_snr_too_high(self, esp32_config, mock_logger):
|
||||
"""Should raise validation error for SNR too high."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
invalid_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=100,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid SNR value"):
|
||||
extractor.validate_csi_data(invalid_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_with_exception(self, esp32_config, mock_logger):
|
||||
"""Should handle exceptions during streaming."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
callback = Mock()
|
||||
|
||||
with patch.object(extractor, 'extract_csi', new_callable=AsyncMock) as mock_extract:
|
||||
mock_extract.side_effect = Exception("Extraction error")
|
||||
|
||||
# Start streaming and let it handle the exception
|
||||
streaming_task = asyncio.create_task(extractor.start_streaming(callback))
|
||||
await asyncio.sleep(0.1) # Let it run briefly and hit the exception
|
||||
await streaming_task
|
||||
|
||||
# Should log error and stop streaming
|
||||
assert extractor.is_streaming == False
|
||||
extractor.logger.error.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestESP32CSIParserComplete:
|
||||
"""Complete ESP32 CSI parser tests for 100% coverage."""
|
||||
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
"""Create ESP32 CSI parser for testing."""
|
||||
return ESP32CSIParser()
|
||||
|
||||
def test_parse_with_value_error(self, parser):
|
||||
"""Should handle ValueError during parsing."""
|
||||
invalid_data = b"CSI_DATA:invalid_timestamp,3,56,2400,20,15.5"
|
||||
|
||||
with pytest.raises(CSIParseError, match="Failed to parse ESP32 data"):
|
||||
parser.parse(invalid_data)
|
||||
|
||||
def test_parse_with_index_error(self, parser):
|
||||
"""Should handle IndexError during parsing."""
|
||||
invalid_data = b"CSI_DATA:1234567890" # Missing fields
|
||||
|
||||
with pytest.raises(CSIParseError, match="Failed to parse ESP32 data"):
|
||||
parser.parse(invalid_data)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestRouterCSIParserComplete:
|
||||
"""Complete Router CSI parser tests for 100% coverage."""
|
||||
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
"""Create Router CSI parser for testing."""
|
||||
return RouterCSIParser()
|
||||
|
||||
def test_parse_atheros_format_directly(self, parser):
|
||||
"""Should parse Atheros format directly."""
|
||||
raw_data = b"ATHEROS_CSI:mock_data"
|
||||
|
||||
result = parser.parse(raw_data)
|
||||
|
||||
assert isinstance(result, CSIData)
|
||||
assert result.metadata['source'] == 'atheros_router'
|
||||
87
v1/tests/unit/test_csi_processor.py
Normal file
87
v1/tests/unit/test_csi_processor.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch
|
||||
from src.core.csi_processor import CSIProcessor
|
||||
|
||||
|
||||
class TestCSIProcessor:
|
||||
"""Test suite for CSI processor following London School TDD principles"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_csi_data(self):
|
||||
"""Generate synthetic CSI data for testing"""
|
||||
# Simple raw CSI data array for testing
|
||||
return np.random.uniform(0.1, 2.0, (3, 56, 100))
|
||||
|
||||
@pytest.fixture
|
||||
def csi_processor(self):
|
||||
"""Create CSI processor instance for testing"""
|
||||
return CSIProcessor()
|
||||
|
||||
def test_process_csi_data_returns_normalized_output(self, csi_processor, mock_csi_data):
|
||||
"""Test that CSI processing returns properly normalized output"""
|
||||
# Act
|
||||
result = csi_processor.process_raw_csi(mock_csi_data)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert isinstance(result, np.ndarray)
|
||||
assert result.shape == mock_csi_data.shape
|
||||
|
||||
# Verify normalization - mean should be close to 0, std close to 1
|
||||
assert abs(result.mean()) < 0.1
|
||||
assert abs(result.std() - 1.0) < 0.1
|
||||
|
||||
def test_process_csi_data_handles_invalid_input(self, csi_processor):
|
||||
"""Test that CSI processor handles invalid input gracefully"""
|
||||
# Arrange
|
||||
invalid_data = np.array([])
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Raw CSI data cannot be empty"):
|
||||
csi_processor.process_raw_csi(invalid_data)
|
||||
|
||||
def test_process_csi_data_removes_nan_values(self, csi_processor, mock_csi_data):
|
||||
"""Test that CSI processor removes NaN values from input"""
|
||||
# Arrange
|
||||
mock_csi_data[0, 0, 0] = np.nan
|
||||
|
||||
# Act
|
||||
result = csi_processor.process_raw_csi(mock_csi_data)
|
||||
|
||||
# Assert
|
||||
assert not np.isnan(result).any()
|
||||
|
||||
def test_process_csi_data_applies_temporal_filtering(self, csi_processor, mock_csi_data):
|
||||
"""Test that temporal filtering is applied to CSI data"""
|
||||
# Arrange - Add noise to make filtering effect visible
|
||||
noisy_data = mock_csi_data + np.random.normal(0, 0.1, mock_csi_data.shape)
|
||||
|
||||
# Act
|
||||
result = csi_processor.process_raw_csi(noisy_data)
|
||||
|
||||
# Assert - Result should be normalized
|
||||
assert isinstance(result, np.ndarray)
|
||||
assert result.shape == noisy_data.shape
|
||||
|
||||
def test_process_csi_data_preserves_metadata(self, csi_processor, mock_csi_data):
|
||||
"""Test that metadata is preserved during processing"""
|
||||
# Act
|
||||
result = csi_processor.process_raw_csi(mock_csi_data)
|
||||
|
||||
# Assert - For now, just verify processing works
|
||||
assert result is not None
|
||||
assert isinstance(result, np.ndarray)
|
||||
|
||||
def test_process_csi_data_performance_requirement(self, csi_processor, mock_csi_data):
|
||||
"""Test that CSI processing meets performance requirements (<10ms)"""
|
||||
import time
|
||||
|
||||
# Act
|
||||
start_time = time.time()
|
||||
result = csi_processor.process_raw_csi(mock_csi_data)
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
# Assert
|
||||
assert processing_time < 0.01 # <10ms requirement
|
||||
assert result is not None
|
||||
479
v1/tests/unit/test_csi_processor_tdd.py
Normal file
479
v1/tests/unit/test_csi_processor_tdd.py
Normal file
@@ -0,0 +1,479 @@
|
||||
"""TDD tests for CSI processor following London School approach."""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||||
from datetime import datetime, timezone
|
||||
import importlib.util
|
||||
from typing import Dict, List, Any
|
||||
|
||||
# Import the CSI processor module directly
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
'csi_processor',
|
||||
'/workspaces/wifi-densepose/src/core/csi_processor.py'
|
||||
)
|
||||
csi_processor_module = importlib.util.module_from_spec(spec)
|
||||
|
||||
# Import CSI extractor for dependencies
|
||||
csi_spec = importlib.util.spec_from_file_location(
|
||||
'csi_extractor',
|
||||
'/workspaces/wifi-densepose/src/hardware/csi_extractor.py'
|
||||
)
|
||||
csi_module = importlib.util.module_from_spec(csi_spec)
|
||||
csi_spec.loader.exec_module(csi_module)
|
||||
|
||||
# Make dependencies available and load the processor
|
||||
csi_processor_module.CSIData = csi_module.CSIData
|
||||
spec.loader.exec_module(csi_processor_module)
|
||||
|
||||
# Get classes from modules
|
||||
CSIProcessor = csi_processor_module.CSIProcessor
|
||||
CSIProcessingError = csi_processor_module.CSIProcessingError
|
||||
HumanDetectionResult = csi_processor_module.HumanDetectionResult
|
||||
CSIFeatures = csi_processor_module.CSIFeatures
|
||||
CSIData = csi_module.CSIData
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestCSIProcessor:
|
||||
"""Test CSI processor using London School TDD."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger(self):
|
||||
"""Mock logger for testing."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def processor_config(self):
|
||||
"""CSI processor configuration for testing."""
|
||||
return {
|
||||
'sampling_rate': 100,
|
||||
'window_size': 256,
|
||||
'overlap': 0.5,
|
||||
'noise_threshold': -60.0,
|
||||
'human_detection_threshold': 0.7,
|
||||
'smoothing_factor': 0.8,
|
||||
'max_history_size': 1000,
|
||||
'enable_preprocessing': True,
|
||||
'enable_feature_extraction': True,
|
||||
'enable_human_detection': True
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def csi_processor(self, processor_config, mock_logger):
|
||||
"""Create CSI processor for testing."""
|
||||
return CSIProcessor(config=processor_config, logger=mock_logger)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_csi_data(self):
|
||||
"""Sample CSI data for testing."""
|
||||
return CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56) + 1.0, # Ensure positive amplitude
|
||||
phase=np.random.uniform(-np.pi, np.pi, (3, 56)),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={'source': 'test'}
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_features(self):
|
||||
"""Sample CSI features for testing."""
|
||||
return CSIFeatures(
|
||||
amplitude_mean=np.random.rand(56),
|
||||
amplitude_variance=np.random.rand(56),
|
||||
phase_difference=np.random.rand(56),
|
||||
correlation_matrix=np.random.rand(3, 3),
|
||||
doppler_shift=np.random.rand(10),
|
||||
power_spectral_density=np.random.rand(128),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
metadata={'processing_params': {}}
|
||||
)
|
||||
|
||||
# Initialization tests
|
||||
def test_should_initialize_with_valid_config(self, processor_config, mock_logger):
|
||||
"""Should initialize CSI processor with valid configuration."""
|
||||
processor = CSIProcessor(config=processor_config, logger=mock_logger)
|
||||
|
||||
assert processor.config == processor_config
|
||||
assert processor.logger == mock_logger
|
||||
assert processor.sampling_rate == 100
|
||||
assert processor.window_size == 256
|
||||
assert processor.overlap == 0.5
|
||||
assert processor.noise_threshold == -60.0
|
||||
assert processor.human_detection_threshold == 0.7
|
||||
assert processor.smoothing_factor == 0.8
|
||||
assert processor.max_history_size == 1000
|
||||
assert len(processor.csi_history) == 0
|
||||
|
||||
def test_should_raise_error_with_invalid_config(self, mock_logger):
|
||||
"""Should raise error when initialized with invalid configuration."""
|
||||
invalid_config = {'invalid': 'config'}
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required configuration"):
|
||||
CSIProcessor(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_should_validate_required_fields(self, mock_logger):
|
||||
"""Should validate all required configuration fields."""
|
||||
required_fields = ['sampling_rate', 'window_size', 'overlap', 'noise_threshold']
|
||||
base_config = {
|
||||
'sampling_rate': 100,
|
||||
'window_size': 256,
|
||||
'overlap': 0.5,
|
||||
'noise_threshold': -60.0
|
||||
}
|
||||
|
||||
for field in required_fields:
|
||||
config = base_config.copy()
|
||||
del config[field]
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required configuration"):
|
||||
CSIProcessor(config=config, logger=mock_logger)
|
||||
|
||||
def test_should_use_default_values(self, mock_logger):
|
||||
"""Should use default values for optional parameters."""
|
||||
minimal_config = {
|
||||
'sampling_rate': 100,
|
||||
'window_size': 256,
|
||||
'overlap': 0.5,
|
||||
'noise_threshold': -60.0
|
||||
}
|
||||
|
||||
processor = CSIProcessor(config=minimal_config, logger=mock_logger)
|
||||
|
||||
assert processor.human_detection_threshold == 0.8 # default
|
||||
assert processor.smoothing_factor == 0.9 # default
|
||||
assert processor.max_history_size == 500 # default
|
||||
|
||||
def test_should_initialize_without_logger(self, processor_config):
|
||||
"""Should initialize without logger provided."""
|
||||
processor = CSIProcessor(config=processor_config)
|
||||
|
||||
assert processor.logger is not None # Should create default logger
|
||||
|
||||
# Preprocessing tests
|
||||
def test_should_preprocess_csi_data_successfully(self, csi_processor, sample_csi_data):
|
||||
"""Should preprocess CSI data successfully."""
|
||||
with patch.object(csi_processor, '_remove_noise') as mock_noise:
|
||||
with patch.object(csi_processor, '_apply_windowing') as mock_window:
|
||||
with patch.object(csi_processor, '_normalize_amplitude') as mock_normalize:
|
||||
mock_noise.return_value = sample_csi_data
|
||||
mock_window.return_value = sample_csi_data
|
||||
mock_normalize.return_value = sample_csi_data
|
||||
|
||||
result = csi_processor.preprocess_csi_data(sample_csi_data)
|
||||
|
||||
assert result == sample_csi_data
|
||||
mock_noise.assert_called_once_with(sample_csi_data)
|
||||
mock_window.assert_called_once()
|
||||
mock_normalize.assert_called_once()
|
||||
|
||||
def test_should_skip_preprocessing_when_disabled(self, processor_config, mock_logger, sample_csi_data):
|
||||
"""Should skip preprocessing when disabled."""
|
||||
processor_config['enable_preprocessing'] = False
|
||||
processor = CSIProcessor(config=processor_config, logger=mock_logger)
|
||||
|
||||
result = processor.preprocess_csi_data(sample_csi_data)
|
||||
|
||||
assert result == sample_csi_data
|
||||
|
||||
def test_should_handle_preprocessing_error(self, csi_processor, sample_csi_data):
|
||||
"""Should handle preprocessing errors gracefully."""
|
||||
with patch.object(csi_processor, '_remove_noise') as mock_noise:
|
||||
mock_noise.side_effect = Exception("Preprocessing error")
|
||||
|
||||
with pytest.raises(CSIProcessingError, match="Failed to preprocess CSI data"):
|
||||
csi_processor.preprocess_csi_data(sample_csi_data)
|
||||
|
||||
# Feature extraction tests
|
||||
def test_should_extract_features_successfully(self, csi_processor, sample_csi_data, sample_features):
|
||||
"""Should extract features from CSI data successfully."""
|
||||
with patch.object(csi_processor, '_extract_amplitude_features') as mock_amp:
|
||||
with patch.object(csi_processor, '_extract_phase_features') as mock_phase:
|
||||
with patch.object(csi_processor, '_extract_correlation_features') as mock_corr:
|
||||
with patch.object(csi_processor, '_extract_doppler_features') as mock_doppler:
|
||||
mock_amp.return_value = (sample_features.amplitude_mean, sample_features.amplitude_variance)
|
||||
mock_phase.return_value = sample_features.phase_difference
|
||||
mock_corr.return_value = sample_features.correlation_matrix
|
||||
mock_doppler.return_value = (sample_features.doppler_shift, sample_features.power_spectral_density)
|
||||
|
||||
result = csi_processor.extract_features(sample_csi_data)
|
||||
|
||||
assert isinstance(result, CSIFeatures)
|
||||
assert np.array_equal(result.amplitude_mean, sample_features.amplitude_mean)
|
||||
assert np.array_equal(result.amplitude_variance, sample_features.amplitude_variance)
|
||||
mock_amp.assert_called_once()
|
||||
mock_phase.assert_called_once()
|
||||
mock_corr.assert_called_once()
|
||||
mock_doppler.assert_called_once()
|
||||
|
||||
def test_should_skip_feature_extraction_when_disabled(self, processor_config, mock_logger, sample_csi_data):
|
||||
"""Should skip feature extraction when disabled."""
|
||||
processor_config['enable_feature_extraction'] = False
|
||||
processor = CSIProcessor(config=processor_config, logger=mock_logger)
|
||||
|
||||
result = processor.extract_features(sample_csi_data)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_should_handle_feature_extraction_error(self, csi_processor, sample_csi_data):
|
||||
"""Should handle feature extraction errors gracefully."""
|
||||
with patch.object(csi_processor, '_extract_amplitude_features') as mock_amp:
|
||||
mock_amp.side_effect = Exception("Feature extraction error")
|
||||
|
||||
with pytest.raises(CSIProcessingError, match="Failed to extract features"):
|
||||
csi_processor.extract_features(sample_csi_data)
|
||||
|
||||
# Human detection tests
|
||||
def test_should_detect_human_presence_successfully(self, csi_processor, sample_features):
|
||||
"""Should detect human presence successfully."""
|
||||
with patch.object(csi_processor, '_analyze_motion_patterns') as mock_motion:
|
||||
with patch.object(csi_processor, '_calculate_detection_confidence') as mock_confidence:
|
||||
with patch.object(csi_processor, '_apply_temporal_smoothing') as mock_smooth:
|
||||
mock_motion.return_value = 0.9
|
||||
mock_confidence.return_value = 0.85
|
||||
mock_smooth.return_value = 0.88
|
||||
|
||||
result = csi_processor.detect_human_presence(sample_features)
|
||||
|
||||
assert isinstance(result, HumanDetectionResult)
|
||||
assert result.human_detected == True
|
||||
assert result.confidence == 0.88
|
||||
assert result.motion_score == 0.9
|
||||
mock_motion.assert_called_once()
|
||||
mock_confidence.assert_called_once()
|
||||
mock_smooth.assert_called_once()
|
||||
|
||||
def test_should_detect_no_human_presence(self, csi_processor, sample_features):
|
||||
"""Should detect no human presence when confidence is low."""
|
||||
with patch.object(csi_processor, '_analyze_motion_patterns') as mock_motion:
|
||||
with patch.object(csi_processor, '_calculate_detection_confidence') as mock_confidence:
|
||||
with patch.object(csi_processor, '_apply_temporal_smoothing') as mock_smooth:
|
||||
mock_motion.return_value = 0.3
|
||||
mock_confidence.return_value = 0.2
|
||||
mock_smooth.return_value = 0.25
|
||||
|
||||
result = csi_processor.detect_human_presence(sample_features)
|
||||
|
||||
assert result.human_detected == False
|
||||
assert result.confidence == 0.25
|
||||
assert result.motion_score == 0.3
|
||||
|
||||
def test_should_skip_human_detection_when_disabled(self, processor_config, mock_logger, sample_features):
|
||||
"""Should skip human detection when disabled."""
|
||||
processor_config['enable_human_detection'] = False
|
||||
processor = CSIProcessor(config=processor_config, logger=mock_logger)
|
||||
|
||||
result = processor.detect_human_presence(sample_features)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_should_handle_human_detection_error(self, csi_processor, sample_features):
|
||||
"""Should handle human detection errors gracefully."""
|
||||
with patch.object(csi_processor, '_analyze_motion_patterns') as mock_motion:
|
||||
mock_motion.side_effect = Exception("Detection error")
|
||||
|
||||
with pytest.raises(CSIProcessingError, match="Failed to detect human presence"):
|
||||
csi_processor.detect_human_presence(sample_features)
|
||||
|
||||
# Processing pipeline tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_process_csi_data_pipeline_successfully(self, csi_processor, sample_csi_data, sample_features):
|
||||
"""Should process CSI data through full pipeline successfully."""
|
||||
expected_detection = HumanDetectionResult(
|
||||
human_detected=True,
|
||||
confidence=0.85,
|
||||
motion_score=0.9,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
features=sample_features,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with patch.object(csi_processor, 'preprocess_csi_data', return_value=sample_csi_data) as mock_preprocess:
|
||||
with patch.object(csi_processor, 'extract_features', return_value=sample_features) as mock_features:
|
||||
with patch.object(csi_processor, 'detect_human_presence', return_value=expected_detection) as mock_detect:
|
||||
|
||||
result = await csi_processor.process_csi_data(sample_csi_data)
|
||||
|
||||
assert result == expected_detection
|
||||
mock_preprocess.assert_called_once_with(sample_csi_data)
|
||||
mock_features.assert_called_once_with(sample_csi_data)
|
||||
mock_detect.assert_called_once_with(sample_features)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_handle_pipeline_processing_error(self, csi_processor, sample_csi_data):
|
||||
"""Should handle pipeline processing errors gracefully."""
|
||||
with patch.object(csi_processor, 'preprocess_csi_data') as mock_preprocess:
|
||||
mock_preprocess.side_effect = CSIProcessingError("Pipeline error")
|
||||
|
||||
with pytest.raises(CSIProcessingError):
|
||||
await csi_processor.process_csi_data(sample_csi_data)
|
||||
|
||||
# History management tests
|
||||
def test_should_add_csi_data_to_history(self, csi_processor, sample_csi_data):
|
||||
"""Should add CSI data to history successfully."""
|
||||
csi_processor.add_to_history(sample_csi_data)
|
||||
|
||||
assert len(csi_processor.csi_history) == 1
|
||||
assert csi_processor.csi_history[0] == sample_csi_data
|
||||
|
||||
def test_should_maintain_history_size_limit(self, processor_config, mock_logger):
|
||||
"""Should maintain history size within limits."""
|
||||
processor_config['max_history_size'] = 2
|
||||
processor = CSIProcessor(config=processor_config, logger=mock_logger)
|
||||
|
||||
# Add 3 items to history of size 2
|
||||
for i in range(3):
|
||||
csi_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={'index': i}
|
||||
)
|
||||
processor.add_to_history(csi_data)
|
||||
|
||||
assert len(processor.csi_history) == 2
|
||||
assert processor.csi_history[0].metadata['index'] == 1 # First item removed
|
||||
assert processor.csi_history[1].metadata['index'] == 2
|
||||
|
||||
def test_should_clear_history(self, csi_processor, sample_csi_data):
|
||||
"""Should clear history successfully."""
|
||||
csi_processor.add_to_history(sample_csi_data)
|
||||
assert len(csi_processor.csi_history) > 0
|
||||
|
||||
csi_processor.clear_history()
|
||||
|
||||
assert len(csi_processor.csi_history) == 0
|
||||
|
||||
def test_should_get_recent_history(self, csi_processor):
|
||||
"""Should get recent history entries."""
|
||||
# Add 5 items to history
|
||||
for i in range(5):
|
||||
csi_data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={'index': i}
|
||||
)
|
||||
csi_processor.add_to_history(csi_data)
|
||||
|
||||
recent = csi_processor.get_recent_history(3)
|
||||
|
||||
assert len(recent) == 3
|
||||
assert recent[0].metadata['index'] == 2 # Most recent first
|
||||
assert recent[1].metadata['index'] == 3
|
||||
assert recent[2].metadata['index'] == 4
|
||||
|
||||
# Statistics and monitoring tests
|
||||
def test_should_get_processing_statistics(self, csi_processor):
|
||||
"""Should get processing statistics."""
|
||||
# Simulate some processing
|
||||
csi_processor._total_processed = 100
|
||||
csi_processor._processing_errors = 5
|
||||
csi_processor._human_detections = 25
|
||||
|
||||
stats = csi_processor.get_processing_statistics()
|
||||
|
||||
assert isinstance(stats, dict)
|
||||
assert stats['total_processed'] == 100
|
||||
assert stats['processing_errors'] == 5
|
||||
assert stats['human_detections'] == 25
|
||||
assert stats['error_rate'] == 0.05
|
||||
assert stats['detection_rate'] == 0.25
|
||||
|
||||
def test_should_reset_statistics(self, csi_processor):
|
||||
"""Should reset processing statistics."""
|
||||
csi_processor._total_processed = 100
|
||||
csi_processor._processing_errors = 5
|
||||
csi_processor._human_detections = 25
|
||||
|
||||
csi_processor.reset_statistics()
|
||||
|
||||
assert csi_processor._total_processed == 0
|
||||
assert csi_processor._processing_errors == 0
|
||||
assert csi_processor._human_detections == 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestCSIFeatures:
|
||||
"""Test CSI features data structure."""
|
||||
|
||||
def test_should_create_csi_features(self):
|
||||
"""Should create CSI features successfully."""
|
||||
features = CSIFeatures(
|
||||
amplitude_mean=np.random.rand(56),
|
||||
amplitude_variance=np.random.rand(56),
|
||||
phase_difference=np.random.rand(56),
|
||||
correlation_matrix=np.random.rand(3, 3),
|
||||
doppler_shift=np.random.rand(10),
|
||||
power_spectral_density=np.random.rand(128),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
metadata={'test': 'data'}
|
||||
)
|
||||
|
||||
assert features.amplitude_mean.shape == (56,)
|
||||
assert features.amplitude_variance.shape == (56,)
|
||||
assert features.phase_difference.shape == (56,)
|
||||
assert features.correlation_matrix.shape == (3, 3)
|
||||
assert features.doppler_shift.shape == (10,)
|
||||
assert features.power_spectral_density.shape == (128,)
|
||||
assert isinstance(features.timestamp, datetime)
|
||||
assert features.metadata['test'] == 'data'
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestHumanDetectionResult:
|
||||
"""Test human detection result data structure."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_features(self):
|
||||
"""Sample features for testing."""
|
||||
return CSIFeatures(
|
||||
amplitude_mean=np.random.rand(56),
|
||||
amplitude_variance=np.random.rand(56),
|
||||
phase_difference=np.random.rand(56),
|
||||
correlation_matrix=np.random.rand(3, 3),
|
||||
doppler_shift=np.random.rand(10),
|
||||
power_spectral_density=np.random.rand(128),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
metadata={}
|
||||
)
|
||||
|
||||
def test_should_create_detection_result(self, sample_features):
|
||||
"""Should create human detection result successfully."""
|
||||
result = HumanDetectionResult(
|
||||
human_detected=True,
|
||||
confidence=0.85,
|
||||
motion_score=0.92,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
features=sample_features,
|
||||
metadata={'test': 'data'}
|
||||
)
|
||||
|
||||
assert result.human_detected == True
|
||||
assert result.confidence == 0.85
|
||||
assert result.motion_score == 0.92
|
||||
assert isinstance(result.timestamp, datetime)
|
||||
assert result.features == sample_features
|
||||
assert result.metadata['test'] == 'data'
|
||||
599
v1/tests/unit/test_csi_standalone.py
Normal file
599
v1/tests/unit/test_csi_standalone.py
Normal file
@@ -0,0 +1,599 @@
|
||||
"""Standalone tests for CSI extractor module."""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
import importlib.util
|
||||
|
||||
# Import the module directly to avoid circular imports
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
'csi_extractor',
|
||||
'/workspaces/wifi-densepose/src/hardware/csi_extractor.py'
|
||||
)
|
||||
csi_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(csi_module)
|
||||
|
||||
# Get classes from the module
|
||||
CSIExtractor = csi_module.CSIExtractor
|
||||
CSIParseError = csi_module.CSIParseError
|
||||
CSIData = csi_module.CSIData
|
||||
ESP32CSIParser = csi_module.ESP32CSIParser
|
||||
RouterCSIParser = csi_module.RouterCSIParser
|
||||
CSIValidationError = csi_module.CSIValidationError
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestCSIExtractorStandalone:
|
||||
"""Standalone tests for CSI extractor with 100% coverage."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger(self):
|
||||
"""Mock logger for testing."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def esp32_config(self):
|
||||
"""ESP32 configuration for testing."""
|
||||
return {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 1024,
|
||||
'timeout': 5.0,
|
||||
'validation_enabled': True,
|
||||
'retry_attempts': 3
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def router_config(self):
|
||||
"""Router configuration for testing."""
|
||||
return {
|
||||
'hardware_type': 'router',
|
||||
'sampling_rate': 50,
|
||||
'buffer_size': 512,
|
||||
'timeout': 10.0,
|
||||
'validation_enabled': False,
|
||||
'retry_attempts': 1
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def sample_csi_data(self):
|
||||
"""Sample CSI data for testing."""
|
||||
return CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={'source': 'esp32', 'channel': 6}
|
||||
)
|
||||
|
||||
# Test all initialization paths
|
||||
def test_init_esp32_config(self, esp32_config, mock_logger):
|
||||
"""Should initialize with ESP32 configuration."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
assert extractor.config == esp32_config
|
||||
assert extractor.logger == mock_logger
|
||||
assert extractor.is_connected == False
|
||||
assert extractor.hardware_type == 'esp32'
|
||||
assert isinstance(extractor.parser, ESP32CSIParser)
|
||||
|
||||
def test_init_router_config(self, router_config, mock_logger):
|
||||
"""Should initialize with router configuration."""
|
||||
extractor = CSIExtractor(config=router_config, logger=mock_logger)
|
||||
|
||||
assert isinstance(extractor.parser, RouterCSIParser)
|
||||
assert extractor.hardware_type == 'router'
|
||||
|
||||
def test_init_unsupported_hardware(self, mock_logger):
|
||||
"""Should raise error for unsupported hardware type."""
|
||||
invalid_config = {
|
||||
'hardware_type': 'unsupported',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 1024,
|
||||
'timeout': 5.0
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported hardware type: unsupported"):
|
||||
CSIExtractor(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_init_without_logger(self, esp32_config):
|
||||
"""Should initialize without logger."""
|
||||
extractor = CSIExtractor(config=esp32_config)
|
||||
|
||||
assert extractor.logger is not None # Should create default logger
|
||||
|
||||
# Test all validation paths
|
||||
def test_validation_missing_fields(self, mock_logger):
|
||||
"""Should validate missing required fields."""
|
||||
for missing_field in ['hardware_type', 'sampling_rate', 'buffer_size', 'timeout']:
|
||||
config = {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 1024,
|
||||
'timeout': 5.0
|
||||
}
|
||||
del config[missing_field]
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required configuration"):
|
||||
CSIExtractor(config=config, logger=mock_logger)
|
||||
|
||||
def test_validation_negative_sampling_rate(self, mock_logger):
|
||||
"""Should validate sampling_rate is positive."""
|
||||
config = {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': -1,
|
||||
'buffer_size': 1024,
|
||||
'timeout': 5.0
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="sampling_rate must be positive"):
|
||||
CSIExtractor(config=config, logger=mock_logger)
|
||||
|
||||
def test_validation_zero_buffer_size(self, mock_logger):
|
||||
"""Should validate buffer_size is positive."""
|
||||
config = {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 0,
|
||||
'timeout': 5.0
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="buffer_size must be positive"):
|
||||
CSIExtractor(config=config, logger=mock_logger)
|
||||
|
||||
def test_validation_negative_timeout(self, mock_logger):
|
||||
"""Should validate timeout is positive."""
|
||||
config = {
|
||||
'hardware_type': 'esp32',
|
||||
'sampling_rate': 100,
|
||||
'buffer_size': 1024,
|
||||
'timeout': -1.0
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="timeout must be positive"):
|
||||
CSIExtractor(config=config, logger=mock_logger)
|
||||
|
||||
# Test connection management
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_success(self, esp32_config, mock_logger):
|
||||
"""Should connect successfully."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
with patch.object(extractor, '_establish_hardware_connection', new_callable=AsyncMock) as mock_conn:
|
||||
mock_conn.return_value = True
|
||||
|
||||
result = await extractor.connect()
|
||||
|
||||
assert result == True
|
||||
assert extractor.is_connected == True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_failure(self, esp32_config, mock_logger):
|
||||
"""Should handle connection failure."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
with patch.object(extractor, '_establish_hardware_connection', new_callable=AsyncMock) as mock_conn:
|
||||
mock_conn.side_effect = ConnectionError("Failed")
|
||||
|
||||
result = await extractor.connect()
|
||||
|
||||
assert result == False
|
||||
assert extractor.is_connected == False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_when_connected(self, esp32_config, mock_logger):
|
||||
"""Should disconnect when connected."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
|
||||
with patch.object(extractor, '_close_hardware_connection', new_callable=AsyncMock) as mock_close:
|
||||
await extractor.disconnect()
|
||||
|
||||
assert extractor.is_connected == False
|
||||
mock_close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_when_not_connected(self, esp32_config, mock_logger):
|
||||
"""Should not disconnect when not connected."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = False
|
||||
|
||||
with patch.object(extractor, '_close_hardware_connection', new_callable=AsyncMock) as mock_close:
|
||||
await extractor.disconnect()
|
||||
|
||||
mock_close.assert_not_called()
|
||||
|
||||
# Test extraction
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_not_connected(self, esp32_config, mock_logger):
|
||||
"""Should raise error when not connected."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = False
|
||||
|
||||
with pytest.raises(CSIParseError, match="Not connected to hardware"):
|
||||
await extractor.extract_csi()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_success_with_validation(self, esp32_config, mock_logger, sample_csi_data):
|
||||
"""Should extract successfully with validation."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
|
||||
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
|
||||
with patch.object(extractor.parser, 'parse', return_value=sample_csi_data):
|
||||
with patch.object(extractor, 'validate_csi_data', return_value=True) as mock_validate:
|
||||
mock_read.return_value = b"raw_data"
|
||||
|
||||
result = await extractor.extract_csi()
|
||||
|
||||
assert result == sample_csi_data
|
||||
mock_validate.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_success_without_validation(self, esp32_config, mock_logger, sample_csi_data):
|
||||
"""Should extract successfully without validation."""
|
||||
esp32_config['validation_enabled'] = False
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
|
||||
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
|
||||
with patch.object(extractor.parser, 'parse', return_value=sample_csi_data):
|
||||
with patch.object(extractor, 'validate_csi_data') as mock_validate:
|
||||
mock_read.return_value = b"raw_data"
|
||||
|
||||
result = await extractor.extract_csi()
|
||||
|
||||
assert result == sample_csi_data
|
||||
mock_validate.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_retry_success(self, esp32_config, mock_logger, sample_csi_data):
|
||||
"""Should retry and succeed."""
|
||||
esp32_config['retry_attempts'] = 3
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
|
||||
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
|
||||
with patch.object(extractor.parser, 'parse', return_value=sample_csi_data):
|
||||
# Fail first two attempts, succeed on third
|
||||
mock_read.side_effect = [ConnectionError(), ConnectionError(), b"raw_data"]
|
||||
|
||||
result = await extractor.extract_csi()
|
||||
|
||||
assert result == sample_csi_data
|
||||
assert mock_read.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_retry_failure(self, esp32_config, mock_logger):
|
||||
"""Should fail after max retries."""
|
||||
esp32_config['retry_attempts'] = 2
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
|
||||
with patch.object(extractor, '_read_raw_data', new_callable=AsyncMock) as mock_read:
|
||||
mock_read.side_effect = ConnectionError("Failed")
|
||||
|
||||
with pytest.raises(CSIParseError, match="Extraction failed after 2 attempts"):
|
||||
await extractor.extract_csi()
|
||||
|
||||
# Test validation
|
||||
def test_validate_success(self, esp32_config, mock_logger, sample_csi_data):
|
||||
"""Should validate successfully."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
result = extractor.validate_csi_data(sample_csi_data)
|
||||
|
||||
assert result == True
|
||||
|
||||
def test_validate_empty_amplitude(self, esp32_config, mock_logger):
|
||||
"""Should reject empty amplitude."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.array([]),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Empty amplitude data"):
|
||||
extractor.validate_csi_data(data)
|
||||
|
||||
def test_validate_empty_phase(self, esp32_config, mock_logger):
|
||||
"""Should reject empty phase."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.array([]),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Empty phase data"):
|
||||
extractor.validate_csi_data(data)
|
||||
|
||||
def test_validate_invalid_frequency(self, esp32_config, mock_logger):
|
||||
"""Should reject invalid frequency."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=0,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid frequency"):
|
||||
extractor.validate_csi_data(data)
|
||||
|
||||
def test_validate_invalid_bandwidth(self, esp32_config, mock_logger):
|
||||
"""Should reject invalid bandwidth."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=0,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid bandwidth"):
|
||||
extractor.validate_csi_data(data)
|
||||
|
||||
def test_validate_invalid_subcarriers(self, esp32_config, mock_logger):
|
||||
"""Should reject invalid subcarriers."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=0,
|
||||
num_antennas=3,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid number of subcarriers"):
|
||||
extractor.validate_csi_data(data)
|
||||
|
||||
def test_validate_invalid_antennas(self, esp32_config, mock_logger):
|
||||
"""Should reject invalid antennas."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=0,
|
||||
snr=15.5,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid number of antennas"):
|
||||
extractor.validate_csi_data(data)
|
||||
|
||||
def test_validate_snr_too_low(self, esp32_config, mock_logger):
|
||||
"""Should reject SNR too low."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=-100,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid SNR value"):
|
||||
extractor.validate_csi_data(data)
|
||||
|
||||
def test_validate_snr_too_high(self, esp32_config, mock_logger):
|
||||
"""Should reject SNR too high."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
data = CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=100,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with pytest.raises(CSIValidationError, match="Invalid SNR value"):
|
||||
extractor.validate_csi_data(data)
|
||||
|
||||
# Test streaming
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_success(self, esp32_config, mock_logger, sample_csi_data):
|
||||
"""Should stream successfully."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
callback = Mock()
|
||||
|
||||
with patch.object(extractor, 'extract_csi', new_callable=AsyncMock) as mock_extract:
|
||||
mock_extract.return_value = sample_csi_data
|
||||
|
||||
# Start streaming task
|
||||
task = asyncio.create_task(extractor.start_streaming(callback))
|
||||
await asyncio.sleep(0.1) # Let it run briefly
|
||||
extractor.stop_streaming()
|
||||
await task
|
||||
|
||||
callback.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_exception(self, esp32_config, mock_logger):
|
||||
"""Should handle streaming exceptions."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_connected = True
|
||||
callback = Mock()
|
||||
|
||||
with patch.object(extractor, 'extract_csi', new_callable=AsyncMock) as mock_extract:
|
||||
mock_extract.side_effect = Exception("Test error")
|
||||
|
||||
# Start streaming and let it handle exception
|
||||
task = asyncio.create_task(extractor.start_streaming(callback))
|
||||
await task # This should complete due to exception
|
||||
|
||||
assert extractor.is_streaming == False
|
||||
|
||||
def test_stop_streaming(self, esp32_config, mock_logger):
|
||||
"""Should stop streaming."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
extractor.is_streaming = True
|
||||
|
||||
extractor.stop_streaming()
|
||||
|
||||
assert extractor.is_streaming == False
|
||||
|
||||
# Test placeholder implementations for 100% coverage
|
||||
@pytest.mark.asyncio
|
||||
async def test_establish_hardware_connection_placeholder(self, esp32_config, mock_logger):
|
||||
"""Should test placeholder hardware connection."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
result = await extractor._establish_hardware_connection()
|
||||
|
||||
assert result == True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_hardware_connection_placeholder(self, esp32_config, mock_logger):
|
||||
"""Should test placeholder hardware disconnection."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
# Should not raise any exception
|
||||
await extractor._close_hardware_connection()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_raw_data_placeholder(self, esp32_config, mock_logger):
|
||||
"""Should test placeholder raw data reading."""
|
||||
extractor = CSIExtractor(config=esp32_config, logger=mock_logger)
|
||||
|
||||
result = await extractor._read_raw_data()
|
||||
|
||||
assert result == b"CSI_DATA:1234567890,3,56,2400,20,15.5,[1.0,2.0,3.0],[0.5,1.5,2.5]"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
class TestESP32CSIParserStandalone:
|
||||
"""Standalone tests for ESP32 CSI parser."""
|
||||
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
"""Create parser instance."""
|
||||
return ESP32CSIParser()
|
||||
|
||||
def test_parse_valid_data(self, parser):
|
||||
"""Should parse valid ESP32 data."""
|
||||
data = b"CSI_DATA:1234567890,3,56,2400,20,15.5,[1.0,2.0,3.0],[0.5,1.5,2.5]"
|
||||
|
||||
result = parser.parse(data)
|
||||
|
||||
assert isinstance(result, CSIData)
|
||||
assert result.num_antennas == 3
|
||||
assert result.num_subcarriers == 56
|
||||
assert result.frequency == 2400000000
|
||||
assert result.bandwidth == 20000000
|
||||
assert result.snr == 15.5
|
||||
|
||||
def test_parse_empty_data(self, parser):
|
||||
"""Should reject empty data."""
|
||||
with pytest.raises(CSIParseError, match="Empty data received"):
|
||||
parser.parse(b"")
|
||||
|
||||
def test_parse_invalid_format(self, parser):
|
||||
"""Should reject invalid format."""
|
||||
with pytest.raises(CSIParseError, match="Invalid ESP32 CSI data format"):
|
||||
parser.parse(b"INVALID_DATA")
|
||||
|
||||
def test_parse_value_error(self, parser):
|
||||
"""Should handle ValueError."""
|
||||
data = b"CSI_DATA:invalid_number,3,56,2400,20,15.5"
|
||||
|
||||
with pytest.raises(CSIParseError, match="Failed to parse ESP32 data"):
|
||||
parser.parse(data)
|
||||
|
||||
def test_parse_index_error(self, parser):
|
||||
"""Should handle IndexError."""
|
||||
data = b"CSI_DATA:1234567890" # Missing fields
|
||||
|
||||
with pytest.raises(CSIParseError, match="Failed to parse ESP32 data"):
|
||||
parser.parse(data)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
class TestRouterCSIParserStandalone:
|
||||
"""Standalone tests for Router CSI parser."""
|
||||
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
"""Create parser instance."""
|
||||
return RouterCSIParser()
|
||||
|
||||
def test_parse_empty_data(self, parser):
|
||||
"""Should reject empty data."""
|
||||
with pytest.raises(CSIParseError, match="Empty data received"):
|
||||
parser.parse(b"")
|
||||
|
||||
def test_parse_atheros_format(self, parser):
|
||||
"""Should parse Atheros format."""
|
||||
data = b"ATHEROS_CSI:mock_data"
|
||||
|
||||
result = parser.parse(data)
|
||||
|
||||
assert isinstance(result, CSIData)
|
||||
assert result.metadata['source'] == 'atheros_router'
|
||||
|
||||
def test_parse_unknown_format(self, parser):
|
||||
"""Should reject unknown format."""
|
||||
data = b"UNKNOWN_FORMAT:data"
|
||||
|
||||
with pytest.raises(CSIParseError, match="Unknown router CSI format"):
|
||||
parser.parse(data)
|
||||
367
v1/tests/unit/test_densepose_head.py
Normal file
367
v1/tests/unit/test_densepose_head.py
Normal file
@@ -0,0 +1,367 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch
|
||||
from src.models.densepose_head import DensePoseHead, DensePoseError
|
||||
|
||||
|
||||
class TestDensePoseHead:
|
||||
"""Test suite for DensePose Head following London School TDD principles"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Configuration for DensePose head"""
|
||||
return {
|
||||
'input_channels': 256,
|
||||
'num_body_parts': 24,
|
||||
'num_uv_coordinates': 2,
|
||||
'hidden_channels': [128, 64],
|
||||
'kernel_size': 3,
|
||||
'padding': 1,
|
||||
'dropout_rate': 0.1,
|
||||
'use_deformable_conv': False,
|
||||
'use_fpn': True,
|
||||
'fpn_levels': [2, 3, 4, 5],
|
||||
'output_stride': 4
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def densepose_head(self, mock_config):
|
||||
"""Create DensePose head instance for testing"""
|
||||
return DensePoseHead(mock_config)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_feature_input(self):
|
||||
"""Generate mock feature input tensor"""
|
||||
batch_size = 2
|
||||
channels = 256
|
||||
height = 56
|
||||
width = 56
|
||||
return torch.randn(batch_size, channels, height, width)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_target_masks(self):
|
||||
"""Generate mock target segmentation masks"""
|
||||
batch_size = 2
|
||||
num_parts = 24
|
||||
height = 224
|
||||
width = 224
|
||||
return torch.randint(0, num_parts + 1, (batch_size, height, width))
|
||||
|
||||
@pytest.fixture
|
||||
def mock_target_uv(self):
|
||||
"""Generate mock target UV coordinates"""
|
||||
batch_size = 2
|
||||
num_coords = 2
|
||||
height = 224
|
||||
width = 224
|
||||
return torch.randn(batch_size, num_coords, height, width)
|
||||
|
||||
def test_head_initialization_creates_correct_architecture(self, mock_config):
|
||||
"""Test that DensePose head initializes with correct architecture"""
|
||||
# Act
|
||||
head = DensePoseHead(mock_config)
|
||||
|
||||
# Assert
|
||||
assert head is not None
|
||||
assert isinstance(head, nn.Module)
|
||||
assert head.input_channels == mock_config['input_channels']
|
||||
assert head.num_body_parts == mock_config['num_body_parts']
|
||||
assert head.num_uv_coordinates == mock_config['num_uv_coordinates']
|
||||
assert head.use_fpn == mock_config['use_fpn']
|
||||
assert hasattr(head, 'segmentation_head')
|
||||
assert hasattr(head, 'uv_regression_head')
|
||||
if mock_config['use_fpn']:
|
||||
assert hasattr(head, 'fpn')
|
||||
|
||||
def test_forward_pass_produces_correct_output_format(self, densepose_head, mock_feature_input):
|
||||
"""Test that forward pass produces correctly formatted output"""
|
||||
# Act
|
||||
output = densepose_head(mock_feature_input)
|
||||
|
||||
# Assert
|
||||
assert output is not None
|
||||
assert isinstance(output, dict)
|
||||
assert 'segmentation' in output
|
||||
assert 'uv_coordinates' in output
|
||||
|
||||
seg_output = output['segmentation']
|
||||
uv_output = output['uv_coordinates']
|
||||
|
||||
assert isinstance(seg_output, torch.Tensor)
|
||||
assert isinstance(uv_output, torch.Tensor)
|
||||
assert seg_output.shape[0] == mock_feature_input.shape[0] # Batch size preserved
|
||||
assert uv_output.shape[0] == mock_feature_input.shape[0] # Batch size preserved
|
||||
|
||||
def test_segmentation_head_produces_correct_shape(self, densepose_head, mock_feature_input):
|
||||
"""Test that segmentation head produces correct output shape"""
|
||||
# Act
|
||||
output = densepose_head(mock_feature_input)
|
||||
seg_output = output['segmentation']
|
||||
|
||||
# Assert
|
||||
expected_channels = densepose_head.num_body_parts + 1 # +1 for background
|
||||
assert seg_output.shape[1] == expected_channels
|
||||
assert seg_output.shape[2] >= mock_feature_input.shape[2] # Height upsampled
|
||||
assert seg_output.shape[3] >= mock_feature_input.shape[3] # Width upsampled
|
||||
|
||||
def test_uv_regression_head_produces_correct_shape(self, densepose_head, mock_feature_input):
|
||||
"""Test that UV regression head produces correct output shape"""
|
||||
# Act
|
||||
output = densepose_head(mock_feature_input)
|
||||
uv_output = output['uv_coordinates']
|
||||
|
||||
# Assert
|
||||
assert uv_output.shape[1] == densepose_head.num_uv_coordinates
|
||||
assert uv_output.shape[2] >= mock_feature_input.shape[2] # Height upsampled
|
||||
assert uv_output.shape[3] >= mock_feature_input.shape[3] # Width upsampled
|
||||
|
||||
def test_compute_segmentation_loss_measures_pixel_classification(self, densepose_head, mock_feature_input, mock_target_masks):
|
||||
"""Test that compute_segmentation_loss measures pixel classification accuracy"""
|
||||
# Arrange
|
||||
output = densepose_head(mock_feature_input)
|
||||
seg_logits = output['segmentation']
|
||||
|
||||
# Resize target to match output
|
||||
target_resized = torch.nn.functional.interpolate(
|
||||
mock_target_masks.float().unsqueeze(1),
|
||||
size=seg_logits.shape[2:],
|
||||
mode='nearest'
|
||||
).squeeze(1).long()
|
||||
|
||||
# Act
|
||||
loss = densepose_head.compute_segmentation_loss(seg_logits, target_resized)
|
||||
|
||||
# Assert
|
||||
assert loss is not None
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.dim() == 0 # Scalar loss
|
||||
assert loss.item() >= 0 # Loss should be non-negative
|
||||
|
||||
def test_compute_uv_loss_measures_coordinate_regression(self, densepose_head, mock_feature_input, mock_target_uv):
|
||||
"""Test that compute_uv_loss measures UV coordinate regression accuracy"""
|
||||
# Arrange
|
||||
output = densepose_head(mock_feature_input)
|
||||
uv_pred = output['uv_coordinates']
|
||||
|
||||
# Resize target to match output
|
||||
target_resized = torch.nn.functional.interpolate(
|
||||
mock_target_uv,
|
||||
size=uv_pred.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Act
|
||||
loss = densepose_head.compute_uv_loss(uv_pred, target_resized)
|
||||
|
||||
# Assert
|
||||
assert loss is not None
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.dim() == 0 # Scalar loss
|
||||
assert loss.item() >= 0 # Loss should be non-negative
|
||||
|
||||
def test_compute_total_loss_combines_segmentation_and_uv_losses(self, densepose_head, mock_feature_input, mock_target_masks, mock_target_uv):
|
||||
"""Test that compute_total_loss combines segmentation and UV losses"""
|
||||
# Arrange
|
||||
output = densepose_head(mock_feature_input)
|
||||
|
||||
# Resize targets to match outputs
|
||||
seg_target = torch.nn.functional.interpolate(
|
||||
mock_target_masks.float().unsqueeze(1),
|
||||
size=output['segmentation'].shape[2:],
|
||||
mode='nearest'
|
||||
).squeeze(1).long()
|
||||
|
||||
uv_target = torch.nn.functional.interpolate(
|
||||
mock_target_uv,
|
||||
size=output['uv_coordinates'].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Act
|
||||
total_loss = densepose_head.compute_total_loss(output, seg_target, uv_target)
|
||||
seg_loss = densepose_head.compute_segmentation_loss(output['segmentation'], seg_target)
|
||||
uv_loss = densepose_head.compute_uv_loss(output['uv_coordinates'], uv_target)
|
||||
|
||||
# Assert
|
||||
assert total_loss is not None
|
||||
assert isinstance(total_loss, torch.Tensor)
|
||||
assert total_loss.item() > 0
|
||||
# Total loss should be combination of individual losses
|
||||
expected_total = seg_loss + uv_loss
|
||||
assert torch.allclose(total_loss, expected_total, atol=1e-6)
|
||||
|
||||
def test_fpn_integration_enhances_multi_scale_features(self, mock_config, mock_feature_input):
|
||||
"""Test that FPN integration enhances multi-scale feature processing"""
|
||||
# Arrange
|
||||
config_with_fpn = mock_config.copy()
|
||||
config_with_fpn['use_fpn'] = True
|
||||
|
||||
config_without_fpn = mock_config.copy()
|
||||
config_without_fpn['use_fpn'] = False
|
||||
|
||||
head_with_fpn = DensePoseHead(config_with_fpn)
|
||||
head_without_fpn = DensePoseHead(config_without_fpn)
|
||||
|
||||
# Act
|
||||
output_with_fpn = head_with_fpn(mock_feature_input)
|
||||
output_without_fpn = head_without_fpn(mock_feature_input)
|
||||
|
||||
# Assert
|
||||
assert output_with_fpn['segmentation'].shape == output_without_fpn['segmentation'].shape
|
||||
assert output_with_fpn['uv_coordinates'].shape == output_without_fpn['uv_coordinates'].shape
|
||||
# Outputs should be different due to FPN
|
||||
assert not torch.allclose(output_with_fpn['segmentation'], output_without_fpn['segmentation'], atol=1e-6)
|
||||
|
||||
def test_get_prediction_confidence_provides_uncertainty_estimates(self, densepose_head, mock_feature_input):
|
||||
"""Test that get_prediction_confidence provides uncertainty estimates"""
|
||||
# Arrange
|
||||
output = densepose_head(mock_feature_input)
|
||||
|
||||
# Act
|
||||
confidence = densepose_head.get_prediction_confidence(output)
|
||||
|
||||
# Assert
|
||||
assert confidence is not None
|
||||
assert isinstance(confidence, dict)
|
||||
assert 'segmentation_confidence' in confidence
|
||||
assert 'uv_confidence' in confidence
|
||||
|
||||
seg_conf = confidence['segmentation_confidence']
|
||||
uv_conf = confidence['uv_confidence']
|
||||
|
||||
assert isinstance(seg_conf, torch.Tensor)
|
||||
assert isinstance(uv_conf, torch.Tensor)
|
||||
assert seg_conf.shape[0] == mock_feature_input.shape[0]
|
||||
assert uv_conf.shape[0] == mock_feature_input.shape[0]
|
||||
|
||||
def test_post_process_predictions_formats_output(self, densepose_head, mock_feature_input):
|
||||
"""Test that post_process_predictions formats output correctly"""
|
||||
# Arrange
|
||||
raw_output = densepose_head(mock_feature_input)
|
||||
|
||||
# Act
|
||||
processed = densepose_head.post_process_predictions(raw_output)
|
||||
|
||||
# Assert
|
||||
assert processed is not None
|
||||
assert isinstance(processed, dict)
|
||||
assert 'body_parts' in processed
|
||||
assert 'uv_coordinates' in processed
|
||||
assert 'confidence_scores' in processed
|
||||
|
||||
def test_training_mode_enables_dropout(self, densepose_head, mock_feature_input):
|
||||
"""Test that training mode enables dropout for regularization"""
|
||||
# Arrange
|
||||
densepose_head.train()
|
||||
|
||||
# Act
|
||||
output1 = densepose_head(mock_feature_input)
|
||||
output2 = densepose_head(mock_feature_input)
|
||||
|
||||
# Assert - outputs should be different due to dropout
|
||||
assert not torch.allclose(output1['segmentation'], output2['segmentation'], atol=1e-6)
|
||||
assert not torch.allclose(output1['uv_coordinates'], output2['uv_coordinates'], atol=1e-6)
|
||||
|
||||
def test_evaluation_mode_disables_dropout(self, densepose_head, mock_feature_input):
|
||||
"""Test that evaluation mode disables dropout for consistent inference"""
|
||||
# Arrange
|
||||
densepose_head.eval()
|
||||
|
||||
# Act
|
||||
output1 = densepose_head(mock_feature_input)
|
||||
output2 = densepose_head(mock_feature_input)
|
||||
|
||||
# Assert - outputs should be identical in eval mode
|
||||
assert torch.allclose(output1['segmentation'], output2['segmentation'], atol=1e-6)
|
||||
assert torch.allclose(output1['uv_coordinates'], output2['uv_coordinates'], atol=1e-6)
|
||||
|
||||
def test_head_validates_input_dimensions(self, densepose_head):
|
||||
"""Test that head validates input dimensions"""
|
||||
# Arrange
|
||||
invalid_input = torch.randn(2, 128, 56, 56) # Wrong number of channels
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(DensePoseError):
|
||||
densepose_head(invalid_input)
|
||||
|
||||
def test_head_handles_different_input_sizes(self, densepose_head):
|
||||
"""Test that head handles different input sizes"""
|
||||
# Arrange
|
||||
small_input = torch.randn(1, 256, 28, 28)
|
||||
large_input = torch.randn(1, 256, 112, 112)
|
||||
|
||||
# Act
|
||||
small_output = densepose_head(small_input)
|
||||
large_output = densepose_head(large_input)
|
||||
|
||||
# Assert
|
||||
assert small_output['segmentation'].shape[2:] != large_output['segmentation'].shape[2:]
|
||||
assert small_output['uv_coordinates'].shape[2:] != large_output['uv_coordinates'].shape[2:]
|
||||
|
||||
def test_head_supports_gradient_computation(self, densepose_head, mock_feature_input, mock_target_masks, mock_target_uv):
|
||||
"""Test that head supports gradient computation for training"""
|
||||
# Arrange
|
||||
densepose_head.train()
|
||||
optimizer = torch.optim.Adam(densepose_head.parameters(), lr=0.001)
|
||||
|
||||
output = densepose_head(mock_feature_input)
|
||||
|
||||
# Resize targets
|
||||
seg_target = torch.nn.functional.interpolate(
|
||||
mock_target_masks.float().unsqueeze(1),
|
||||
size=output['segmentation'].shape[2:],
|
||||
mode='nearest'
|
||||
).squeeze(1).long()
|
||||
|
||||
uv_target = torch.nn.functional.interpolate(
|
||||
mock_target_uv,
|
||||
size=output['uv_coordinates'].shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Act
|
||||
loss = densepose_head.compute_total_loss(output, seg_target, uv_target)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
# Assert
|
||||
for param in densepose_head.parameters():
|
||||
if param.requires_grad:
|
||||
assert param.grad is not None
|
||||
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
|
||||
|
||||
def test_head_configuration_validation(self):
|
||||
"""Test that head validates configuration parameters"""
|
||||
# Arrange
|
||||
invalid_config = {
|
||||
'input_channels': 0, # Invalid
|
||||
'num_body_parts': -1, # Invalid
|
||||
'num_uv_coordinates': 2
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
DensePoseHead(invalid_config)
|
||||
|
||||
def test_save_and_load_model_state(self, densepose_head, mock_feature_input):
|
||||
"""Test that model state can be saved and loaded"""
|
||||
# Arrange
|
||||
original_output = densepose_head(mock_feature_input)
|
||||
|
||||
# Act - Save state
|
||||
state_dict = densepose_head.state_dict()
|
||||
|
||||
# Create new head and load state
|
||||
new_head = DensePoseHead(densepose_head.config)
|
||||
new_head.load_state_dict(state_dict)
|
||||
new_output = new_head(mock_feature_input)
|
||||
|
||||
# Assert
|
||||
assert torch.allclose(original_output['segmentation'], new_output['segmentation'], atol=1e-6)
|
||||
assert torch.allclose(original_output['uv_coordinates'], new_output['uv_coordinates'], atol=1e-6)
|
||||
293
v1/tests/unit/test_modality_translation.py
Normal file
293
v1/tests/unit/test_modality_translation.py
Normal file
@@ -0,0 +1,293 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch
|
||||
from src.models.modality_translation import ModalityTranslationNetwork, ModalityTranslationError
|
||||
|
||||
|
||||
class TestModalityTranslationNetwork:
|
||||
"""Test suite for Modality Translation Network following London School TDD principles"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Configuration for modality translation network"""
|
||||
return {
|
||||
'input_channels': 6, # Real and imaginary parts for 3 antennas
|
||||
'hidden_channels': [64, 128, 256],
|
||||
'output_channels': 256,
|
||||
'kernel_size': 3,
|
||||
'stride': 1,
|
||||
'padding': 1,
|
||||
'dropout_rate': 0.1,
|
||||
'activation': 'relu',
|
||||
'normalization': 'batch',
|
||||
'use_attention': True,
|
||||
'attention_heads': 8
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def translation_network(self, mock_config):
|
||||
"""Create modality translation network instance for testing"""
|
||||
return ModalityTranslationNetwork(mock_config)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_csi_input(self):
|
||||
"""Generate mock CSI input tensor"""
|
||||
batch_size = 4
|
||||
channels = 6 # Real and imaginary parts for 3 antennas
|
||||
height = 56 # Number of subcarriers
|
||||
width = 100 # Time samples
|
||||
return torch.randn(batch_size, channels, height, width)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_target_features(self):
|
||||
"""Generate mock target feature tensor for training"""
|
||||
batch_size = 4
|
||||
feature_dim = 256
|
||||
spatial_height = 56
|
||||
spatial_width = 100
|
||||
return torch.randn(batch_size, feature_dim, spatial_height, spatial_width)
|
||||
|
||||
def test_network_initialization_creates_correct_architecture(self, mock_config):
|
||||
"""Test that modality translation network initializes with correct architecture"""
|
||||
# Act
|
||||
network = ModalityTranslationNetwork(mock_config)
|
||||
|
||||
# Assert
|
||||
assert network is not None
|
||||
assert isinstance(network, nn.Module)
|
||||
assert network.input_channels == mock_config['input_channels']
|
||||
assert network.output_channels == mock_config['output_channels']
|
||||
assert network.use_attention == mock_config['use_attention']
|
||||
assert hasattr(network, 'encoder')
|
||||
assert hasattr(network, 'decoder')
|
||||
if mock_config['use_attention']:
|
||||
assert hasattr(network, 'attention')
|
||||
|
||||
def test_forward_pass_produces_correct_output_shape(self, translation_network, mock_csi_input):
|
||||
"""Test that forward pass produces correctly shaped output"""
|
||||
# Act
|
||||
output = translation_network(mock_csi_input)
|
||||
|
||||
# Assert
|
||||
assert output is not None
|
||||
assert isinstance(output, torch.Tensor)
|
||||
assert output.shape[0] == mock_csi_input.shape[0] # Batch size preserved
|
||||
assert output.shape[1] == translation_network.output_channels # Correct output channels
|
||||
assert output.shape[2] == mock_csi_input.shape[2] # Spatial height preserved
|
||||
assert output.shape[3] == mock_csi_input.shape[3] # Spatial width preserved
|
||||
|
||||
def test_forward_pass_handles_different_input_sizes(self, translation_network):
|
||||
"""Test that forward pass handles different input sizes"""
|
||||
# Arrange
|
||||
small_input = torch.randn(2, 6, 28, 50)
|
||||
large_input = torch.randn(8, 6, 112, 200)
|
||||
|
||||
# Act
|
||||
small_output = translation_network(small_input)
|
||||
large_output = translation_network(large_input)
|
||||
|
||||
# Assert
|
||||
assert small_output.shape == (2, 256, 28, 50)
|
||||
assert large_output.shape == (8, 256, 112, 200)
|
||||
|
||||
def test_encoder_extracts_hierarchical_features(self, translation_network, mock_csi_input):
|
||||
"""Test that encoder extracts hierarchical features"""
|
||||
# Act
|
||||
features = translation_network.encode(mock_csi_input)
|
||||
|
||||
# Assert
|
||||
assert features is not None
|
||||
assert isinstance(features, list)
|
||||
assert len(features) == len(translation_network.encoder)
|
||||
|
||||
# Check feature map sizes decrease with depth
|
||||
for i in range(1, len(features)):
|
||||
assert features[i].shape[2] <= features[i-1].shape[2] # Height decreases or stays same
|
||||
assert features[i].shape[3] <= features[i-1].shape[3] # Width decreases or stays same
|
||||
|
||||
def test_decoder_reconstructs_target_features(self, translation_network, mock_csi_input):
|
||||
"""Test that decoder reconstructs target feature representation"""
|
||||
# Arrange
|
||||
encoded_features = translation_network.encode(mock_csi_input)
|
||||
|
||||
# Act
|
||||
decoded_output = translation_network.decode(encoded_features)
|
||||
|
||||
# Assert
|
||||
assert decoded_output is not None
|
||||
assert isinstance(decoded_output, torch.Tensor)
|
||||
assert decoded_output.shape[1] == translation_network.output_channels
|
||||
assert decoded_output.shape[2:] == mock_csi_input.shape[2:]
|
||||
|
||||
def test_attention_mechanism_enhances_features(self, mock_config, mock_csi_input):
|
||||
"""Test that attention mechanism enhances feature representation"""
|
||||
# Arrange
|
||||
config_with_attention = mock_config.copy()
|
||||
config_with_attention['use_attention'] = True
|
||||
|
||||
config_without_attention = mock_config.copy()
|
||||
config_without_attention['use_attention'] = False
|
||||
|
||||
network_with_attention = ModalityTranslationNetwork(config_with_attention)
|
||||
network_without_attention = ModalityTranslationNetwork(config_without_attention)
|
||||
|
||||
# Act
|
||||
output_with_attention = network_with_attention(mock_csi_input)
|
||||
output_without_attention = network_without_attention(mock_csi_input)
|
||||
|
||||
# Assert
|
||||
assert output_with_attention.shape == output_without_attention.shape
|
||||
# Outputs should be different due to attention mechanism
|
||||
assert not torch.allclose(output_with_attention, output_without_attention, atol=1e-6)
|
||||
|
||||
def test_training_mode_enables_dropout(self, translation_network, mock_csi_input):
|
||||
"""Test that training mode enables dropout for regularization"""
|
||||
# Arrange
|
||||
translation_network.train()
|
||||
|
||||
# Act
|
||||
output1 = translation_network(mock_csi_input)
|
||||
output2 = translation_network(mock_csi_input)
|
||||
|
||||
# Assert - outputs should be different due to dropout
|
||||
assert not torch.allclose(output1, output2, atol=1e-6)
|
||||
|
||||
def test_evaluation_mode_disables_dropout(self, translation_network, mock_csi_input):
|
||||
"""Test that evaluation mode disables dropout for consistent inference"""
|
||||
# Arrange
|
||||
translation_network.eval()
|
||||
|
||||
# Act
|
||||
output1 = translation_network(mock_csi_input)
|
||||
output2 = translation_network(mock_csi_input)
|
||||
|
||||
# Assert - outputs should be identical in eval mode
|
||||
assert torch.allclose(output1, output2, atol=1e-6)
|
||||
|
||||
def test_compute_translation_loss_measures_feature_alignment(self, translation_network, mock_csi_input, mock_target_features):
|
||||
"""Test that compute_translation_loss measures feature alignment"""
|
||||
# Arrange
|
||||
predicted_features = translation_network(mock_csi_input)
|
||||
|
||||
# Act
|
||||
loss = translation_network.compute_translation_loss(predicted_features, mock_target_features)
|
||||
|
||||
# Assert
|
||||
assert loss is not None
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.dim() == 0 # Scalar loss
|
||||
assert loss.item() >= 0 # Loss should be non-negative
|
||||
|
||||
def test_compute_translation_loss_handles_different_loss_types(self, translation_network, mock_csi_input, mock_target_features):
|
||||
"""Test that compute_translation_loss handles different loss types"""
|
||||
# Arrange
|
||||
predicted_features = translation_network(mock_csi_input)
|
||||
|
||||
# Act
|
||||
mse_loss = translation_network.compute_translation_loss(predicted_features, mock_target_features, loss_type='mse')
|
||||
l1_loss = translation_network.compute_translation_loss(predicted_features, mock_target_features, loss_type='l1')
|
||||
|
||||
# Assert
|
||||
assert mse_loss is not None
|
||||
assert l1_loss is not None
|
||||
assert mse_loss.item() != l1_loss.item() # Different loss types should give different values
|
||||
|
||||
def test_get_feature_statistics_provides_analysis(self, translation_network, mock_csi_input):
|
||||
"""Test that get_feature_statistics provides feature analysis"""
|
||||
# Arrange
|
||||
output = translation_network(mock_csi_input)
|
||||
|
||||
# Act
|
||||
stats = translation_network.get_feature_statistics(output)
|
||||
|
||||
# Assert
|
||||
assert stats is not None
|
||||
assert isinstance(stats, dict)
|
||||
assert 'mean' in stats
|
||||
assert 'std' in stats
|
||||
assert 'min' in stats
|
||||
assert 'max' in stats
|
||||
assert 'sparsity' in stats
|
||||
|
||||
def test_network_supports_gradient_computation(self, translation_network, mock_csi_input, mock_target_features):
|
||||
"""Test that network supports gradient computation for training"""
|
||||
# Arrange
|
||||
translation_network.train()
|
||||
optimizer = torch.optim.Adam(translation_network.parameters(), lr=0.001)
|
||||
|
||||
# Act
|
||||
output = translation_network(mock_csi_input)
|
||||
loss = translation_network.compute_translation_loss(output, mock_target_features)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
# Assert
|
||||
for param in translation_network.parameters():
|
||||
if param.requires_grad:
|
||||
assert param.grad is not None
|
||||
assert not torch.allclose(param.grad, torch.zeros_like(param.grad))
|
||||
|
||||
def test_network_validates_input_dimensions(self, translation_network):
|
||||
"""Test that network validates input dimensions"""
|
||||
# Arrange
|
||||
invalid_input = torch.randn(4, 3, 56, 100) # Wrong number of channels
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ModalityTranslationError):
|
||||
translation_network(invalid_input)
|
||||
|
||||
def test_network_handles_batch_size_one(self, translation_network):
|
||||
"""Test that network handles single sample inference"""
|
||||
# Arrange
|
||||
single_input = torch.randn(1, 6, 56, 100)
|
||||
|
||||
# Act
|
||||
output = translation_network(single_input)
|
||||
|
||||
# Assert
|
||||
assert output.shape == (1, 256, 56, 100)
|
||||
|
||||
def test_save_and_load_model_state(self, translation_network, mock_csi_input):
|
||||
"""Test that model state can be saved and loaded"""
|
||||
# Arrange
|
||||
original_output = translation_network(mock_csi_input)
|
||||
|
||||
# Act - Save state
|
||||
state_dict = translation_network.state_dict()
|
||||
|
||||
# Create new network and load state
|
||||
new_network = ModalityTranslationNetwork(translation_network.config)
|
||||
new_network.load_state_dict(state_dict)
|
||||
new_output = new_network(mock_csi_input)
|
||||
|
||||
# Assert
|
||||
assert torch.allclose(original_output, new_output, atol=1e-6)
|
||||
|
||||
def test_network_configuration_validation(self):
|
||||
"""Test that network validates configuration parameters"""
|
||||
# Arrange
|
||||
invalid_config = {
|
||||
'input_channels': 0, # Invalid
|
||||
'hidden_channels': [], # Invalid
|
||||
'output_channels': 256
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
ModalityTranslationNetwork(invalid_config)
|
||||
|
||||
def test_feature_visualization_support(self, translation_network, mock_csi_input):
|
||||
"""Test that network supports feature visualization"""
|
||||
# Act
|
||||
features = translation_network.get_intermediate_features(mock_csi_input)
|
||||
|
||||
# Assert
|
||||
assert features is not None
|
||||
assert isinstance(features, dict)
|
||||
assert 'encoder_features' in features
|
||||
assert 'decoder_features' in features
|
||||
if translation_network.use_attention:
|
||||
assert 'attention_weights' in features
|
||||
107
v1/tests/unit/test_phase_sanitizer.py
Normal file
107
v1/tests/unit/test_phase_sanitizer.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch
|
||||
from src.core.phase_sanitizer import PhaseSanitizer
|
||||
|
||||
|
||||
class TestPhaseSanitizer:
|
||||
"""Test suite for Phase Sanitizer following London School TDD principles"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_phase_data(self):
|
||||
"""Generate synthetic phase data for testing"""
|
||||
# Phase data with unwrapping issues and outliers
|
||||
return np.array([
|
||||
[0.1, 0.2, 6.0, 0.4, 0.5], # Contains phase jump at index 2
|
||||
[-3.0, -0.1, 0.0, 0.1, 0.2], # Contains wrapped phase at index 0
|
||||
[0.0, 0.1, 0.2, 0.3, 0.4] # Clean phase data
|
||||
])
|
||||
|
||||
@pytest.fixture
|
||||
def phase_sanitizer(self):
|
||||
"""Create Phase Sanitizer instance for testing"""
|
||||
return PhaseSanitizer()
|
||||
|
||||
def test_unwrap_phase_removes_discontinuities(self, phase_sanitizer, mock_phase_data):
|
||||
"""Test that phase unwrapping removes 2π discontinuities"""
|
||||
# Act
|
||||
result = phase_sanitizer.unwrap_phase(mock_phase_data)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert isinstance(result, np.ndarray)
|
||||
assert result.shape == mock_phase_data.shape
|
||||
|
||||
# Check that large jumps are reduced
|
||||
for i in range(result.shape[0]):
|
||||
phase_diffs = np.abs(np.diff(result[i]))
|
||||
assert np.all(phase_diffs < np.pi) # No jumps larger than π
|
||||
|
||||
def test_remove_outliers_filters_anomalous_values(self, phase_sanitizer, mock_phase_data):
|
||||
"""Test that outlier removal filters anomalous phase values"""
|
||||
# Arrange - Add clear outliers
|
||||
outlier_data = mock_phase_data.copy()
|
||||
outlier_data[0, 2] = 100.0 # Clear outlier
|
||||
|
||||
# Act
|
||||
result = phase_sanitizer.remove_outliers(outlier_data)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert isinstance(result, np.ndarray)
|
||||
assert result.shape == outlier_data.shape
|
||||
assert np.abs(result[0, 2]) < 10.0 # Outlier should be corrected
|
||||
|
||||
def test_smooth_phase_reduces_noise(self, phase_sanitizer, mock_phase_data):
|
||||
"""Test that phase smoothing reduces noise while preserving trends"""
|
||||
# Arrange - Add noise
|
||||
noisy_data = mock_phase_data + np.random.normal(0, 0.1, mock_phase_data.shape)
|
||||
|
||||
# Act
|
||||
result = phase_sanitizer.smooth_phase(noisy_data)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert isinstance(result, np.ndarray)
|
||||
assert result.shape == noisy_data.shape
|
||||
|
||||
# Smoothed data should have lower variance
|
||||
original_variance = np.var(noisy_data)
|
||||
smoothed_variance = np.var(result)
|
||||
assert smoothed_variance <= original_variance
|
||||
|
||||
def test_sanitize_handles_empty_input(self, phase_sanitizer):
|
||||
"""Test that sanitizer handles empty input gracefully"""
|
||||
# Arrange
|
||||
empty_data = np.array([])
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Phase data cannot be empty"):
|
||||
phase_sanitizer.sanitize(empty_data)
|
||||
|
||||
def test_sanitize_full_pipeline_integration(self, phase_sanitizer, mock_phase_data):
|
||||
"""Test that full sanitization pipeline works correctly"""
|
||||
# Act
|
||||
result = phase_sanitizer.sanitize(mock_phase_data)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert isinstance(result, np.ndarray)
|
||||
assert result.shape == mock_phase_data.shape
|
||||
|
||||
# Result should be within reasonable phase bounds
|
||||
assert np.all(result >= -2*np.pi)
|
||||
assert np.all(result <= 2*np.pi)
|
||||
|
||||
def test_sanitize_performance_requirement(self, phase_sanitizer, mock_phase_data):
|
||||
"""Test that phase sanitization meets performance requirements (<5ms)"""
|
||||
import time
|
||||
|
||||
# Act
|
||||
start_time = time.time()
|
||||
result = phase_sanitizer.sanitize(mock_phase_data)
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
# Assert
|
||||
assert processing_time < 0.005 # <5ms requirement
|
||||
assert result is not None
|
||||
407
v1/tests/unit/test_phase_sanitizer_tdd.py
Normal file
407
v1/tests/unit/test_phase_sanitizer_tdd.py
Normal file
@@ -0,0 +1,407 @@
|
||||
"""TDD tests for phase sanitizer following London School approach."""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
from datetime import datetime, timezone
|
||||
import importlib.util
|
||||
|
||||
# Import the phase sanitizer module directly
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
'phase_sanitizer',
|
||||
'/workspaces/wifi-densepose/src/core/phase_sanitizer.py'
|
||||
)
|
||||
phase_sanitizer_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(phase_sanitizer_module)
|
||||
|
||||
# Get classes from the module
|
||||
PhaseSanitizer = phase_sanitizer_module.PhaseSanitizer
|
||||
PhaseSanitizationError = phase_sanitizer_module.PhaseSanitizationError
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestPhaseSanitizer:
|
||||
"""Test phase sanitizer using London School TDD."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger(self):
|
||||
"""Mock logger for testing."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def sanitizer_config(self):
|
||||
"""Phase sanitizer configuration for testing."""
|
||||
return {
|
||||
'unwrapping_method': 'numpy',
|
||||
'outlier_threshold': 3.0,
|
||||
'smoothing_window': 5,
|
||||
'enable_outlier_removal': True,
|
||||
'enable_smoothing': True,
|
||||
'enable_noise_filtering': True,
|
||||
'noise_threshold': 0.1,
|
||||
'phase_range': (-np.pi, np.pi)
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def phase_sanitizer(self, sanitizer_config, mock_logger):
|
||||
"""Create phase sanitizer for testing."""
|
||||
return PhaseSanitizer(config=sanitizer_config, logger=mock_logger)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_wrapped_phase(self):
|
||||
"""Sample wrapped phase data with discontinuities."""
|
||||
# Create phase data with wrapping
|
||||
phase = np.linspace(0, 4*np.pi, 100)
|
||||
wrapped_phase = np.angle(np.exp(1j * phase)) # Wrap to [-π, π]
|
||||
return wrapped_phase.reshape(1, -1) # Shape: (1, 100)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_noisy_phase(self):
|
||||
"""Sample phase data with noise and outliers."""
|
||||
clean_phase = np.linspace(-np.pi, np.pi, 50)
|
||||
noise = np.random.normal(0, 0.05, 50)
|
||||
# Add some outliers
|
||||
outliers = np.random.choice(50, 5, replace=False)
|
||||
noisy_phase = clean_phase + noise
|
||||
noisy_phase[outliers] += np.random.uniform(-2, 2, 5) # Add outliers
|
||||
return noisy_phase.reshape(1, -1)
|
||||
|
||||
# Initialization tests
|
||||
def test_should_initialize_with_valid_config(self, sanitizer_config, mock_logger):
|
||||
"""Should initialize phase sanitizer with valid configuration."""
|
||||
sanitizer = PhaseSanitizer(config=sanitizer_config, logger=mock_logger)
|
||||
|
||||
assert sanitizer.config == sanitizer_config
|
||||
assert sanitizer.logger == mock_logger
|
||||
assert sanitizer.unwrapping_method == 'numpy'
|
||||
assert sanitizer.outlier_threshold == 3.0
|
||||
assert sanitizer.smoothing_window == 5
|
||||
assert sanitizer.enable_outlier_removal == True
|
||||
assert sanitizer.enable_smoothing == True
|
||||
assert sanitizer.enable_noise_filtering == True
|
||||
assert sanitizer.noise_threshold == 0.1
|
||||
assert sanitizer.phase_range == (-np.pi, np.pi)
|
||||
|
||||
def test_should_raise_error_with_invalid_config(self, mock_logger):
|
||||
"""Should raise error when initialized with invalid configuration."""
|
||||
invalid_config = {'invalid': 'config'}
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required configuration"):
|
||||
PhaseSanitizer(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_should_validate_required_fields(self, mock_logger):
|
||||
"""Should validate required configuration fields."""
|
||||
required_fields = ['unwrapping_method', 'outlier_threshold', 'smoothing_window']
|
||||
base_config = {
|
||||
'unwrapping_method': 'numpy',
|
||||
'outlier_threshold': 3.0,
|
||||
'smoothing_window': 5
|
||||
}
|
||||
|
||||
for field in required_fields:
|
||||
config = base_config.copy()
|
||||
del config[field]
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required configuration"):
|
||||
PhaseSanitizer(config=config, logger=mock_logger)
|
||||
|
||||
def test_should_use_default_values(self, mock_logger):
|
||||
"""Should use default values for optional parameters."""
|
||||
minimal_config = {
|
||||
'unwrapping_method': 'numpy',
|
||||
'outlier_threshold': 3.0,
|
||||
'smoothing_window': 5
|
||||
}
|
||||
|
||||
sanitizer = PhaseSanitizer(config=minimal_config, logger=mock_logger)
|
||||
|
||||
assert sanitizer.enable_outlier_removal == True # default
|
||||
assert sanitizer.enable_smoothing == True # default
|
||||
assert sanitizer.enable_noise_filtering == False # default
|
||||
assert sanitizer.noise_threshold == 0.05 # default
|
||||
assert sanitizer.phase_range == (-np.pi, np.pi) # default
|
||||
|
||||
def test_should_initialize_without_logger(self, sanitizer_config):
|
||||
"""Should initialize without logger provided."""
|
||||
sanitizer = PhaseSanitizer(config=sanitizer_config)
|
||||
|
||||
assert sanitizer.logger is not None # Should create default logger
|
||||
|
||||
# Phase unwrapping tests
|
||||
def test_should_unwrap_phase_successfully(self, phase_sanitizer, sample_wrapped_phase):
|
||||
"""Should unwrap phase data successfully."""
|
||||
result = phase_sanitizer.unwrap_phase(sample_wrapped_phase)
|
||||
|
||||
# Check that result has same shape
|
||||
assert result.shape == sample_wrapped_phase.shape
|
||||
|
||||
# Check that unwrapping removed discontinuities
|
||||
phase_diff = np.diff(result.flatten())
|
||||
large_jumps = np.abs(phase_diff) > np.pi
|
||||
assert np.sum(large_jumps) < np.sum(np.abs(np.diff(sample_wrapped_phase.flatten())) > np.pi)
|
||||
|
||||
def test_should_handle_different_unwrapping_methods(self, sanitizer_config, mock_logger):
|
||||
"""Should handle different unwrapping methods."""
|
||||
for method in ['numpy', 'scipy', 'custom']:
|
||||
sanitizer_config['unwrapping_method'] = method
|
||||
sanitizer = PhaseSanitizer(config=sanitizer_config, logger=mock_logger)
|
||||
|
||||
phase_data = np.random.uniform(-np.pi, np.pi, (2, 50))
|
||||
|
||||
with patch.object(sanitizer, f'_unwrap_{method}', return_value=phase_data) as mock_unwrap:
|
||||
result = sanitizer.unwrap_phase(phase_data)
|
||||
|
||||
assert result.shape == phase_data.shape
|
||||
mock_unwrap.assert_called_once()
|
||||
|
||||
def test_should_handle_unwrapping_error(self, phase_sanitizer):
|
||||
"""Should handle phase unwrapping errors gracefully."""
|
||||
invalid_phase = np.array([[]]) # Empty array
|
||||
|
||||
with pytest.raises(PhaseSanitizationError, match="Failed to unwrap phase"):
|
||||
phase_sanitizer.unwrap_phase(invalid_phase)
|
||||
|
||||
# Outlier removal tests
|
||||
def test_should_remove_outliers_successfully(self, phase_sanitizer, sample_noisy_phase):
|
||||
"""Should remove outliers from phase data successfully."""
|
||||
with patch.object(phase_sanitizer, '_detect_outliers') as mock_detect:
|
||||
with patch.object(phase_sanitizer, '_interpolate_outliers') as mock_interpolate:
|
||||
outlier_mask = np.zeros(sample_noisy_phase.shape, dtype=bool)
|
||||
outlier_mask[0, [10, 20, 30]] = True # Mark some outliers
|
||||
clean_phase = sample_noisy_phase.copy()
|
||||
|
||||
mock_detect.return_value = outlier_mask
|
||||
mock_interpolate.return_value = clean_phase
|
||||
|
||||
result = phase_sanitizer.remove_outliers(sample_noisy_phase)
|
||||
|
||||
assert result.shape == sample_noisy_phase.shape
|
||||
mock_detect.assert_called_once_with(sample_noisy_phase)
|
||||
mock_interpolate.assert_called_once()
|
||||
|
||||
def test_should_skip_outlier_removal_when_disabled(self, sanitizer_config, mock_logger, sample_noisy_phase):
|
||||
"""Should skip outlier removal when disabled."""
|
||||
sanitizer_config['enable_outlier_removal'] = False
|
||||
sanitizer = PhaseSanitizer(config=sanitizer_config, logger=mock_logger)
|
||||
|
||||
result = sanitizer.remove_outliers(sample_noisy_phase)
|
||||
|
||||
assert np.array_equal(result, sample_noisy_phase)
|
||||
|
||||
def test_should_handle_outlier_removal_error(self, phase_sanitizer):
|
||||
"""Should handle outlier removal errors gracefully."""
|
||||
with patch.object(phase_sanitizer, '_detect_outliers') as mock_detect:
|
||||
mock_detect.side_effect = Exception("Detection error")
|
||||
|
||||
phase_data = np.random.uniform(-np.pi, np.pi, (2, 50))
|
||||
|
||||
with pytest.raises(PhaseSanitizationError, match="Failed to remove outliers"):
|
||||
phase_sanitizer.remove_outliers(phase_data)
|
||||
|
||||
# Smoothing tests
|
||||
def test_should_smooth_phase_successfully(self, phase_sanitizer, sample_noisy_phase):
|
||||
"""Should smooth phase data successfully."""
|
||||
with patch.object(phase_sanitizer, '_apply_moving_average') as mock_smooth:
|
||||
smoothed_phase = sample_noisy_phase * 0.9 # Simulate smoothing
|
||||
mock_smooth.return_value = smoothed_phase
|
||||
|
||||
result = phase_sanitizer.smooth_phase(sample_noisy_phase)
|
||||
|
||||
assert result.shape == sample_noisy_phase.shape
|
||||
mock_smooth.assert_called_once_with(sample_noisy_phase, phase_sanitizer.smoothing_window)
|
||||
|
||||
def test_should_skip_smoothing_when_disabled(self, sanitizer_config, mock_logger, sample_noisy_phase):
|
||||
"""Should skip smoothing when disabled."""
|
||||
sanitizer_config['enable_smoothing'] = False
|
||||
sanitizer = PhaseSanitizer(config=sanitizer_config, logger=mock_logger)
|
||||
|
||||
result = sanitizer.smooth_phase(sample_noisy_phase)
|
||||
|
||||
assert np.array_equal(result, sample_noisy_phase)
|
||||
|
||||
def test_should_handle_smoothing_error(self, phase_sanitizer):
|
||||
"""Should handle smoothing errors gracefully."""
|
||||
with patch.object(phase_sanitizer, '_apply_moving_average') as mock_smooth:
|
||||
mock_smooth.side_effect = Exception("Smoothing error")
|
||||
|
||||
phase_data = np.random.uniform(-np.pi, np.pi, (2, 50))
|
||||
|
||||
with pytest.raises(PhaseSanitizationError, match="Failed to smooth phase"):
|
||||
phase_sanitizer.smooth_phase(phase_data)
|
||||
|
||||
# Noise filtering tests
|
||||
def test_should_filter_noise_successfully(self, phase_sanitizer, sample_noisy_phase):
|
||||
"""Should filter noise from phase data successfully."""
|
||||
with patch.object(phase_sanitizer, '_apply_low_pass_filter') as mock_filter:
|
||||
filtered_phase = sample_noisy_phase * 0.95 # Simulate filtering
|
||||
mock_filter.return_value = filtered_phase
|
||||
|
||||
result = phase_sanitizer.filter_noise(sample_noisy_phase)
|
||||
|
||||
assert result.shape == sample_noisy_phase.shape
|
||||
mock_filter.assert_called_once_with(sample_noisy_phase, phase_sanitizer.noise_threshold)
|
||||
|
||||
def test_should_skip_noise_filtering_when_disabled(self, sanitizer_config, mock_logger, sample_noisy_phase):
|
||||
"""Should skip noise filtering when disabled."""
|
||||
sanitizer_config['enable_noise_filtering'] = False
|
||||
sanitizer = PhaseSanitizer(config=sanitizer_config, logger=mock_logger)
|
||||
|
||||
result = sanitizer.filter_noise(sample_noisy_phase)
|
||||
|
||||
assert np.array_equal(result, sample_noisy_phase)
|
||||
|
||||
def test_should_handle_noise_filtering_error(self, phase_sanitizer):
|
||||
"""Should handle noise filtering errors gracefully."""
|
||||
with patch.object(phase_sanitizer, '_apply_low_pass_filter') as mock_filter:
|
||||
mock_filter.side_effect = Exception("Filtering error")
|
||||
|
||||
phase_data = np.random.uniform(-np.pi, np.pi, (2, 50))
|
||||
|
||||
with pytest.raises(PhaseSanitizationError, match="Failed to filter noise"):
|
||||
phase_sanitizer.filter_noise(phase_data)
|
||||
|
||||
# Complete sanitization pipeline tests
|
||||
def test_should_sanitize_phase_pipeline_successfully(self, phase_sanitizer, sample_wrapped_phase):
|
||||
"""Should sanitize phase through complete pipeline successfully."""
|
||||
with patch.object(phase_sanitizer, 'unwrap_phase', return_value=sample_wrapped_phase) as mock_unwrap:
|
||||
with patch.object(phase_sanitizer, 'remove_outliers', return_value=sample_wrapped_phase) as mock_outliers:
|
||||
with patch.object(phase_sanitizer, 'smooth_phase', return_value=sample_wrapped_phase) as mock_smooth:
|
||||
with patch.object(phase_sanitizer, 'filter_noise', return_value=sample_wrapped_phase) as mock_filter:
|
||||
|
||||
result = phase_sanitizer.sanitize_phase(sample_wrapped_phase)
|
||||
|
||||
assert result.shape == sample_wrapped_phase.shape
|
||||
mock_unwrap.assert_called_once_with(sample_wrapped_phase)
|
||||
mock_outliers.assert_called_once()
|
||||
mock_smooth.assert_called_once()
|
||||
mock_filter.assert_called_once()
|
||||
|
||||
def test_should_handle_sanitization_pipeline_error(self, phase_sanitizer, sample_wrapped_phase):
|
||||
"""Should handle sanitization pipeline errors gracefully."""
|
||||
with patch.object(phase_sanitizer, 'unwrap_phase') as mock_unwrap:
|
||||
mock_unwrap.side_effect = PhaseSanitizationError("Unwrapping failed")
|
||||
|
||||
with pytest.raises(PhaseSanitizationError):
|
||||
phase_sanitizer.sanitize_phase(sample_wrapped_phase)
|
||||
|
||||
# Phase validation tests
|
||||
def test_should_validate_phase_data_successfully(self, phase_sanitizer):
|
||||
"""Should validate phase data successfully."""
|
||||
valid_phase = np.random.uniform(-np.pi, np.pi, (3, 56))
|
||||
|
||||
result = phase_sanitizer.validate_phase_data(valid_phase)
|
||||
|
||||
assert result == True
|
||||
|
||||
def test_should_reject_invalid_phase_shape(self, phase_sanitizer):
|
||||
"""Should reject phase data with invalid shape."""
|
||||
invalid_phase = np.array([1, 2, 3]) # 1D array
|
||||
|
||||
with pytest.raises(PhaseSanitizationError, match="Phase data must be 2D"):
|
||||
phase_sanitizer.validate_phase_data(invalid_phase)
|
||||
|
||||
def test_should_reject_empty_phase_data(self, phase_sanitizer):
|
||||
"""Should reject empty phase data."""
|
||||
empty_phase = np.array([]).reshape(0, 0)
|
||||
|
||||
with pytest.raises(PhaseSanitizationError, match="Phase data cannot be empty"):
|
||||
phase_sanitizer.validate_phase_data(empty_phase)
|
||||
|
||||
def test_should_reject_phase_out_of_range(self, phase_sanitizer):
|
||||
"""Should reject phase data outside valid range."""
|
||||
invalid_phase = np.array([[10.0, -10.0, 5.0, -5.0]]) # Outside [-π, π]
|
||||
|
||||
with pytest.raises(PhaseSanitizationError, match="Phase values outside valid range"):
|
||||
phase_sanitizer.validate_phase_data(invalid_phase)
|
||||
|
||||
# Statistics and monitoring tests
|
||||
def test_should_get_sanitization_statistics(self, phase_sanitizer):
|
||||
"""Should get sanitization statistics."""
|
||||
# Simulate some processing
|
||||
phase_sanitizer._total_processed = 50
|
||||
phase_sanitizer._outliers_removed = 5
|
||||
phase_sanitizer._sanitization_errors = 2
|
||||
|
||||
stats = phase_sanitizer.get_sanitization_statistics()
|
||||
|
||||
assert isinstance(stats, dict)
|
||||
assert stats['total_processed'] == 50
|
||||
assert stats['outliers_removed'] == 5
|
||||
assert stats['sanitization_errors'] == 2
|
||||
assert stats['outlier_rate'] == 0.1
|
||||
assert stats['error_rate'] == 0.04
|
||||
|
||||
def test_should_reset_statistics(self, phase_sanitizer):
|
||||
"""Should reset sanitization statistics."""
|
||||
phase_sanitizer._total_processed = 50
|
||||
phase_sanitizer._outliers_removed = 5
|
||||
phase_sanitizer._sanitization_errors = 2
|
||||
|
||||
phase_sanitizer.reset_statistics()
|
||||
|
||||
assert phase_sanitizer._total_processed == 0
|
||||
assert phase_sanitizer._outliers_removed == 0
|
||||
assert phase_sanitizer._sanitization_errors == 0
|
||||
|
||||
# Configuration validation tests
|
||||
def test_should_validate_unwrapping_method(self, mock_logger):
|
||||
"""Should validate unwrapping method."""
|
||||
invalid_config = {
|
||||
'unwrapping_method': 'invalid_method',
|
||||
'outlier_threshold': 3.0,
|
||||
'smoothing_window': 5
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid unwrapping method"):
|
||||
PhaseSanitizer(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_should_validate_outlier_threshold(self, mock_logger):
|
||||
"""Should validate outlier threshold."""
|
||||
invalid_config = {
|
||||
'unwrapping_method': 'numpy',
|
||||
'outlier_threshold': -1.0, # Negative threshold
|
||||
'smoothing_window': 5
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="outlier_threshold must be positive"):
|
||||
PhaseSanitizer(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_should_validate_smoothing_window(self, mock_logger):
|
||||
"""Should validate smoothing window."""
|
||||
invalid_config = {
|
||||
'unwrapping_method': 'numpy',
|
||||
'outlier_threshold': 3.0,
|
||||
'smoothing_window': 0 # Invalid window size
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="smoothing_window must be positive"):
|
||||
PhaseSanitizer(config=invalid_config, logger=mock_logger)
|
||||
|
||||
# Edge case tests
|
||||
def test_should_handle_single_antenna_data(self, phase_sanitizer):
|
||||
"""Should handle single antenna phase data."""
|
||||
single_antenna_phase = np.random.uniform(-np.pi, np.pi, (1, 56))
|
||||
|
||||
result = phase_sanitizer.sanitize_phase(single_antenna_phase)
|
||||
|
||||
assert result.shape == single_antenna_phase.shape
|
||||
|
||||
def test_should_handle_small_phase_arrays(self, phase_sanitizer):
|
||||
"""Should handle small phase arrays."""
|
||||
small_phase = np.random.uniform(-np.pi, np.pi, (2, 5))
|
||||
|
||||
result = phase_sanitizer.sanitize_phase(small_phase)
|
||||
|
||||
assert result.shape == small_phase.shape
|
||||
|
||||
def test_should_handle_constant_phase_data(self, phase_sanitizer):
|
||||
"""Should handle constant phase data."""
|
||||
constant_phase = np.full((3, 20), 0.5)
|
||||
|
||||
result = phase_sanitizer.sanitize_phase(constant_phase)
|
||||
|
||||
assert result.shape == constant_phase.shape
|
||||
244
v1/tests/unit/test_router_interface.py
Normal file
244
v1/tests/unit/test_router_interface.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from src.hardware.router_interface import RouterInterface, RouterConnectionError
|
||||
|
||||
|
||||
class TestRouterInterface:
|
||||
"""Test suite for Router Interface following London School TDD principles"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Configuration for router interface"""
|
||||
return {
|
||||
'router_ip': '192.168.1.1',
|
||||
'username': 'admin',
|
||||
'password': 'password',
|
||||
'ssh_port': 22,
|
||||
'timeout': 30,
|
||||
'max_retries': 3
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def router_interface(self, mock_config):
|
||||
"""Create router interface instance for testing"""
|
||||
return RouterInterface(mock_config)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ssh_client(self):
|
||||
"""Mock SSH client for testing"""
|
||||
mock_client = Mock()
|
||||
mock_client.connect = Mock()
|
||||
mock_client.exec_command = Mock()
|
||||
mock_client.close = Mock()
|
||||
return mock_client
|
||||
|
||||
def test_interface_initialization_creates_correct_configuration(self, mock_config):
|
||||
"""Test that router interface initializes with correct configuration"""
|
||||
# Act
|
||||
interface = RouterInterface(mock_config)
|
||||
|
||||
# Assert
|
||||
assert interface is not None
|
||||
assert interface.router_ip == mock_config['router_ip']
|
||||
assert interface.username == mock_config['username']
|
||||
assert interface.password == mock_config['password']
|
||||
assert interface.ssh_port == mock_config['ssh_port']
|
||||
assert interface.timeout == mock_config['timeout']
|
||||
assert interface.max_retries == mock_config['max_retries']
|
||||
assert not interface.is_connected
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_connect_establishes_ssh_connection(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that connect method establishes SSH connection"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
|
||||
# Act
|
||||
result = router_interface.connect()
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert router_interface.is_connected is True
|
||||
mock_ssh_client.set_missing_host_key_policy.assert_called_once()
|
||||
mock_ssh_client.connect.assert_called_once_with(
|
||||
hostname=router_interface.router_ip,
|
||||
port=router_interface.ssh_port,
|
||||
username=router_interface.username,
|
||||
password=router_interface.password,
|
||||
timeout=router_interface.timeout
|
||||
)
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_connect_handles_connection_failure(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that connect method handles connection failures gracefully"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
mock_ssh_client.connect.side_effect = Exception("Connection failed")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RouterConnectionError):
|
||||
router_interface.connect()
|
||||
|
||||
assert router_interface.is_connected is False
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_disconnect_closes_ssh_connection(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that disconnect method closes SSH connection"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
router_interface.connect()
|
||||
|
||||
# Act
|
||||
router_interface.disconnect()
|
||||
|
||||
# Assert
|
||||
assert router_interface.is_connected is False
|
||||
mock_ssh_client.close.assert_called_once()
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_execute_command_runs_ssh_command(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that execute_command runs SSH commands correctly"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
mock_stdout = Mock()
|
||||
mock_stdout.read.return_value = b"command output"
|
||||
mock_stderr = Mock()
|
||||
mock_stderr.read.return_value = b""
|
||||
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
|
||||
|
||||
router_interface.connect()
|
||||
|
||||
# Act
|
||||
result = router_interface.execute_command("test command")
|
||||
|
||||
# Assert
|
||||
assert result == "command output"
|
||||
mock_ssh_client.exec_command.assert_called_with("test command")
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_execute_command_handles_command_errors(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that execute_command handles command errors"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
mock_stdout = Mock()
|
||||
mock_stdout.read.return_value = b""
|
||||
mock_stderr = Mock()
|
||||
mock_stderr.read.return_value = b"command error"
|
||||
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
|
||||
|
||||
router_interface.connect()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RouterConnectionError):
|
||||
router_interface.execute_command("failing command")
|
||||
|
||||
def test_execute_command_requires_connection(self, router_interface):
|
||||
"""Test that execute_command requires active connection"""
|
||||
# Act & Assert
|
||||
with pytest.raises(RouterConnectionError):
|
||||
router_interface.execute_command("test command")
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_get_router_info_retrieves_system_information(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that get_router_info retrieves router system information"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
mock_stdout = Mock()
|
||||
mock_stdout.read.return_value = b"Router Model: AC1900\nFirmware: 1.2.3"
|
||||
mock_stderr = Mock()
|
||||
mock_stderr.read.return_value = b""
|
||||
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
|
||||
|
||||
router_interface.connect()
|
||||
|
||||
# Act
|
||||
info = router_interface.get_router_info()
|
||||
|
||||
# Assert
|
||||
assert info is not None
|
||||
assert isinstance(info, dict)
|
||||
assert 'model' in info
|
||||
assert 'firmware' in info
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_enable_monitor_mode_configures_wifi_monitoring(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that enable_monitor_mode configures WiFi monitoring"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
mock_stdout = Mock()
|
||||
mock_stdout.read.return_value = b"Monitor mode enabled"
|
||||
mock_stderr = Mock()
|
||||
mock_stderr.read.return_value = b""
|
||||
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
|
||||
|
||||
router_interface.connect()
|
||||
|
||||
# Act
|
||||
result = router_interface.enable_monitor_mode("wlan0")
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_ssh_client.exec_command.assert_called()
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_disable_monitor_mode_disables_wifi_monitoring(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that disable_monitor_mode disables WiFi monitoring"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
mock_stdout = Mock()
|
||||
mock_stdout.read.return_value = b"Monitor mode disabled"
|
||||
mock_stderr = Mock()
|
||||
mock_stderr.read.return_value = b""
|
||||
mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr)
|
||||
|
||||
router_interface.connect()
|
||||
|
||||
# Act
|
||||
result = router_interface.disable_monitor_mode("wlan0")
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_ssh_client.exec_command.assert_called()
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_interface_supports_context_manager(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that router interface supports context manager protocol"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
|
||||
# Act
|
||||
with router_interface as interface:
|
||||
# Assert
|
||||
assert interface.is_connected is True
|
||||
|
||||
# Assert - connection should be closed after context
|
||||
assert router_interface.is_connected is False
|
||||
mock_ssh_client.close.assert_called_once()
|
||||
|
||||
def test_interface_validates_configuration(self):
|
||||
"""Test that router interface validates configuration parameters"""
|
||||
# Arrange
|
||||
invalid_config = {
|
||||
'router_ip': '', # Invalid IP
|
||||
'username': 'admin',
|
||||
'password': 'password'
|
||||
}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
RouterInterface(invalid_config)
|
||||
|
||||
@patch('paramiko.SSHClient')
|
||||
def test_interface_implements_retry_logic(self, mock_ssh_class, router_interface, mock_ssh_client):
|
||||
"""Test that interface implements retry logic for failed operations"""
|
||||
# Arrange
|
||||
mock_ssh_class.return_value = mock_ssh_client
|
||||
mock_ssh_client.connect.side_effect = [Exception("Temp failure"), None] # Fail once, then succeed
|
||||
|
||||
# Act
|
||||
result = router_interface.connect()
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert mock_ssh_client.connect.call_count == 2 # Should retry once
|
||||
410
v1/tests/unit/test_router_interface_tdd.py
Normal file
410
v1/tests/unit/test_router_interface_tdd.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""TDD tests for router interface following London School approach."""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||||
from datetime import datetime, timezone
|
||||
import importlib.util
|
||||
|
||||
# Import the router interface module directly
|
||||
import unittest.mock
|
||||
|
||||
# Mock asyncssh before importing
|
||||
with unittest.mock.patch.dict('sys.modules', {'asyncssh': unittest.mock.MagicMock()}):
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
'router_interface',
|
||||
'/workspaces/wifi-densepose/src/hardware/router_interface.py'
|
||||
)
|
||||
router_module = importlib.util.module_from_spec(spec)
|
||||
|
||||
# Import CSI extractor for dependency
|
||||
csi_spec = importlib.util.spec_from_file_location(
|
||||
'csi_extractor',
|
||||
'/workspaces/wifi-densepose/src/hardware/csi_extractor.py'
|
||||
)
|
||||
csi_module = importlib.util.module_from_spec(csi_spec)
|
||||
csi_spec.loader.exec_module(csi_module)
|
||||
|
||||
# Now load the router interface
|
||||
router_module.CSIData = csi_module.CSIData # Make CSIData available
|
||||
spec.loader.exec_module(router_module)
|
||||
|
||||
# Get classes from modules
|
||||
RouterInterface = router_module.RouterInterface
|
||||
RouterConnectionError = router_module.RouterConnectionError
|
||||
CSIData = csi_module.CSIData
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.tdd
|
||||
@pytest.mark.london
|
||||
class TestRouterInterface:
|
||||
"""Test router interface using London School TDD."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger(self):
|
||||
"""Mock logger for testing."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def router_config(self):
|
||||
"""Router configuration for testing."""
|
||||
return {
|
||||
'host': '192.168.1.1',
|
||||
'port': 22,
|
||||
'username': 'admin',
|
||||
'password': 'password',
|
||||
'command_timeout': 30,
|
||||
'connection_timeout': 10,
|
||||
'max_retries': 3,
|
||||
'retry_delay': 1.0
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def router_interface(self, router_config, mock_logger):
|
||||
"""Create router interface for testing."""
|
||||
return RouterInterface(config=router_config, logger=mock_logger)
|
||||
|
||||
# Initialization tests
|
||||
def test_should_initialize_with_valid_config(self, router_config, mock_logger):
|
||||
"""Should initialize router interface with valid configuration."""
|
||||
interface = RouterInterface(config=router_config, logger=mock_logger)
|
||||
|
||||
assert interface.host == '192.168.1.1'
|
||||
assert interface.port == 22
|
||||
assert interface.username == 'admin'
|
||||
assert interface.password == 'password'
|
||||
assert interface.command_timeout == 30
|
||||
assert interface.connection_timeout == 10
|
||||
assert interface.max_retries == 3
|
||||
assert interface.retry_delay == 1.0
|
||||
assert interface.is_connected == False
|
||||
assert interface.logger == mock_logger
|
||||
|
||||
def test_should_raise_error_with_invalid_config(self, mock_logger):
|
||||
"""Should raise error when initialized with invalid configuration."""
|
||||
invalid_config = {'invalid': 'config'}
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required configuration"):
|
||||
RouterInterface(config=invalid_config, logger=mock_logger)
|
||||
|
||||
def test_should_validate_required_fields(self, mock_logger):
|
||||
"""Should validate all required configuration fields."""
|
||||
required_fields = ['host', 'port', 'username', 'password']
|
||||
base_config = {
|
||||
'host': '192.168.1.1',
|
||||
'port': 22,
|
||||
'username': 'admin',
|
||||
'password': 'password'
|
||||
}
|
||||
|
||||
for field in required_fields:
|
||||
config = base_config.copy()
|
||||
del config[field]
|
||||
|
||||
with pytest.raises(ValueError, match="Missing required configuration"):
|
||||
RouterInterface(config=config, logger=mock_logger)
|
||||
|
||||
def test_should_use_default_values(self, mock_logger):
|
||||
"""Should use default values for optional parameters."""
|
||||
minimal_config = {
|
||||
'host': '192.168.1.1',
|
||||
'port': 22,
|
||||
'username': 'admin',
|
||||
'password': 'password'
|
||||
}
|
||||
|
||||
interface = RouterInterface(config=minimal_config, logger=mock_logger)
|
||||
|
||||
assert interface.command_timeout == 30 # default
|
||||
assert interface.connection_timeout == 10 # default
|
||||
assert interface.max_retries == 3 # default
|
||||
assert interface.retry_delay == 1.0 # default
|
||||
|
||||
def test_should_initialize_without_logger(self, router_config):
|
||||
"""Should initialize without logger provided."""
|
||||
interface = RouterInterface(config=router_config)
|
||||
|
||||
assert interface.logger is not None # Should create default logger
|
||||
|
||||
# Connection tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_connect_successfully(self, router_interface):
|
||||
"""Should establish SSH connection successfully."""
|
||||
mock_ssh_client = Mock()
|
||||
|
||||
with patch('src.hardware.router_interface.asyncssh.connect', new_callable=AsyncMock) as mock_connect:
|
||||
mock_connect.return_value = mock_ssh_client
|
||||
|
||||
result = await router_interface.connect()
|
||||
|
||||
assert result == True
|
||||
assert router_interface.is_connected == True
|
||||
assert router_interface.ssh_client == mock_ssh_client
|
||||
mock_connect.assert_called_once_with(
|
||||
'192.168.1.1',
|
||||
port=22,
|
||||
username='admin',
|
||||
password='password',
|
||||
connect_timeout=10
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_handle_connection_failure(self, router_interface):
|
||||
"""Should handle SSH connection failure gracefully."""
|
||||
with patch('src.hardware.router_interface.asyncssh.connect', new_callable=AsyncMock) as mock_connect:
|
||||
mock_connect.side_effect = ConnectionError("Connection failed")
|
||||
|
||||
result = await router_interface.connect()
|
||||
|
||||
assert result == False
|
||||
assert router_interface.is_connected == False
|
||||
assert router_interface.ssh_client is None
|
||||
router_interface.logger.error.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_disconnect_when_connected(self, router_interface):
|
||||
"""Should disconnect SSH connection when connected."""
|
||||
mock_ssh_client = Mock()
|
||||
router_interface.is_connected = True
|
||||
router_interface.ssh_client = mock_ssh_client
|
||||
|
||||
await router_interface.disconnect()
|
||||
|
||||
assert router_interface.is_connected == False
|
||||
assert router_interface.ssh_client is None
|
||||
mock_ssh_client.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_handle_disconnect_when_not_connected(self, router_interface):
|
||||
"""Should handle disconnect when not connected."""
|
||||
router_interface.is_connected = False
|
||||
router_interface.ssh_client = None
|
||||
|
||||
await router_interface.disconnect()
|
||||
|
||||
# Should not raise any exception
|
||||
assert router_interface.is_connected == False
|
||||
|
||||
# Command execution tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_execute_command_successfully(self, router_interface):
|
||||
"""Should execute SSH command successfully."""
|
||||
mock_ssh_client = Mock()
|
||||
mock_result = Mock()
|
||||
mock_result.stdout = "command output"
|
||||
mock_result.stderr = ""
|
||||
mock_result.returncode = 0
|
||||
|
||||
router_interface.is_connected = True
|
||||
router_interface.ssh_client = mock_ssh_client
|
||||
|
||||
with patch.object(mock_ssh_client, 'run', new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = mock_result
|
||||
|
||||
result = await router_interface.execute_command("test command")
|
||||
|
||||
assert result == "command output"
|
||||
mock_run.assert_called_once_with("test command", timeout=30)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_handle_command_execution_when_not_connected(self, router_interface):
|
||||
"""Should handle command execution when not connected."""
|
||||
router_interface.is_connected = False
|
||||
|
||||
with pytest.raises(RouterConnectionError, match="Not connected to router"):
|
||||
await router_interface.execute_command("test command")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_handle_command_execution_error(self, router_interface):
|
||||
"""Should handle command execution errors."""
|
||||
mock_ssh_client = Mock()
|
||||
mock_result = Mock()
|
||||
mock_result.stdout = ""
|
||||
mock_result.stderr = "command error"
|
||||
mock_result.returncode = 1
|
||||
|
||||
router_interface.is_connected = True
|
||||
router_interface.ssh_client = mock_ssh_client
|
||||
|
||||
with patch.object(mock_ssh_client, 'run', new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = mock_result
|
||||
|
||||
with pytest.raises(RouterConnectionError, match="Command failed"):
|
||||
await router_interface.execute_command("test command")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_retry_command_execution_on_failure(self, router_interface):
|
||||
"""Should retry command execution on temporary failure."""
|
||||
mock_ssh_client = Mock()
|
||||
mock_success_result = Mock()
|
||||
mock_success_result.stdout = "success output"
|
||||
mock_success_result.stderr = ""
|
||||
mock_success_result.returncode = 0
|
||||
|
||||
router_interface.is_connected = True
|
||||
router_interface.ssh_client = mock_ssh_client
|
||||
|
||||
with patch.object(mock_ssh_client, 'run', new_callable=AsyncMock) as mock_run:
|
||||
# First two calls fail, third succeeds
|
||||
mock_run.side_effect = [
|
||||
ConnectionError("Network error"),
|
||||
ConnectionError("Network error"),
|
||||
mock_success_result
|
||||
]
|
||||
|
||||
result = await router_interface.execute_command("test command")
|
||||
|
||||
assert result == "success output"
|
||||
assert mock_run.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_fail_after_max_retries(self, router_interface):
|
||||
"""Should fail after maximum retries exceeded."""
|
||||
mock_ssh_client = Mock()
|
||||
|
||||
router_interface.is_connected = True
|
||||
router_interface.ssh_client = mock_ssh_client
|
||||
|
||||
with patch.object(mock_ssh_client, 'run', new_callable=AsyncMock) as mock_run:
|
||||
mock_run.side_effect = ConnectionError("Network error")
|
||||
|
||||
with pytest.raises(RouterConnectionError, match="Command execution failed after 3 retries"):
|
||||
await router_interface.execute_command("test command")
|
||||
|
||||
assert mock_run.call_count == 3
|
||||
|
||||
# CSI data retrieval tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_get_csi_data_successfully(self, router_interface):
|
||||
"""Should retrieve CSI data successfully."""
|
||||
expected_csi_data = Mock(spec=CSIData)
|
||||
|
||||
with patch.object(router_interface, 'execute_command', new_callable=AsyncMock) as mock_execute:
|
||||
with patch.object(router_interface, '_parse_csi_response', return_value=expected_csi_data) as mock_parse:
|
||||
mock_execute.return_value = "csi data response"
|
||||
|
||||
result = await router_interface.get_csi_data()
|
||||
|
||||
assert result == expected_csi_data
|
||||
mock_execute.assert_called_once_with("iwlist scan | grep CSI")
|
||||
mock_parse.assert_called_once_with("csi data response")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_handle_csi_data_retrieval_failure(self, router_interface):
|
||||
"""Should handle CSI data retrieval failure."""
|
||||
with patch.object(router_interface, 'execute_command', new_callable=AsyncMock) as mock_execute:
|
||||
mock_execute.side_effect = RouterConnectionError("Command failed")
|
||||
|
||||
with pytest.raises(RouterConnectionError):
|
||||
await router_interface.get_csi_data()
|
||||
|
||||
# Router status tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_get_router_status_successfully(self, router_interface):
|
||||
"""Should get router status successfully."""
|
||||
expected_status = {
|
||||
'cpu_usage': 25.5,
|
||||
'memory_usage': 60.2,
|
||||
'wifi_status': 'active',
|
||||
'uptime': '5 days, 3 hours'
|
||||
}
|
||||
|
||||
with patch.object(router_interface, 'execute_command', new_callable=AsyncMock) as mock_execute:
|
||||
with patch.object(router_interface, '_parse_status_response', return_value=expected_status) as mock_parse:
|
||||
mock_execute.return_value = "status response"
|
||||
|
||||
result = await router_interface.get_router_status()
|
||||
|
||||
assert result == expected_status
|
||||
mock_execute.assert_called_once_with("cat /proc/stat && free && iwconfig")
|
||||
mock_parse.assert_called_once_with("status response")
|
||||
|
||||
# Configuration tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_configure_csi_monitoring_successfully(self, router_interface):
|
||||
"""Should configure CSI monitoring successfully."""
|
||||
config = {
|
||||
'channel': 6,
|
||||
'bandwidth': 20,
|
||||
'sample_rate': 100
|
||||
}
|
||||
|
||||
with patch.object(router_interface, 'execute_command', new_callable=AsyncMock) as mock_execute:
|
||||
mock_execute.return_value = "Configuration applied"
|
||||
|
||||
result = await router_interface.configure_csi_monitoring(config)
|
||||
|
||||
assert result == True
|
||||
mock_execute.assert_called_once_with(
|
||||
"iwconfig wlan0 channel 6 && echo 'CSI monitoring configured'"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_handle_csi_monitoring_configuration_failure(self, router_interface):
|
||||
"""Should handle CSI monitoring configuration failure."""
|
||||
config = {
|
||||
'channel': 6,
|
||||
'bandwidth': 20,
|
||||
'sample_rate': 100
|
||||
}
|
||||
|
||||
with patch.object(router_interface, 'execute_command', new_callable=AsyncMock) as mock_execute:
|
||||
mock_execute.side_effect = RouterConnectionError("Command failed")
|
||||
|
||||
result = await router_interface.configure_csi_monitoring(config)
|
||||
|
||||
assert result == False
|
||||
|
||||
# Health check tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_perform_health_check_successfully(self, router_interface):
|
||||
"""Should perform health check successfully."""
|
||||
with patch.object(router_interface, 'execute_command', new_callable=AsyncMock) as mock_execute:
|
||||
mock_execute.return_value = "pong"
|
||||
|
||||
result = await router_interface.health_check()
|
||||
|
||||
assert result == True
|
||||
mock_execute.assert_called_once_with("echo 'ping' && echo 'pong'")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_handle_health_check_failure(self, router_interface):
|
||||
"""Should handle health check failure."""
|
||||
with patch.object(router_interface, 'execute_command', new_callable=AsyncMock) as mock_execute:
|
||||
mock_execute.side_effect = RouterConnectionError("Command failed")
|
||||
|
||||
result = await router_interface.health_check()
|
||||
|
||||
assert result == False
|
||||
|
||||
# Parsing method tests
|
||||
def test_should_parse_csi_response(self, router_interface):
|
||||
"""Should parse CSI response data."""
|
||||
mock_response = "CSI_DATA:timestamp,antennas,subcarriers,frequency,bandwidth"
|
||||
|
||||
with patch('src.hardware.router_interface.CSIData') as mock_csi_data:
|
||||
expected_data = Mock(spec=CSIData)
|
||||
mock_csi_data.return_value = expected_data
|
||||
|
||||
result = router_interface._parse_csi_response(mock_response)
|
||||
|
||||
assert result == expected_data
|
||||
|
||||
def test_should_parse_status_response(self, router_interface):
|
||||
"""Should parse router status response."""
|
||||
mock_response = """
|
||||
cpu 123456 0 78901 234567 0 0 0 0 0 0
|
||||
MemTotal: 1024000 kB
|
||||
MemFree: 512000 kB
|
||||
wlan0 IEEE 802.11 ESSID:"TestNetwork"
|
||||
"""
|
||||
|
||||
result = router_interface._parse_status_response(mock_response)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert 'cpu_usage' in result
|
||||
assert 'memory_usage' in result
|
||||
assert 'wifi_status' in result
|
||||
Reference in New Issue
Block a user