updates
This commit is contained in:
736
tests/e2e/test_healthcare_scenario.py
Normal file
736
tests/e2e/test_healthcare_scenario.py
Normal file
@@ -0,0 +1,736 @@
|
||||
"""
|
||||
End-to-end tests for healthcare fall detection scenario.
|
||||
|
||||
Tests complete workflow from CSI data collection to fall alert generation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class AlertSeverity(Enum):
|
||||
"""Alert severity levels."""
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HealthcareAlert:
|
||||
"""Healthcare alert data structure."""
|
||||
alert_id: str
|
||||
timestamp: datetime
|
||||
alert_type: str
|
||||
severity: AlertSeverity
|
||||
patient_id: str
|
||||
location: str
|
||||
confidence: float
|
||||
description: str
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class MockPatientMonitor:
|
||||
"""Mock patient monitoring system."""
|
||||
|
||||
def __init__(self, patient_id: str, room_id: str):
|
||||
self.patient_id = patient_id
|
||||
self.room_id = room_id
|
||||
self.is_monitoring = False
|
||||
self.baseline_activity = None
|
||||
self.activity_history = []
|
||||
self.alerts_generated = []
|
||||
self.fall_detection_enabled = True
|
||||
self.sensitivity_level = "medium"
|
||||
|
||||
async def start_monitoring(self) -> bool:
|
||||
"""Start patient monitoring."""
|
||||
if self.is_monitoring:
|
||||
return False
|
||||
|
||||
self.is_monitoring = True
|
||||
return True
|
||||
|
||||
async def stop_monitoring(self) -> bool:
|
||||
"""Stop patient monitoring."""
|
||||
if not self.is_monitoring:
|
||||
return False
|
||||
|
||||
self.is_monitoring = False
|
||||
return True
|
||||
|
||||
async def process_pose_data(self, pose_data: Dict[str, Any]) -> Optional[HealthcareAlert]:
|
||||
"""Process pose data and detect potential issues."""
|
||||
if not self.is_monitoring:
|
||||
return None
|
||||
|
||||
# Extract activity metrics
|
||||
activity_metrics = self._extract_activity_metrics(pose_data)
|
||||
self.activity_history.append(activity_metrics)
|
||||
|
||||
# Keep only recent history
|
||||
if len(self.activity_history) > 100:
|
||||
self.activity_history = self.activity_history[-100:]
|
||||
|
||||
# Detect anomalies
|
||||
alert = await self._detect_anomalies(activity_metrics, pose_data)
|
||||
|
||||
if alert:
|
||||
self.alerts_generated.append(alert)
|
||||
|
||||
return alert
|
||||
|
||||
def _extract_activity_metrics(self, pose_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Extract activity metrics from pose data."""
|
||||
persons = pose_data.get("persons", [])
|
||||
|
||||
if not persons:
|
||||
return {
|
||||
"person_count": 0,
|
||||
"activity_level": 0.0,
|
||||
"posture": "unknown",
|
||||
"movement_speed": 0.0,
|
||||
"stability_score": 1.0
|
||||
}
|
||||
|
||||
# Analyze first person (primary patient)
|
||||
person = persons[0]
|
||||
|
||||
# Extract posture from activity field or bounding box analysis
|
||||
posture = person.get("activity", "standing")
|
||||
|
||||
# If no activity specified, analyze bounding box for fall detection
|
||||
if posture == "standing" and "bounding_box" in person:
|
||||
bbox = person["bounding_box"]
|
||||
width = bbox.get("width", 80)
|
||||
height = bbox.get("height", 180)
|
||||
|
||||
# Fall detection: if width > height, likely fallen
|
||||
if width > height * 1.5:
|
||||
posture = "fallen"
|
||||
|
||||
# Calculate activity metrics based on posture
|
||||
if posture == "fallen":
|
||||
activity_level = 0.1
|
||||
movement_speed = 0.0
|
||||
stability_score = 0.2
|
||||
elif posture == "walking":
|
||||
activity_level = 0.8
|
||||
movement_speed = 1.5
|
||||
stability_score = 0.7
|
||||
elif posture == "sitting":
|
||||
activity_level = 0.3
|
||||
movement_speed = 0.1
|
||||
stability_score = 0.9
|
||||
else: # standing or other
|
||||
activity_level = 0.5
|
||||
movement_speed = 0.2
|
||||
stability_score = 0.8
|
||||
|
||||
return {
|
||||
"person_count": len(persons),
|
||||
"activity_level": activity_level,
|
||||
"posture": posture,
|
||||
"movement_speed": movement_speed,
|
||||
"stability_score": stability_score,
|
||||
"confidence": person.get("confidence", 0.0)
|
||||
}
|
||||
|
||||
async def _detect_anomalies(self, current_metrics: Dict[str, Any], pose_data: Dict[str, Any]) -> Optional[HealthcareAlert]:
|
||||
"""Detect health-related anomalies."""
|
||||
# Fall detection
|
||||
if current_metrics["posture"] == "fallen":
|
||||
return await self._generate_fall_alert(current_metrics, pose_data)
|
||||
|
||||
# Prolonged inactivity detection
|
||||
if len(self.activity_history) >= 10:
|
||||
recent_activity = [m["activity_level"] for m in self.activity_history[-10:]]
|
||||
avg_activity = np.mean(recent_activity)
|
||||
|
||||
if avg_activity < 0.1: # Very low activity
|
||||
return await self._generate_inactivity_alert(current_metrics, pose_data)
|
||||
|
||||
# Unusual movement patterns
|
||||
if current_metrics["stability_score"] < 0.4:
|
||||
return await self._generate_instability_alert(current_metrics, pose_data)
|
||||
|
||||
return None
|
||||
|
||||
async def _generate_fall_alert(self, metrics: Dict[str, Any], pose_data: Dict[str, Any]) -> HealthcareAlert:
|
||||
"""Generate fall detection alert."""
|
||||
return HealthcareAlert(
|
||||
alert_id=f"fall_{self.patient_id}_{int(datetime.utcnow().timestamp())}",
|
||||
timestamp=datetime.utcnow(),
|
||||
alert_type="fall_detected",
|
||||
severity=AlertSeverity.CRITICAL,
|
||||
patient_id=self.patient_id,
|
||||
location=self.room_id,
|
||||
confidence=metrics["confidence"],
|
||||
description=f"Fall detected for patient {self.patient_id} in {self.room_id}",
|
||||
metadata={
|
||||
"posture": metrics["posture"],
|
||||
"stability_score": metrics["stability_score"],
|
||||
"pose_data": pose_data
|
||||
}
|
||||
)
|
||||
|
||||
async def _generate_inactivity_alert(self, metrics: Dict[str, Any], pose_data: Dict[str, Any]) -> HealthcareAlert:
|
||||
"""Generate prolonged inactivity alert."""
|
||||
return HealthcareAlert(
|
||||
alert_id=f"inactivity_{self.patient_id}_{int(datetime.utcnow().timestamp())}",
|
||||
timestamp=datetime.utcnow(),
|
||||
alert_type="prolonged_inactivity",
|
||||
severity=AlertSeverity.MEDIUM,
|
||||
patient_id=self.patient_id,
|
||||
location=self.room_id,
|
||||
confidence=metrics["confidence"],
|
||||
description=f"Prolonged inactivity detected for patient {self.patient_id}",
|
||||
metadata={
|
||||
"activity_level": metrics["activity_level"],
|
||||
"duration_minutes": 10,
|
||||
"pose_data": pose_data
|
||||
}
|
||||
)
|
||||
|
||||
async def _generate_instability_alert(self, metrics: Dict[str, Any], pose_data: Dict[str, Any]) -> HealthcareAlert:
|
||||
"""Generate movement instability alert."""
|
||||
return HealthcareAlert(
|
||||
alert_id=f"instability_{self.patient_id}_{int(datetime.utcnow().timestamp())}",
|
||||
timestamp=datetime.utcnow(),
|
||||
alert_type="movement_instability",
|
||||
severity=AlertSeverity.HIGH,
|
||||
patient_id=self.patient_id,
|
||||
location=self.room_id,
|
||||
confidence=metrics["confidence"],
|
||||
description=f"Movement instability detected for patient {self.patient_id}",
|
||||
metadata={
|
||||
"stability_score": metrics["stability_score"],
|
||||
"movement_speed": metrics["movement_speed"],
|
||||
"pose_data": pose_data
|
||||
}
|
||||
)
|
||||
|
||||
def get_monitoring_stats(self) -> Dict[str, Any]:
|
||||
"""Get monitoring statistics."""
|
||||
return {
|
||||
"patient_id": self.patient_id,
|
||||
"room_id": self.room_id,
|
||||
"is_monitoring": self.is_monitoring,
|
||||
"total_alerts": len(self.alerts_generated),
|
||||
"alert_types": {
|
||||
alert.alert_type: len([a for a in self.alerts_generated if a.alert_type == alert.alert_type])
|
||||
for alert in self.alerts_generated
|
||||
},
|
||||
"activity_samples": len(self.activity_history),
|
||||
"fall_detection_enabled": self.fall_detection_enabled
|
||||
}
|
||||
|
||||
|
||||
class MockHealthcareNotificationSystem:
|
||||
"""Mock healthcare notification system."""
|
||||
|
||||
def __init__(self):
|
||||
self.notifications_sent = []
|
||||
self.notification_channels = {
|
||||
"nurse_station": True,
|
||||
"mobile_app": True,
|
||||
"email": True,
|
||||
"sms": False
|
||||
}
|
||||
self.escalation_rules = {
|
||||
AlertSeverity.CRITICAL: ["nurse_station", "mobile_app", "sms"],
|
||||
AlertSeverity.HIGH: ["nurse_station", "mobile_app"],
|
||||
AlertSeverity.MEDIUM: ["nurse_station"],
|
||||
AlertSeverity.LOW: ["mobile_app"]
|
||||
}
|
||||
|
||||
async def send_alert_notification(self, alert: HealthcareAlert) -> Dict[str, bool]:
|
||||
"""Send alert notification through appropriate channels."""
|
||||
channels_to_notify = self.escalation_rules.get(alert.severity, ["nurse_station"])
|
||||
results = {}
|
||||
|
||||
for channel in channels_to_notify:
|
||||
if self.notification_channels.get(channel, False):
|
||||
success = await self._send_to_channel(channel, alert)
|
||||
results[channel] = success
|
||||
|
||||
if success:
|
||||
self.notifications_sent.append({
|
||||
"alert_id": alert.alert_id,
|
||||
"channel": channel,
|
||||
"timestamp": datetime.utcnow(),
|
||||
"severity": alert.severity.value
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
async def _send_to_channel(self, channel: str, alert: HealthcareAlert) -> bool:
|
||||
"""Send notification to specific channel."""
|
||||
# Simulate network delay
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Simulate occasional failures
|
||||
if np.random.random() < 0.05: # 5% failure rate
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_notification_stats(self) -> Dict[str, Any]:
|
||||
"""Get notification statistics."""
|
||||
return {
|
||||
"total_notifications": len(self.notifications_sent),
|
||||
"notifications_by_channel": {
|
||||
channel: len([n for n in self.notifications_sent if n["channel"] == channel])
|
||||
for channel in self.notification_channels.keys()
|
||||
},
|
||||
"notifications_by_severity": {
|
||||
severity.value: len([n for n in self.notifications_sent if n["severity"] == severity.value])
|
||||
for severity in AlertSeverity
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TestHealthcareFallDetection:
|
||||
"""Test healthcare fall detection workflow."""
|
||||
|
||||
@pytest.fixture
|
||||
def patient_monitor(self):
|
||||
"""Create patient monitor."""
|
||||
return MockPatientMonitor("patient_001", "room_101")
|
||||
|
||||
@pytest.fixture
|
||||
def notification_system(self):
|
||||
"""Create notification system."""
|
||||
return MockHealthcareNotificationSystem()
|
||||
|
||||
@pytest.fixture
|
||||
def fall_pose_data(self):
|
||||
"""Create pose data indicating a fall."""
|
||||
return {
|
||||
"persons": [
|
||||
{
|
||||
"person_id": "patient_001",
|
||||
"confidence": 0.92,
|
||||
"bounding_box": {"x": 200, "y": 400, "width": 150, "height": 80}, # Horizontal position
|
||||
"activity": "fallen",
|
||||
"keypoints": [[x, y, 0.8] for x, y in zip(range(17), range(17))]
|
||||
}
|
||||
],
|
||||
"zone_summary": {"room_101": 1},
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def normal_pose_data(self):
|
||||
"""Create normal pose data."""
|
||||
return {
|
||||
"persons": [
|
||||
{
|
||||
"person_id": "patient_001",
|
||||
"confidence": 0.88,
|
||||
"bounding_box": {"x": 200, "y": 150, "width": 80, "height": 180},
|
||||
"activity": "standing",
|
||||
"keypoints": [[x, y, 0.9] for x, y in zip(range(17), range(17))]
|
||||
}
|
||||
],
|
||||
"zone_summary": {"room_101": 1},
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fall_detection_workflow_should_fail_initially(self, patient_monitor, notification_system, fall_pose_data):
|
||||
"""Test fall detection workflow - should fail initially."""
|
||||
# Start monitoring
|
||||
result = await patient_monitor.start_monitoring()
|
||||
|
||||
# This will fail initially
|
||||
assert result is True
|
||||
assert patient_monitor.is_monitoring is True
|
||||
|
||||
# Process fall pose data
|
||||
alert = await patient_monitor.process_pose_data(fall_pose_data)
|
||||
|
||||
# Should generate fall alert
|
||||
assert alert is not None
|
||||
assert alert.alert_type == "fall_detected"
|
||||
assert alert.severity == AlertSeverity.CRITICAL
|
||||
assert alert.patient_id == "patient_001"
|
||||
|
||||
# Send notification
|
||||
notification_results = await notification_system.send_alert_notification(alert)
|
||||
|
||||
# Should notify appropriate channels
|
||||
assert len(notification_results) > 0
|
||||
assert any(notification_results.values()) # At least one channel should succeed
|
||||
|
||||
# Check statistics
|
||||
monitor_stats = patient_monitor.get_monitoring_stats()
|
||||
assert monitor_stats["total_alerts"] == 1
|
||||
|
||||
notification_stats = notification_system.get_notification_stats()
|
||||
assert notification_stats["total_notifications"] > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normal_activity_monitoring_should_fail_initially(self, patient_monitor, normal_pose_data):
|
||||
"""Test normal activity monitoring - should fail initially."""
|
||||
await patient_monitor.start_monitoring()
|
||||
|
||||
# Process multiple normal pose data samples
|
||||
alerts_generated = []
|
||||
|
||||
for i in range(10):
|
||||
alert = await patient_monitor.process_pose_data(normal_pose_data)
|
||||
if alert:
|
||||
alerts_generated.append(alert)
|
||||
|
||||
# This will fail initially
|
||||
# Should not generate alerts for normal activity
|
||||
assert len(alerts_generated) == 0
|
||||
|
||||
# Should have activity history
|
||||
stats = patient_monitor.get_monitoring_stats()
|
||||
assert stats["activity_samples"] == 10
|
||||
assert stats["is_monitoring"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prolonged_inactivity_detection_should_fail_initially(self, patient_monitor):
|
||||
"""Test prolonged inactivity detection - should fail initially."""
|
||||
await patient_monitor.start_monitoring()
|
||||
|
||||
# Simulate prolonged inactivity
|
||||
inactive_pose_data = {
|
||||
"persons": [], # No person detected
|
||||
"zone_summary": {"room_101": 0},
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
alerts_generated = []
|
||||
|
||||
# Process multiple inactive samples
|
||||
for i in range(15):
|
||||
alert = await patient_monitor.process_pose_data(inactive_pose_data)
|
||||
if alert:
|
||||
alerts_generated.append(alert)
|
||||
|
||||
# This will fail initially
|
||||
# Should generate inactivity alert after sufficient samples
|
||||
inactivity_alerts = [a for a in alerts_generated if a.alert_type == "prolonged_inactivity"]
|
||||
assert len(inactivity_alerts) > 0
|
||||
|
||||
# Check alert properties
|
||||
alert = inactivity_alerts[0]
|
||||
assert alert.severity == AlertSeverity.MEDIUM
|
||||
assert alert.patient_id == "patient_001"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_movement_instability_detection_should_fail_initially(self, patient_monitor):
|
||||
"""Test movement instability detection - should fail initially."""
|
||||
await patient_monitor.start_monitoring()
|
||||
|
||||
# Simulate unstable movement
|
||||
unstable_pose_data = {
|
||||
"persons": [
|
||||
{
|
||||
"person_id": "patient_001",
|
||||
"confidence": 0.65, # Lower confidence indicates instability
|
||||
"bounding_box": {"x": 200, "y": 150, "width": 80, "height": 180},
|
||||
"activity": "walking",
|
||||
"keypoints": [[x, y, 0.5] for x, y in zip(range(17), range(17))] # Low keypoint confidence
|
||||
}
|
||||
],
|
||||
"zone_summary": {"room_101": 1},
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# Process unstable pose data
|
||||
alert = await patient_monitor.process_pose_data(unstable_pose_data)
|
||||
|
||||
# This will fail initially
|
||||
# May generate instability alert based on stability score
|
||||
if alert and alert.alert_type == "movement_instability":
|
||||
assert alert.severity == AlertSeverity.HIGH
|
||||
assert alert.patient_id == "patient_001"
|
||||
assert "stability_score" in alert.metadata
|
||||
|
||||
|
||||
class TestHealthcareMultiPatientMonitoring:
|
||||
"""Test multi-patient monitoring scenarios."""
|
||||
|
||||
@pytest.fixture
|
||||
def multi_patient_setup(self):
|
||||
"""Create multi-patient monitoring setup."""
|
||||
patients = {
|
||||
"patient_001": MockPatientMonitor("patient_001", "room_101"),
|
||||
"patient_002": MockPatientMonitor("patient_002", "room_102"),
|
||||
"patient_003": MockPatientMonitor("patient_003", "room_103")
|
||||
}
|
||||
|
||||
notification_system = MockHealthcareNotificationSystem()
|
||||
|
||||
return patients, notification_system
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_patient_monitoring_should_fail_initially(self, multi_patient_setup):
|
||||
"""Test concurrent patient monitoring - should fail initially."""
|
||||
patients, notification_system = multi_patient_setup
|
||||
|
||||
# Start monitoring for all patients
|
||||
start_results = []
|
||||
for patient_id, monitor in patients.items():
|
||||
result = await monitor.start_monitoring()
|
||||
start_results.append(result)
|
||||
|
||||
# This will fail initially
|
||||
assert all(start_results)
|
||||
assert all(monitor.is_monitoring for monitor in patients.values())
|
||||
|
||||
# Simulate concurrent pose data processing
|
||||
pose_data_samples = [
|
||||
{
|
||||
"persons": [
|
||||
{
|
||||
"person_id": patient_id,
|
||||
"confidence": 0.85,
|
||||
"bounding_box": {"x": 200, "y": 150, "width": 80, "height": 180},
|
||||
"activity": "standing"
|
||||
}
|
||||
],
|
||||
"zone_summary": {f"room_{101 + i}": 1},
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
for i, patient_id in enumerate(patients.keys())
|
||||
]
|
||||
|
||||
# Process data for all patients concurrently
|
||||
tasks = []
|
||||
for (patient_id, monitor), pose_data in zip(patients.items(), pose_data_samples):
|
||||
task = asyncio.create_task(monitor.process_pose_data(pose_data))
|
||||
tasks.append(task)
|
||||
|
||||
alerts = await asyncio.gather(*tasks)
|
||||
|
||||
# Check results
|
||||
assert len(alerts) == len(patients)
|
||||
|
||||
# Get statistics for all patients
|
||||
all_stats = {}
|
||||
for patient_id, monitor in patients.items():
|
||||
all_stats[patient_id] = monitor.get_monitoring_stats()
|
||||
|
||||
assert len(all_stats) == 3
|
||||
assert all(stats["is_monitoring"] for stats in all_stats.values())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alert_prioritization_should_fail_initially(self, multi_patient_setup):
|
||||
"""Test alert prioritization across patients - should fail initially."""
|
||||
patients, notification_system = multi_patient_setup
|
||||
|
||||
# Start monitoring
|
||||
for monitor in patients.values():
|
||||
await monitor.start_monitoring()
|
||||
|
||||
# Generate different severity alerts
|
||||
alert_scenarios = [
|
||||
("patient_001", "fall_detected", AlertSeverity.CRITICAL),
|
||||
("patient_002", "prolonged_inactivity", AlertSeverity.MEDIUM),
|
||||
("patient_003", "movement_instability", AlertSeverity.HIGH)
|
||||
]
|
||||
|
||||
generated_alerts = []
|
||||
|
||||
for patient_id, alert_type, expected_severity in alert_scenarios:
|
||||
# Create appropriate pose data for each scenario
|
||||
if alert_type == "fall_detected":
|
||||
pose_data = {
|
||||
"persons": [{"person_id": patient_id, "confidence": 0.9, "activity": "fallen"}],
|
||||
"zone_summary": {f"room_{patients[patient_id].room_id}": 1}
|
||||
}
|
||||
else:
|
||||
pose_data = {
|
||||
"persons": [{"person_id": patient_id, "confidence": 0.7, "activity": "standing"}],
|
||||
"zone_summary": {f"room_{patients[patient_id].room_id}": 1}
|
||||
}
|
||||
|
||||
alert = await patients[patient_id].process_pose_data(pose_data)
|
||||
if alert:
|
||||
generated_alerts.append(alert)
|
||||
|
||||
# This will fail initially
|
||||
# Should have generated alerts
|
||||
assert len(generated_alerts) > 0
|
||||
|
||||
# Send notifications for all alerts
|
||||
notification_tasks = [
|
||||
notification_system.send_alert_notification(alert)
|
||||
for alert in generated_alerts
|
||||
]
|
||||
|
||||
notification_results = await asyncio.gather(*notification_tasks)
|
||||
|
||||
# Check notification prioritization
|
||||
notification_stats = notification_system.get_notification_stats()
|
||||
assert notification_stats["total_notifications"] > 0
|
||||
|
||||
# Critical alerts should use more channels
|
||||
critical_notifications = [
|
||||
n for n in notification_system.notifications_sent
|
||||
if n["severity"] == "critical"
|
||||
]
|
||||
|
||||
if critical_notifications:
|
||||
# Critical alerts should be sent to multiple channels
|
||||
critical_channels = set(n["channel"] for n in critical_notifications)
|
||||
assert len(critical_channels) >= 1
|
||||
|
||||
|
||||
class TestHealthcareSystemIntegration:
|
||||
"""Test healthcare system integration scenarios."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_healthcare_workflow_should_fail_initially(self):
|
||||
"""Test complete end-to-end healthcare workflow - should fail initially."""
|
||||
# Setup complete healthcare monitoring system
|
||||
class HealthcareMonitoringSystem:
|
||||
def __init__(self):
|
||||
self.patient_monitors = {}
|
||||
self.notification_system = MockHealthcareNotificationSystem()
|
||||
self.alert_history = []
|
||||
self.system_status = "operational"
|
||||
|
||||
async def add_patient(self, patient_id: str, room_id: str) -> bool:
|
||||
"""Add patient to monitoring system."""
|
||||
if patient_id in self.patient_monitors:
|
||||
return False
|
||||
|
||||
monitor = MockPatientMonitor(patient_id, room_id)
|
||||
self.patient_monitors[patient_id] = monitor
|
||||
return await monitor.start_monitoring()
|
||||
|
||||
async def process_pose_update(self, room_id: str, pose_data: Dict[str, Any]) -> List[HealthcareAlert]:
|
||||
"""Process pose update for room."""
|
||||
alerts = []
|
||||
|
||||
# Find patients in this room
|
||||
room_patients = [
|
||||
(patient_id, monitor) for patient_id, monitor in self.patient_monitors.items()
|
||||
if monitor.room_id == room_id
|
||||
]
|
||||
|
||||
for patient_id, monitor in room_patients:
|
||||
alert = await monitor.process_pose_data(pose_data)
|
||||
if alert:
|
||||
alerts.append(alert)
|
||||
self.alert_history.append(alert)
|
||||
|
||||
# Send notification
|
||||
await self.notification_system.send_alert_notification(alert)
|
||||
|
||||
return alerts
|
||||
|
||||
def get_system_status(self) -> Dict[str, Any]:
|
||||
"""Get overall system status."""
|
||||
return {
|
||||
"system_status": self.system_status,
|
||||
"total_patients": len(self.patient_monitors),
|
||||
"active_monitors": sum(1 for m in self.patient_monitors.values() if m.is_monitoring),
|
||||
"total_alerts": len(self.alert_history),
|
||||
"notification_stats": self.notification_system.get_notification_stats()
|
||||
}
|
||||
|
||||
healthcare_system = HealthcareMonitoringSystem()
|
||||
|
||||
# Add patients to system
|
||||
patients = [
|
||||
("patient_001", "room_101"),
|
||||
("patient_002", "room_102"),
|
||||
("patient_003", "room_103")
|
||||
]
|
||||
|
||||
for patient_id, room_id in patients:
|
||||
result = await healthcare_system.add_patient(patient_id, room_id)
|
||||
assert result is True
|
||||
|
||||
# Simulate pose data updates for different rooms
|
||||
pose_updates = [
|
||||
("room_101", {
|
||||
"persons": [{"person_id": "patient_001", "confidence": 0.9, "activity": "fallen"}],
|
||||
"zone_summary": {"room_101": 1}
|
||||
}),
|
||||
("room_102", {
|
||||
"persons": [{"person_id": "patient_002", "confidence": 0.8, "activity": "standing"}],
|
||||
"zone_summary": {"room_102": 1}
|
||||
}),
|
||||
("room_103", {
|
||||
"persons": [], # No person detected
|
||||
"zone_summary": {"room_103": 0}
|
||||
})
|
||||
]
|
||||
|
||||
all_alerts = []
|
||||
for room_id, pose_data in pose_updates:
|
||||
alerts = await healthcare_system.process_pose_update(room_id, pose_data)
|
||||
all_alerts.extend(alerts)
|
||||
|
||||
# This will fail initially
|
||||
# Should have processed all updates
|
||||
assert len(pose_updates) == 3
|
||||
|
||||
# Check system status
|
||||
system_status = healthcare_system.get_system_status()
|
||||
assert system_status["total_patients"] == 3
|
||||
assert system_status["active_monitors"] == 3
|
||||
assert system_status["system_status"] == "operational"
|
||||
|
||||
# Should have generated some alerts
|
||||
if all_alerts:
|
||||
assert len(all_alerts) > 0
|
||||
assert system_status["total_alerts"] > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_healthcare_system_resilience_should_fail_initially(self):
|
||||
"""Test healthcare system resilience - should fail initially."""
|
||||
patient_monitor = MockPatientMonitor("patient_001", "room_101")
|
||||
notification_system = MockHealthcareNotificationSystem()
|
||||
|
||||
await patient_monitor.start_monitoring()
|
||||
|
||||
# Simulate system stress with rapid pose updates
|
||||
rapid_updates = 50
|
||||
alerts_generated = []
|
||||
|
||||
for i in range(rapid_updates):
|
||||
# Alternate between normal and concerning pose data
|
||||
if i % 10 == 0: # Every 10th update is concerning
|
||||
pose_data = {
|
||||
"persons": [{"person_id": "patient_001", "confidence": 0.9, "activity": "fallen"}],
|
||||
"zone_summary": {"room_101": 1}
|
||||
}
|
||||
else:
|
||||
pose_data = {
|
||||
"persons": [{"person_id": "patient_001", "confidence": 0.85, "activity": "standing"}],
|
||||
"zone_summary": {"room_101": 1}
|
||||
}
|
||||
|
||||
alert = await patient_monitor.process_pose_data(pose_data)
|
||||
if alert:
|
||||
alerts_generated.append(alert)
|
||||
await notification_system.send_alert_notification(alert)
|
||||
|
||||
# This will fail initially
|
||||
# System should handle rapid updates gracefully
|
||||
stats = patient_monitor.get_monitoring_stats()
|
||||
assert stats["activity_samples"] == rapid_updates
|
||||
assert stats["is_monitoring"] is True
|
||||
|
||||
# Should have generated some alerts but not excessive
|
||||
assert len(alerts_generated) <= rapid_updates / 5 # At most 20% alert rate
|
||||
|
||||
notification_stats = notification_system.get_notification_stats()
|
||||
assert notification_stats["total_notifications"] >= len(alerts_generated)
|
||||
661
tests/fixtures/api_client.py
vendored
Normal file
661
tests/fixtures/api_client.py
vendored
Normal file
@@ -0,0 +1,661 @@
|
||||
"""
|
||||
Test client utilities for API testing.
|
||||
|
||||
Provides mock and real API clients for comprehensive testing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional, Union, AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
import websockets
|
||||
import jwt
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
"""Authentication related errors."""
|
||||
pass
|
||||
|
||||
|
||||
class APIError(Exception):
|
||||
"""General API errors."""
|
||||
pass
|
||||
|
||||
|
||||
class RateLimitError(Exception):
|
||||
"""Rate limiting errors."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIResponse:
|
||||
"""API response wrapper."""
|
||||
status_code: int
|
||||
data: Dict[str, Any]
|
||||
headers: Dict[str, str]
|
||||
response_time_ms: float
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
class MockAPIClient:
|
||||
"""Mock API client for testing."""
|
||||
|
||||
def __init__(self, base_url: str = "http://localhost:8000"):
|
||||
self.base_url = base_url
|
||||
self.session = None
|
||||
self.auth_token = None
|
||||
self.refresh_token = None
|
||||
self.token_expires_at = None
|
||||
self.request_history = []
|
||||
self.response_delays = {}
|
||||
self.error_simulation = {}
|
||||
self.rate_limit_config = {
|
||||
"enabled": False,
|
||||
"requests_per_minute": 60,
|
||||
"current_count": 0,
|
||||
"window_start": time.time()
|
||||
}
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
await self.disconnect()
|
||||
|
||||
async def connect(self):
|
||||
"""Initialize connection."""
|
||||
self.session = aiohttp.ClientSession()
|
||||
|
||||
async def disconnect(self):
|
||||
"""Close connection."""
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
|
||||
def set_response_delay(self, endpoint: str, delay_ms: float):
|
||||
"""Set artificial delay for endpoint."""
|
||||
self.response_delays[endpoint] = delay_ms
|
||||
|
||||
def simulate_error(self, endpoint: str, error_type: str, probability: float = 1.0):
|
||||
"""Simulate errors for endpoint."""
|
||||
self.error_simulation[endpoint] = {
|
||||
"type": error_type,
|
||||
"probability": probability
|
||||
}
|
||||
|
||||
def enable_rate_limiting(self, requests_per_minute: int = 60):
|
||||
"""Enable rate limiting simulation."""
|
||||
self.rate_limit_config.update({
|
||||
"enabled": True,
|
||||
"requests_per_minute": requests_per_minute,
|
||||
"current_count": 0,
|
||||
"window_start": time.time()
|
||||
})
|
||||
|
||||
async def _check_rate_limit(self):
|
||||
"""Check rate limiting."""
|
||||
if not self.rate_limit_config["enabled"]:
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
window_duration = 60 # 1 minute
|
||||
|
||||
# Reset window if needed
|
||||
if current_time - self.rate_limit_config["window_start"] > window_duration:
|
||||
self.rate_limit_config["current_count"] = 0
|
||||
self.rate_limit_config["window_start"] = current_time
|
||||
|
||||
# Check limit
|
||||
if self.rate_limit_config["current_count"] >= self.rate_limit_config["requests_per_minute"]:
|
||||
raise RateLimitError("Rate limit exceeded")
|
||||
|
||||
self.rate_limit_config["current_count"] += 1
|
||||
|
||||
async def _simulate_network_delay(self, endpoint: str):
|
||||
"""Simulate network delay."""
|
||||
delay = self.response_delays.get(endpoint, 0)
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay / 1000) # Convert ms to seconds
|
||||
|
||||
async def _check_error_simulation(self, endpoint: str):
|
||||
"""Check if error should be simulated."""
|
||||
if endpoint in self.error_simulation:
|
||||
config = self.error_simulation[endpoint]
|
||||
if random.random() < config["probability"]:
|
||||
error_type = config["type"]
|
||||
if error_type == "timeout":
|
||||
raise asyncio.TimeoutError("Simulated timeout")
|
||||
elif error_type == "connection":
|
||||
raise aiohttp.ClientConnectionError("Simulated connection error")
|
||||
elif error_type == "server_error":
|
||||
raise APIError("Simulated server error")
|
||||
|
||||
async def _make_request(self, method: str, endpoint: str, **kwargs) -> APIResponse:
|
||||
"""Make HTTP request with simulation."""
|
||||
start_time = time.time()
|
||||
|
||||
# Check rate limiting
|
||||
await self._check_rate_limit()
|
||||
|
||||
# Simulate network delay
|
||||
await self._simulate_network_delay(endpoint)
|
||||
|
||||
# Check error simulation
|
||||
await self._check_error_simulation(endpoint)
|
||||
|
||||
# Record request
|
||||
request_record = {
|
||||
"method": method,
|
||||
"endpoint": endpoint,
|
||||
"timestamp": datetime.utcnow(),
|
||||
"kwargs": kwargs
|
||||
}
|
||||
self.request_history.append(request_record)
|
||||
|
||||
# Generate mock response
|
||||
response_data = await self._generate_mock_response(method, endpoint, kwargs)
|
||||
|
||||
end_time = time.time()
|
||||
response_time = (end_time - start_time) * 1000
|
||||
|
||||
return APIResponse(
|
||||
status_code=response_data["status_code"],
|
||||
data=response_data["data"],
|
||||
headers=response_data.get("headers", {}),
|
||||
response_time_ms=response_time,
|
||||
timestamp=datetime.utcnow()
|
||||
)
|
||||
|
||||
async def _generate_mock_response(self, method: str, endpoint: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Generate mock response based on endpoint."""
|
||||
if endpoint == "/health":
|
||||
return {
|
||||
"status_code": 200,
|
||||
"data": {
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"version": "1.0.0"
|
||||
}
|
||||
}
|
||||
|
||||
elif endpoint == "/auth/login":
|
||||
if method == "POST":
|
||||
# Generate mock JWT tokens
|
||||
payload = {
|
||||
"user_id": "test_user",
|
||||
"exp": datetime.utcnow() + timedelta(hours=1)
|
||||
}
|
||||
access_token = jwt.encode(payload, "secret", algorithm="HS256")
|
||||
refresh_token = jwt.encode({"user_id": "test_user"}, "secret", algorithm="HS256")
|
||||
|
||||
self.auth_token = access_token
|
||||
self.refresh_token = refresh_token
|
||||
self.token_expires_at = payload["exp"]
|
||||
|
||||
return {
|
||||
"status_code": 200,
|
||||
"data": {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": 3600
|
||||
}
|
||||
}
|
||||
|
||||
elif endpoint == "/auth/refresh":
|
||||
if method == "POST" and self.refresh_token:
|
||||
# Generate new access token
|
||||
payload = {
|
||||
"user_id": "test_user",
|
||||
"exp": datetime.utcnow() + timedelta(hours=1)
|
||||
}
|
||||
access_token = jwt.encode(payload, "secret", algorithm="HS256")
|
||||
|
||||
self.auth_token = access_token
|
||||
self.token_expires_at = payload["exp"]
|
||||
|
||||
return {
|
||||
"status_code": 200,
|
||||
"data": {
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": 3600
|
||||
}
|
||||
}
|
||||
|
||||
elif endpoint == "/pose/detect":
|
||||
if method == "POST":
|
||||
return {
|
||||
"status_code": 200,
|
||||
"data": {
|
||||
"persons": [
|
||||
{
|
||||
"person_id": "person_1",
|
||||
"confidence": 0.85,
|
||||
"bounding_box": {"x": 100, "y": 150, "width": 80, "height": 180},
|
||||
"keypoints": [[x, y, 0.9] for x, y in zip(range(17), range(17))],
|
||||
"activity": "standing"
|
||||
}
|
||||
],
|
||||
"processing_time_ms": 45.2,
|
||||
"model_version": "v1.0",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
elif endpoint == "/config":
|
||||
if method == "GET":
|
||||
return {
|
||||
"status_code": 200,
|
||||
"data": {
|
||||
"model_config": {
|
||||
"confidence_threshold": 0.7,
|
||||
"nms_threshold": 0.5,
|
||||
"max_persons": 10
|
||||
},
|
||||
"processing_config": {
|
||||
"batch_size": 1,
|
||||
"use_gpu": True,
|
||||
"preprocessing": "standard"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Default response
|
||||
return {
|
||||
"status_code": 404,
|
||||
"data": {"error": "Endpoint not found"}
|
||||
}
|
||||
|
||||
async def get(self, endpoint: str, **kwargs) -> APIResponse:
|
||||
"""Make GET request."""
|
||||
return await self._make_request("GET", endpoint, **kwargs)
|
||||
|
||||
async def post(self, endpoint: str, **kwargs) -> APIResponse:
|
||||
"""Make POST request."""
|
||||
return await self._make_request("POST", endpoint, **kwargs)
|
||||
|
||||
async def put(self, endpoint: str, **kwargs) -> APIResponse:
|
||||
"""Make PUT request."""
|
||||
return await self._make_request("PUT", endpoint, **kwargs)
|
||||
|
||||
async def delete(self, endpoint: str, **kwargs) -> APIResponse:
|
||||
"""Make DELETE request."""
|
||||
return await self._make_request("DELETE", endpoint, **kwargs)
|
||||
|
||||
async def login(self, username: str, password: str) -> bool:
|
||||
"""Authenticate with API."""
|
||||
response = await self.post("/auth/login", json={
|
||||
"username": username,
|
||||
"password": password
|
||||
})
|
||||
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
raise AuthenticationError("Login failed")
|
||||
|
||||
async def refresh_auth_token(self) -> bool:
|
||||
"""Refresh authentication token."""
|
||||
if not self.refresh_token:
|
||||
raise AuthenticationError("No refresh token available")
|
||||
|
||||
response = await self.post("/auth/refresh", json={
|
||||
"refresh_token": self.refresh_token
|
||||
})
|
||||
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
raise AuthenticationError("Token refresh failed")
|
||||
|
||||
def is_authenticated(self) -> bool:
|
||||
"""Check if client is authenticated."""
|
||||
if not self.auth_token or not self.token_expires_at:
|
||||
return False
|
||||
|
||||
return datetime.utcnow() < self.token_expires_at
|
||||
|
||||
def get_request_history(self) -> List[Dict[str, Any]]:
|
||||
"""Get request history."""
|
||||
return self.request_history.copy()
|
||||
|
||||
def clear_request_history(self):
|
||||
"""Clear request history."""
|
||||
self.request_history.clear()
|
||||
|
||||
|
||||
class MockWebSocketClient:
|
||||
"""Mock WebSocket client for testing."""
|
||||
|
||||
def __init__(self, uri: str = "ws://localhost:8000/ws"):
|
||||
self.uri = uri
|
||||
self.websocket = None
|
||||
self.is_connected = False
|
||||
self.messages_received = []
|
||||
self.messages_sent = []
|
||||
self.connection_errors = []
|
||||
self.auto_respond = True
|
||||
self.response_delay = 0.01 # 10ms default delay
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to WebSocket."""
|
||||
try:
|
||||
# Simulate connection
|
||||
await asyncio.sleep(0.01)
|
||||
self.is_connected = True
|
||||
return True
|
||||
except Exception as e:
|
||||
self.connection_errors.append(str(e))
|
||||
return False
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from WebSocket."""
|
||||
self.is_connected = False
|
||||
self.websocket = None
|
||||
|
||||
async def send_message(self, message: Dict[str, Any]) -> bool:
|
||||
"""Send message to WebSocket."""
|
||||
if not self.is_connected:
|
||||
raise ConnectionError("WebSocket not connected")
|
||||
|
||||
# Record sent message
|
||||
self.messages_sent.append({
|
||||
"message": message,
|
||||
"timestamp": datetime.utcnow()
|
||||
})
|
||||
|
||||
# Auto-respond if enabled
|
||||
if self.auto_respond:
|
||||
await asyncio.sleep(self.response_delay)
|
||||
response = await self._generate_auto_response(message)
|
||||
if response:
|
||||
self.messages_received.append({
|
||||
"message": response,
|
||||
"timestamp": datetime.utcnow()
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
async def receive_message(self, timeout: float = 1.0) -> Optional[Dict[str, Any]]:
|
||||
"""Receive message from WebSocket."""
|
||||
if not self.is_connected:
|
||||
raise ConnectionError("WebSocket not connected")
|
||||
|
||||
# Wait for message or timeout
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
if self.messages_received:
|
||||
return self.messages_received.pop(0)["message"]
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return None
|
||||
|
||||
async def _generate_auto_response(self, message: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Generate automatic response to message."""
|
||||
message_type = message.get("type")
|
||||
|
||||
if message_type == "subscribe":
|
||||
return {
|
||||
"type": "subscription_confirmed",
|
||||
"channel": message.get("channel"),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
elif message_type == "pose_request":
|
||||
return {
|
||||
"type": "pose_data",
|
||||
"data": {
|
||||
"persons": [
|
||||
{
|
||||
"person_id": "person_1",
|
||||
"confidence": 0.88,
|
||||
"bounding_box": {"x": 150, "y": 200, "width": 80, "height": 180},
|
||||
"keypoints": [[x, y, 0.9] for x, y in zip(range(17), range(17))]
|
||||
}
|
||||
],
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
},
|
||||
"request_id": message.get("request_id")
|
||||
}
|
||||
|
||||
elif message_type == "ping":
|
||||
return {
|
||||
"type": "pong",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def set_auto_respond(self, enabled: bool, delay_ms: float = 10):
|
||||
"""Configure auto-response behavior."""
|
||||
self.auto_respond = enabled
|
||||
self.response_delay = delay_ms / 1000
|
||||
|
||||
def inject_message(self, message: Dict[str, Any]):
|
||||
"""Inject message as if received from server."""
|
||||
self.messages_received.append({
|
||||
"message": message,
|
||||
"timestamp": datetime.utcnow()
|
||||
})
|
||||
|
||||
def get_sent_messages(self) -> List[Dict[str, Any]]:
|
||||
"""Get all sent messages."""
|
||||
return self.messages_sent.copy()
|
||||
|
||||
def get_received_messages(self) -> List[Dict[str, Any]]:
|
||||
"""Get all received messages."""
|
||||
return self.messages_received.copy()
|
||||
|
||||
def clear_message_history(self):
|
||||
"""Clear message history."""
|
||||
self.messages_sent.clear()
|
||||
self.messages_received.clear()
|
||||
|
||||
|
||||
class APITestClient:
|
||||
"""High-level test client combining HTTP and WebSocket."""
|
||||
|
||||
def __init__(self, base_url: str = "http://localhost:8000"):
|
||||
self.base_url = base_url
|
||||
self.ws_url = base_url.replace("http", "ws") + "/ws"
|
||||
self.http_client = MockAPIClient(base_url)
|
||||
self.ws_client = MockWebSocketClient(self.ws_url)
|
||||
self.test_session_id = None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
await self.setup()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
await self.teardown()
|
||||
|
||||
async def setup(self):
|
||||
"""Setup test client."""
|
||||
await self.http_client.connect()
|
||||
await self.ws_client.connect()
|
||||
self.test_session_id = f"test_session_{int(time.time())}"
|
||||
|
||||
async def teardown(self):
|
||||
"""Teardown test client."""
|
||||
await self.ws_client.disconnect()
|
||||
await self.http_client.disconnect()
|
||||
|
||||
async def authenticate(self, username: str = "test_user", password: str = "test_pass") -> bool:
|
||||
"""Authenticate with API."""
|
||||
return await self.http_client.login(username, password)
|
||||
|
||||
async def test_health_endpoint(self) -> APIResponse:
|
||||
"""Test health endpoint."""
|
||||
return await self.http_client.get("/health")
|
||||
|
||||
async def test_pose_detection(self, csi_data: Dict[str, Any]) -> APIResponse:
|
||||
"""Test pose detection endpoint."""
|
||||
return await self.http_client.post("/pose/detect", json=csi_data)
|
||||
|
||||
async def test_websocket_streaming(self, duration_seconds: int = 5) -> List[Dict[str, Any]]:
|
||||
"""Test WebSocket streaming."""
|
||||
# Subscribe to pose stream
|
||||
await self.ws_client.send_message({
|
||||
"type": "subscribe",
|
||||
"channel": "pose_stream",
|
||||
"session_id": self.test_session_id
|
||||
})
|
||||
|
||||
# Collect messages for specified duration
|
||||
messages = []
|
||||
end_time = time.time() + duration_seconds
|
||||
|
||||
while time.time() < end_time:
|
||||
message = await self.ws_client.receive_message(timeout=0.1)
|
||||
if message:
|
||||
messages.append(message)
|
||||
|
||||
return messages
|
||||
|
||||
async def simulate_concurrent_requests(self, num_requests: int = 10) -> List[APIResponse]:
|
||||
"""Simulate concurrent HTTP requests."""
|
||||
tasks = []
|
||||
|
||||
for i in range(num_requests):
|
||||
task = asyncio.create_task(self.http_client.get("/health"))
|
||||
tasks.append(task)
|
||||
|
||||
responses = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
return responses
|
||||
|
||||
async def simulate_websocket_load(self, num_connections: int = 5, duration_seconds: int = 3) -> Dict[str, Any]:
|
||||
"""Simulate WebSocket load testing."""
|
||||
# Create multiple WebSocket clients
|
||||
ws_clients = []
|
||||
for i in range(num_connections):
|
||||
client = MockWebSocketClient(self.ws_url)
|
||||
await client.connect()
|
||||
ws_clients.append(client)
|
||||
|
||||
# Send messages from all clients
|
||||
message_counts = []
|
||||
|
||||
try:
|
||||
tasks = []
|
||||
for i, client in enumerate(ws_clients):
|
||||
task = asyncio.create_task(self._send_messages_for_duration(client, duration_seconds, i))
|
||||
tasks.append(task)
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
message_counts = results
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
for client in ws_clients:
|
||||
await client.disconnect()
|
||||
|
||||
return {
|
||||
"num_connections": num_connections,
|
||||
"duration_seconds": duration_seconds,
|
||||
"messages_per_connection": message_counts,
|
||||
"total_messages": sum(message_counts)
|
||||
}
|
||||
|
||||
async def _send_messages_for_duration(self, client: MockWebSocketClient, duration: int, client_id: int) -> int:
|
||||
"""Send messages for specified duration."""
|
||||
message_count = 0
|
||||
end_time = time.time() + duration
|
||||
|
||||
while time.time() < end_time:
|
||||
await client.send_message({
|
||||
"type": "ping",
|
||||
"client_id": client_id,
|
||||
"message_id": message_count
|
||||
})
|
||||
message_count += 1
|
||||
await asyncio.sleep(0.1) # 10 messages per second
|
||||
|
||||
return message_count
|
||||
|
||||
def configure_error_simulation(self, endpoint: str, error_type: str, probability: float = 0.1):
|
||||
"""Configure error simulation for testing."""
|
||||
self.http_client.simulate_error(endpoint, error_type, probability)
|
||||
|
||||
def configure_rate_limiting(self, requests_per_minute: int = 60):
|
||||
"""Configure rate limiting for testing."""
|
||||
self.http_client.enable_rate_limiting(requests_per_minute)
|
||||
|
||||
def get_performance_metrics(self) -> Dict[str, Any]:
|
||||
"""Get performance metrics from test session."""
|
||||
http_history = self.http_client.get_request_history()
|
||||
ws_sent = self.ws_client.get_sent_messages()
|
||||
ws_received = self.ws_client.get_received_messages()
|
||||
|
||||
# Calculate HTTP metrics
|
||||
if http_history:
|
||||
response_times = [r.get("response_time_ms", 0) for r in http_history]
|
||||
http_metrics = {
|
||||
"total_requests": len(http_history),
|
||||
"avg_response_time_ms": sum(response_times) / len(response_times),
|
||||
"min_response_time_ms": min(response_times),
|
||||
"max_response_time_ms": max(response_times)
|
||||
}
|
||||
else:
|
||||
http_metrics = {"total_requests": 0}
|
||||
|
||||
# Calculate WebSocket metrics
|
||||
ws_metrics = {
|
||||
"messages_sent": len(ws_sent),
|
||||
"messages_received": len(ws_received),
|
||||
"connection_active": self.ws_client.is_connected
|
||||
}
|
||||
|
||||
return {
|
||||
"session_id": self.test_session_id,
|
||||
"http_metrics": http_metrics,
|
||||
"websocket_metrics": ws_metrics,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
# Utility functions for test data generation
|
||||
def generate_test_csi_data() -> Dict[str, Any]:
|
||||
"""Generate test CSI data for API testing."""
|
||||
import numpy as np
|
||||
|
||||
return {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"router_id": "test_router_001",
|
||||
"amplitude": np.random.uniform(0, 1, (4, 64)).tolist(),
|
||||
"phase": np.random.uniform(-np.pi, np.pi, (4, 64)).tolist(),
|
||||
"frequency": 5.8e9,
|
||||
"bandwidth": 80e6,
|
||||
"num_antennas": 4,
|
||||
"num_subcarriers": 64
|
||||
}
|
||||
|
||||
|
||||
def create_test_user_credentials() -> Dict[str, str]:
|
||||
"""Create test user credentials."""
|
||||
return {
|
||||
"username": "test_user",
|
||||
"password": "test_password_123",
|
||||
"email": "test@example.com"
|
||||
}
|
||||
|
||||
|
||||
async def wait_for_condition(condition_func, timeout: float = 5.0, interval: float = 0.1) -> bool:
|
||||
"""Wait for condition to become true."""
|
||||
end_time = time.time() + timeout
|
||||
|
||||
while time.time() < end_time:
|
||||
if await condition_func() if asyncio.iscoroutinefunction(condition_func) else condition_func():
|
||||
return True
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
return False
|
||||
487
tests/fixtures/csi_data.py
vendored
Normal file
487
tests/fixtures/csi_data.py
vendored
Normal file
@@ -0,0 +1,487 @@
|
||||
"""
|
||||
Test data generation utilities for CSI data.
|
||||
|
||||
Provides realistic CSI data samples for testing pose estimation pipeline.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
import json
|
||||
import random
|
||||
|
||||
|
||||
class CSIDataGenerator:
|
||||
"""Generate realistic CSI data for testing."""
|
||||
|
||||
def __init__(self,
|
||||
frequency: float = 5.8e9,
|
||||
bandwidth: float = 80e6,
|
||||
num_antennas: int = 4,
|
||||
num_subcarriers: int = 64):
|
||||
self.frequency = frequency
|
||||
self.bandwidth = bandwidth
|
||||
self.num_antennas = num_antennas
|
||||
self.num_subcarriers = num_subcarriers
|
||||
self.sample_rate = 1000 # Hz
|
||||
self.noise_level = 0.1
|
||||
|
||||
# Pre-computed patterns for different scenarios
|
||||
self._initialize_patterns()
|
||||
|
||||
def _initialize_patterns(self):
|
||||
"""Initialize CSI patterns for different scenarios."""
|
||||
# Empty room pattern (baseline)
|
||||
self.empty_room_pattern = {
|
||||
"amplitude_mean": 0.3,
|
||||
"amplitude_std": 0.05,
|
||||
"phase_variance": 0.1,
|
||||
"temporal_stability": 0.95
|
||||
}
|
||||
|
||||
# Single person patterns
|
||||
self.single_person_patterns = {
|
||||
"standing": {
|
||||
"amplitude_mean": 0.5,
|
||||
"amplitude_std": 0.08,
|
||||
"phase_variance": 0.2,
|
||||
"temporal_stability": 0.85,
|
||||
"movement_frequency": 0.1
|
||||
},
|
||||
"walking": {
|
||||
"amplitude_mean": 0.6,
|
||||
"amplitude_std": 0.15,
|
||||
"phase_variance": 0.4,
|
||||
"temporal_stability": 0.6,
|
||||
"movement_frequency": 2.0
|
||||
},
|
||||
"sitting": {
|
||||
"amplitude_mean": 0.4,
|
||||
"amplitude_std": 0.06,
|
||||
"phase_variance": 0.15,
|
||||
"temporal_stability": 0.9,
|
||||
"movement_frequency": 0.05
|
||||
},
|
||||
"fallen": {
|
||||
"amplitude_mean": 0.35,
|
||||
"amplitude_std": 0.04,
|
||||
"phase_variance": 0.08,
|
||||
"temporal_stability": 0.95,
|
||||
"movement_frequency": 0.02
|
||||
}
|
||||
}
|
||||
|
||||
# Multi-person patterns
|
||||
self.multi_person_patterns = {
|
||||
2: {"amplitude_multiplier": 1.4, "phase_complexity": 1.6},
|
||||
3: {"amplitude_multiplier": 1.7, "phase_complexity": 2.1},
|
||||
4: {"amplitude_multiplier": 2.0, "phase_complexity": 2.8}
|
||||
}
|
||||
|
||||
def generate_empty_room_sample(self, timestamp: Optional[datetime] = None) -> Dict[str, Any]:
|
||||
"""Generate CSI sample for empty room."""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.utcnow()
|
||||
|
||||
pattern = self.empty_room_pattern
|
||||
|
||||
# Generate amplitude matrix
|
||||
amplitude = np.random.normal(
|
||||
pattern["amplitude_mean"],
|
||||
pattern["amplitude_std"],
|
||||
(self.num_antennas, self.num_subcarriers)
|
||||
)
|
||||
amplitude = np.clip(amplitude, 0, 1)
|
||||
|
||||
# Generate phase matrix
|
||||
phase = np.random.uniform(
|
||||
-np.pi, np.pi,
|
||||
(self.num_antennas, self.num_subcarriers)
|
||||
)
|
||||
|
||||
# Add temporal stability
|
||||
if hasattr(self, '_last_empty_sample'):
|
||||
stability = pattern["temporal_stability"]
|
||||
amplitude = stability * self._last_empty_sample["amplitude"] + (1 - stability) * amplitude
|
||||
phase = stability * self._last_empty_sample["phase"] + (1 - stability) * phase
|
||||
|
||||
sample = {
|
||||
"timestamp": timestamp.isoformat(),
|
||||
"router_id": "router_001",
|
||||
"amplitude": amplitude.tolist(),
|
||||
"phase": phase.tolist(),
|
||||
"frequency": self.frequency,
|
||||
"bandwidth": self.bandwidth,
|
||||
"num_antennas": self.num_antennas,
|
||||
"num_subcarriers": self.num_subcarriers,
|
||||
"sample_rate": self.sample_rate,
|
||||
"scenario": "empty_room",
|
||||
"signal_quality": np.random.uniform(0.85, 0.95)
|
||||
}
|
||||
|
||||
self._last_empty_sample = {
|
||||
"amplitude": amplitude,
|
||||
"phase": phase
|
||||
}
|
||||
|
||||
return sample
|
||||
|
||||
def generate_single_person_sample(self,
|
||||
activity: str = "standing",
|
||||
timestamp: Optional[datetime] = None) -> Dict[str, Any]:
|
||||
"""Generate CSI sample for single person activity."""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.utcnow()
|
||||
|
||||
if activity not in self.single_person_patterns:
|
||||
raise ValueError(f"Unknown activity: {activity}")
|
||||
|
||||
pattern = self.single_person_patterns[activity]
|
||||
|
||||
# Generate base amplitude
|
||||
amplitude = np.random.normal(
|
||||
pattern["amplitude_mean"],
|
||||
pattern["amplitude_std"],
|
||||
(self.num_antennas, self.num_subcarriers)
|
||||
)
|
||||
|
||||
# Add movement-induced variations
|
||||
movement_freq = pattern["movement_frequency"]
|
||||
time_factor = timestamp.timestamp()
|
||||
movement_modulation = 0.1 * np.sin(2 * np.pi * movement_freq * time_factor)
|
||||
amplitude += movement_modulation
|
||||
amplitude = np.clip(amplitude, 0, 1)
|
||||
|
||||
# Generate phase with activity-specific variance
|
||||
phase_base = np.random.uniform(-np.pi, np.pi, (self.num_antennas, self.num_subcarriers))
|
||||
phase_variance = pattern["phase_variance"]
|
||||
phase_noise = np.random.normal(0, phase_variance, (self.num_antennas, self.num_subcarriers))
|
||||
phase = phase_base + phase_noise
|
||||
phase = np.mod(phase + np.pi, 2 * np.pi) - np.pi # Wrap to [-π, π]
|
||||
|
||||
# Add temporal correlation
|
||||
if hasattr(self, f'_last_{activity}_sample'):
|
||||
stability = pattern["temporal_stability"]
|
||||
last_sample = getattr(self, f'_last_{activity}_sample')
|
||||
amplitude = stability * last_sample["amplitude"] + (1 - stability) * amplitude
|
||||
phase = stability * last_sample["phase"] + (1 - stability) * phase
|
||||
|
||||
sample = {
|
||||
"timestamp": timestamp.isoformat(),
|
||||
"router_id": "router_001",
|
||||
"amplitude": amplitude.tolist(),
|
||||
"phase": phase.tolist(),
|
||||
"frequency": self.frequency,
|
||||
"bandwidth": self.bandwidth,
|
||||
"num_antennas": self.num_antennas,
|
||||
"num_subcarriers": self.num_subcarriers,
|
||||
"sample_rate": self.sample_rate,
|
||||
"scenario": f"single_person_{activity}",
|
||||
"signal_quality": np.random.uniform(0.7, 0.9),
|
||||
"activity": activity
|
||||
}
|
||||
|
||||
setattr(self, f'_last_{activity}_sample', {
|
||||
"amplitude": amplitude,
|
||||
"phase": phase
|
||||
})
|
||||
|
||||
return sample
|
||||
|
||||
def generate_multi_person_sample(self,
|
||||
num_persons: int = 2,
|
||||
activities: Optional[List[str]] = None,
|
||||
timestamp: Optional[datetime] = None) -> Dict[str, Any]:
|
||||
"""Generate CSI sample for multiple persons."""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.utcnow()
|
||||
|
||||
if num_persons < 2 or num_persons > 4:
|
||||
raise ValueError("Number of persons must be between 2 and 4")
|
||||
|
||||
if activities is None:
|
||||
activities = random.choices(list(self.single_person_patterns.keys()), k=num_persons)
|
||||
|
||||
if len(activities) != num_persons:
|
||||
raise ValueError("Number of activities must match number of persons")
|
||||
|
||||
# Start with empty room baseline
|
||||
amplitude = np.random.normal(
|
||||
self.empty_room_pattern["amplitude_mean"],
|
||||
self.empty_room_pattern["amplitude_std"],
|
||||
(self.num_antennas, self.num_subcarriers)
|
||||
)
|
||||
|
||||
phase = np.random.uniform(
|
||||
-np.pi, np.pi,
|
||||
(self.num_antennas, self.num_subcarriers)
|
||||
)
|
||||
|
||||
# Add contribution from each person
|
||||
for i, activity in enumerate(activities):
|
||||
person_pattern = self.single_person_patterns[activity]
|
||||
|
||||
# Generate person-specific contribution
|
||||
person_amplitude = np.random.normal(
|
||||
person_pattern["amplitude_mean"] * 0.7, # Reduced for multi-person
|
||||
person_pattern["amplitude_std"],
|
||||
(self.num_antennas, self.num_subcarriers)
|
||||
)
|
||||
|
||||
# Add spatial variation (different persons at different locations)
|
||||
spatial_offset = i * self.num_subcarriers // num_persons
|
||||
person_amplitude = np.roll(person_amplitude, spatial_offset, axis=1)
|
||||
|
||||
# Add movement modulation
|
||||
movement_freq = person_pattern["movement_frequency"]
|
||||
time_factor = timestamp.timestamp() + i * 0.5 # Phase offset between persons
|
||||
movement_modulation = 0.05 * np.sin(2 * np.pi * movement_freq * time_factor)
|
||||
person_amplitude += movement_modulation
|
||||
|
||||
amplitude += person_amplitude
|
||||
|
||||
# Add phase contribution
|
||||
person_phase = np.random.normal(0, person_pattern["phase_variance"],
|
||||
(self.num_antennas, self.num_subcarriers))
|
||||
person_phase = np.roll(person_phase, spatial_offset, axis=1)
|
||||
phase += person_phase
|
||||
|
||||
# Apply multi-person complexity
|
||||
pattern = self.multi_person_patterns[num_persons]
|
||||
amplitude *= pattern["amplitude_multiplier"]
|
||||
phase *= pattern["phase_complexity"]
|
||||
|
||||
# Clip and normalize
|
||||
amplitude = np.clip(amplitude, 0, 1)
|
||||
phase = np.mod(phase + np.pi, 2 * np.pi) - np.pi
|
||||
|
||||
sample = {
|
||||
"timestamp": timestamp.isoformat(),
|
||||
"router_id": "router_001",
|
||||
"amplitude": amplitude.tolist(),
|
||||
"phase": phase.tolist(),
|
||||
"frequency": self.frequency,
|
||||
"bandwidth": self.bandwidth,
|
||||
"num_antennas": self.num_antennas,
|
||||
"num_subcarriers": self.num_subcarriers,
|
||||
"sample_rate": self.sample_rate,
|
||||
"scenario": f"multi_person_{num_persons}",
|
||||
"signal_quality": np.random.uniform(0.6, 0.8),
|
||||
"num_persons": num_persons,
|
||||
"activities": activities
|
||||
}
|
||||
|
||||
return sample
|
||||
|
||||
def generate_time_series(self,
|
||||
duration_seconds: int = 10,
|
||||
scenario: str = "single_person_walking",
|
||||
**kwargs) -> List[Dict[str, Any]]:
|
||||
"""Generate time series of CSI samples."""
|
||||
samples = []
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
for i in range(duration_seconds * self.sample_rate):
|
||||
timestamp = start_time + timedelta(seconds=i / self.sample_rate)
|
||||
|
||||
if scenario == "empty_room":
|
||||
sample = self.generate_empty_room_sample(timestamp)
|
||||
elif scenario.startswith("single_person_"):
|
||||
activity = scenario.replace("single_person_", "")
|
||||
sample = self.generate_single_person_sample(activity, timestamp)
|
||||
elif scenario.startswith("multi_person_"):
|
||||
num_persons = int(scenario.split("_")[-1])
|
||||
sample = self.generate_multi_person_sample(num_persons, timestamp=timestamp, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown scenario: {scenario}")
|
||||
|
||||
samples.append(sample)
|
||||
|
||||
return samples
|
||||
|
||||
def add_noise(self, sample: Dict[str, Any], noise_level: Optional[float] = None) -> Dict[str, Any]:
|
||||
"""Add noise to CSI sample."""
|
||||
if noise_level is None:
|
||||
noise_level = self.noise_level
|
||||
|
||||
noisy_sample = sample.copy()
|
||||
|
||||
# Add amplitude noise
|
||||
amplitude = np.array(sample["amplitude"])
|
||||
amplitude_noise = np.random.normal(0, noise_level, amplitude.shape)
|
||||
noisy_amplitude = amplitude + amplitude_noise
|
||||
noisy_amplitude = np.clip(noisy_amplitude, 0, 1)
|
||||
noisy_sample["amplitude"] = noisy_amplitude.tolist()
|
||||
|
||||
# Add phase noise
|
||||
phase = np.array(sample["phase"])
|
||||
phase_noise = np.random.normal(0, noise_level * np.pi, phase.shape)
|
||||
noisy_phase = phase + phase_noise
|
||||
noisy_phase = np.mod(noisy_phase + np.pi, 2 * np.pi) - np.pi
|
||||
noisy_sample["phase"] = noisy_phase.tolist()
|
||||
|
||||
# Reduce signal quality
|
||||
noisy_sample["signal_quality"] *= (1 - noise_level)
|
||||
|
||||
return noisy_sample
|
||||
|
||||
def simulate_hardware_artifacts(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Simulate hardware-specific artifacts."""
|
||||
artifact_sample = sample.copy()
|
||||
|
||||
amplitude = np.array(sample["amplitude"])
|
||||
phase = np.array(sample["phase"])
|
||||
|
||||
# Simulate antenna coupling
|
||||
coupling_matrix = np.random.uniform(0.95, 1.05, (self.num_antennas, self.num_antennas))
|
||||
amplitude = coupling_matrix @ amplitude
|
||||
|
||||
# Simulate frequency-dependent gain variations
|
||||
freq_response = 1 + 0.1 * np.sin(np.linspace(0, 2*np.pi, self.num_subcarriers))
|
||||
amplitude *= freq_response[np.newaxis, :]
|
||||
|
||||
# Simulate phase drift
|
||||
phase_drift = np.random.uniform(-0.1, 0.1) * np.arange(self.num_subcarriers)
|
||||
phase += phase_drift[np.newaxis, :]
|
||||
|
||||
# Clip and wrap
|
||||
amplitude = np.clip(amplitude, 0, 1)
|
||||
phase = np.mod(phase + np.pi, 2 * np.pi) - np.pi
|
||||
|
||||
artifact_sample["amplitude"] = amplitude.tolist()
|
||||
artifact_sample["phase"] = phase.tolist()
|
||||
|
||||
return artifact_sample
|
||||
|
||||
|
||||
# Convenience functions for common test scenarios
|
||||
def generate_fall_detection_sequence() -> List[Dict[str, Any]]:
|
||||
"""Generate CSI sequence showing fall detection scenario."""
|
||||
generator = CSIDataGenerator()
|
||||
|
||||
sequence = []
|
||||
|
||||
# Normal standing (5 seconds)
|
||||
sequence.extend(generator.generate_time_series(5, "single_person_standing"))
|
||||
|
||||
# Walking (3 seconds)
|
||||
sequence.extend(generator.generate_time_series(3, "single_person_walking"))
|
||||
|
||||
# Fall event (1 second transition)
|
||||
sequence.extend(generator.generate_time_series(1, "single_person_fallen"))
|
||||
|
||||
# Fallen state (3 seconds)
|
||||
sequence.extend(generator.generate_time_series(3, "single_person_fallen"))
|
||||
|
||||
return sequence
|
||||
|
||||
|
||||
def generate_multi_person_scenario() -> List[Dict[str, Any]]:
|
||||
"""Generate CSI sequence for multi-person scenario."""
|
||||
generator = CSIDataGenerator()
|
||||
|
||||
sequence = []
|
||||
|
||||
# Start with empty room
|
||||
sequence.extend(generator.generate_time_series(2, "empty_room"))
|
||||
|
||||
# One person enters
|
||||
sequence.extend(generator.generate_time_series(3, "single_person_walking"))
|
||||
|
||||
# Second person enters
|
||||
sequence.extend(generator.generate_time_series(5, "multi_person_2",
|
||||
activities=["standing", "walking"]))
|
||||
|
||||
# Third person enters
|
||||
sequence.extend(generator.generate_time_series(4, "multi_person_3",
|
||||
activities=["standing", "walking", "sitting"]))
|
||||
|
||||
return sequence
|
||||
|
||||
|
||||
def generate_noisy_environment_data() -> List[Dict[str, Any]]:
|
||||
"""Generate CSI data with various noise levels."""
|
||||
generator = CSIDataGenerator()
|
||||
|
||||
# Generate clean data
|
||||
clean_samples = generator.generate_time_series(5, "single_person_walking")
|
||||
|
||||
# Add different noise levels
|
||||
noisy_samples = []
|
||||
noise_levels = [0.05, 0.1, 0.2, 0.3]
|
||||
|
||||
for noise_level in noise_levels:
|
||||
for sample in clean_samples[:10]: # Take first 10 samples
|
||||
noisy_sample = generator.add_noise(sample, noise_level)
|
||||
noisy_samples.append(noisy_sample)
|
||||
|
||||
return noisy_samples
|
||||
|
||||
|
||||
def generate_hardware_test_data() -> List[Dict[str, Any]]:
|
||||
"""Generate CSI data with hardware artifacts."""
|
||||
generator = CSIDataGenerator()
|
||||
|
||||
# Generate base samples
|
||||
base_samples = generator.generate_time_series(3, "single_person_standing")
|
||||
|
||||
# Add hardware artifacts
|
||||
artifact_samples = []
|
||||
for sample in base_samples:
|
||||
artifact_sample = generator.simulate_hardware_artifacts(sample)
|
||||
artifact_samples.append(artifact_sample)
|
||||
|
||||
return artifact_samples
|
||||
|
||||
|
||||
# Test data validation utilities
|
||||
def validate_csi_sample(sample: Dict[str, Any]) -> bool:
|
||||
"""Validate CSI sample structure and data ranges."""
|
||||
required_fields = [
|
||||
"timestamp", "router_id", "amplitude", "phase",
|
||||
"frequency", "bandwidth", "num_antennas", "num_subcarriers"
|
||||
]
|
||||
|
||||
# Check required fields
|
||||
for field in required_fields:
|
||||
if field not in sample:
|
||||
return False
|
||||
|
||||
# Validate data types and ranges
|
||||
amplitude = np.array(sample["amplitude"])
|
||||
phase = np.array(sample["phase"])
|
||||
|
||||
# Check shapes
|
||||
expected_shape = (sample["num_antennas"], sample["num_subcarriers"])
|
||||
if amplitude.shape != expected_shape or phase.shape != expected_shape:
|
||||
return False
|
||||
|
||||
# Check value ranges
|
||||
if not (0 <= amplitude.min() and amplitude.max() <= 1):
|
||||
return False
|
||||
|
||||
if not (-np.pi <= phase.min() and phase.max() <= np.pi):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def extract_features_from_csi(sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Extract features from CSI sample for testing."""
|
||||
amplitude = np.array(sample["amplitude"])
|
||||
phase = np.array(sample["phase"])
|
||||
|
||||
features = {
|
||||
"amplitude_mean": float(np.mean(amplitude)),
|
||||
"amplitude_std": float(np.std(amplitude)),
|
||||
"amplitude_max": float(np.max(amplitude)),
|
||||
"amplitude_min": float(np.min(amplitude)),
|
||||
"phase_variance": float(np.var(phase)),
|
||||
"phase_range": float(np.max(phase) - np.min(phase)),
|
||||
"signal_energy": float(np.sum(amplitude ** 2)),
|
||||
"phase_coherence": float(np.abs(np.mean(np.exp(1j * phase)))),
|
||||
"spatial_correlation": float(np.mean(np.corrcoef(amplitude))),
|
||||
"frequency_diversity": float(np.std(np.mean(amplitude, axis=0)))
|
||||
}
|
||||
|
||||
return features
|
||||
338
tests/integration/test_api_endpoints.py
Normal file
338
tests/integration/test_api_endpoints.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""
|
||||
Integration tests for WiFi-DensePose API endpoints.
|
||||
|
||||
Tests all REST API endpoints with real service dependencies.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import FastAPI
|
||||
import httpx
|
||||
|
||||
from src.api.dependencies import (
|
||||
get_pose_service,
|
||||
get_stream_service,
|
||||
get_hardware_service,
|
||||
get_current_user
|
||||
)
|
||||
from src.api.routers.health import router as health_router
|
||||
from src.api.routers.pose import router as pose_router
|
||||
from src.api.routers.stream import router as stream_router
|
||||
|
||||
|
||||
class TestAPIEndpoints:
|
||||
"""Integration tests for API endpoints."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create FastAPI app with test dependencies."""
|
||||
app = FastAPI()
|
||||
app.include_router(health_router, prefix="/health", tags=["health"])
|
||||
app.include_router(pose_router, prefix="/pose", tags=["pose"])
|
||||
app.include_router(stream_router, prefix="/stream", tags=["stream"])
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pose_service(self):
|
||||
"""Mock pose service."""
|
||||
service = AsyncMock()
|
||||
service.health_check.return_value = {
|
||||
"status": "healthy",
|
||||
"message": "Service operational",
|
||||
"uptime_seconds": 3600.0,
|
||||
"metrics": {"processed_frames": 1000}
|
||||
}
|
||||
service.is_ready.return_value = True
|
||||
service.estimate_poses.return_value = {
|
||||
"timestamp": datetime.utcnow(),
|
||||
"frame_id": "test-frame-001",
|
||||
"persons": [],
|
||||
"zone_summary": {"zone1": 0},
|
||||
"processing_time_ms": 50.0,
|
||||
"metadata": {}
|
||||
}
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stream_service(self):
|
||||
"""Mock stream service."""
|
||||
service = AsyncMock()
|
||||
service.health_check.return_value = {
|
||||
"status": "healthy",
|
||||
"message": "Stream service operational",
|
||||
"uptime_seconds": 1800.0
|
||||
}
|
||||
service.is_ready.return_value = True
|
||||
service.get_status.return_value = {
|
||||
"is_active": True,
|
||||
"active_streams": [],
|
||||
"uptime_seconds": 1800.0
|
||||
}
|
||||
service.is_active.return_value = True
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def mock_hardware_service(self):
|
||||
"""Mock hardware service."""
|
||||
service = AsyncMock()
|
||||
service.health_check.return_value = {
|
||||
"status": "healthy",
|
||||
"message": "Hardware connected",
|
||||
"uptime_seconds": 7200.0,
|
||||
"metrics": {"connected_routers": 3}
|
||||
}
|
||||
service.is_ready.return_value = True
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user(self):
|
||||
"""Mock authenticated user."""
|
||||
return {
|
||||
"id": "test-user-001",
|
||||
"username": "testuser",
|
||||
"email": "test@example.com",
|
||||
"is_admin": False,
|
||||
"is_active": True,
|
||||
"permissions": ["read", "write"]
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def client(self, app, mock_pose_service, mock_stream_service, mock_hardware_service, mock_user):
|
||||
"""Create test client with mocked dependencies."""
|
||||
app.dependency_overrides[get_pose_service] = lambda: mock_pose_service
|
||||
app.dependency_overrides[get_stream_service] = lambda: mock_stream_service
|
||||
app.dependency_overrides[get_hardware_service] = lambda: mock_hardware_service
|
||||
app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
|
||||
def test_health_check_endpoint_should_fail_initially(self, client):
|
||||
"""Test health check endpoint - should fail initially."""
|
||||
# This test should fail because we haven't implemented the endpoint properly
|
||||
response = client.get("/health/health")
|
||||
|
||||
# This assertion will fail initially, driving us to implement the endpoint
|
||||
assert response.status_code == 200
|
||||
assert "status" in response.json()
|
||||
assert "components" in response.json()
|
||||
assert "system_metrics" in response.json()
|
||||
|
||||
def test_readiness_check_endpoint_should_fail_initially(self, client):
|
||||
"""Test readiness check endpoint - should fail initially."""
|
||||
response = client.get("/health/ready")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "ready" in data
|
||||
assert "checks" in data
|
||||
assert isinstance(data["checks"], dict)
|
||||
|
||||
def test_liveness_check_endpoint_should_fail_initially(self, client):
|
||||
"""Test liveness check endpoint - should fail initially."""
|
||||
response = client.get("/health/live")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "status" in data
|
||||
assert data["status"] == "alive"
|
||||
|
||||
def test_version_info_endpoint_should_fail_initially(self, client):
|
||||
"""Test version info endpoint - should fail initially."""
|
||||
response = client.get("/health/version")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "name" in data
|
||||
assert "version" in data
|
||||
assert "environment" in data
|
||||
|
||||
def test_pose_current_endpoint_should_fail_initially(self, client):
|
||||
"""Test current pose estimation endpoint - should fail initially."""
|
||||
response = client.get("/pose/current")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "timestamp" in data
|
||||
assert "frame_id" in data
|
||||
assert "persons" in data
|
||||
assert "zone_summary" in data
|
||||
|
||||
def test_pose_analyze_endpoint_should_fail_initially(self, client):
|
||||
"""Test pose analysis endpoint - should fail initially."""
|
||||
request_data = {
|
||||
"zone_ids": ["zone1", "zone2"],
|
||||
"confidence_threshold": 0.7,
|
||||
"max_persons": 10,
|
||||
"include_keypoints": True,
|
||||
"include_segmentation": False
|
||||
}
|
||||
|
||||
response = client.post("/pose/analyze", json=request_data)
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "timestamp" in data
|
||||
assert "persons" in data
|
||||
|
||||
def test_zone_occupancy_endpoint_should_fail_initially(self, client):
|
||||
"""Test zone occupancy endpoint - should fail initially."""
|
||||
response = client.get("/pose/zones/zone1/occupancy")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "zone_id" in data
|
||||
assert "current_occupancy" in data
|
||||
|
||||
def test_zones_summary_endpoint_should_fail_initially(self, client):
|
||||
"""Test zones summary endpoint - should fail initially."""
|
||||
response = client.get("/pose/zones/summary")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "total_persons" in data
|
||||
assert "zones" in data
|
||||
|
||||
def test_stream_status_endpoint_should_fail_initially(self, client):
|
||||
"""Test stream status endpoint - should fail initially."""
|
||||
response = client.get("/stream/status")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "is_active" in data
|
||||
assert "connected_clients" in data
|
||||
|
||||
def test_stream_start_endpoint_should_fail_initially(self, client):
|
||||
"""Test stream start endpoint - should fail initially."""
|
||||
response = client.post("/stream/start")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "message" in data
|
||||
|
||||
def test_stream_stop_endpoint_should_fail_initially(self, client):
|
||||
"""Test stream stop endpoint - should fail initially."""
|
||||
response = client.post("/stream/stop")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "message" in data
|
||||
|
||||
|
||||
class TestAPIErrorHandling:
|
||||
"""Test API error handling scenarios."""
|
||||
|
||||
@pytest.fixture
|
||||
def app_with_failing_services(self):
|
||||
"""Create app with failing service dependencies."""
|
||||
app = FastAPI()
|
||||
app.include_router(health_router, prefix="/health", tags=["health"])
|
||||
app.include_router(pose_router, prefix="/pose", tags=["pose"])
|
||||
|
||||
# Mock failing services
|
||||
failing_pose_service = AsyncMock()
|
||||
failing_pose_service.health_check.side_effect = Exception("Service unavailable")
|
||||
|
||||
app.dependency_overrides[get_pose_service] = lambda: failing_pose_service
|
||||
|
||||
return app
|
||||
|
||||
def test_health_check_with_failing_service_should_fail_initially(self, app_with_failing_services):
|
||||
"""Test health check with failing service - should fail initially."""
|
||||
with TestClient(app_with_failing_services) as client:
|
||||
response = client.get("/health/health")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "unhealthy"
|
||||
assert "hardware" in data["components"]
|
||||
assert data["components"]["pose"]["status"] == "unhealthy"
|
||||
|
||||
|
||||
class TestAPIAuthentication:
|
||||
"""Test API authentication scenarios."""
|
||||
|
||||
@pytest.fixture
|
||||
def app_with_auth(self):
|
||||
"""Create app with authentication enabled."""
|
||||
app = FastAPI()
|
||||
app.include_router(pose_router, prefix="/pose", tags=["pose"])
|
||||
|
||||
# Mock authenticated user dependency
|
||||
def get_authenticated_user():
|
||||
return {
|
||||
"id": "auth-user-001",
|
||||
"username": "authuser",
|
||||
"is_admin": True,
|
||||
"permissions": ["read", "write", "admin"]
|
||||
}
|
||||
|
||||
app.dependency_overrides[get_current_user] = get_authenticated_user
|
||||
|
||||
return app
|
||||
|
||||
def test_authenticated_endpoint_access_should_fail_initially(self, app_with_auth):
|
||||
"""Test authenticated endpoint access - should fail initially."""
|
||||
with TestClient(app_with_auth) as client:
|
||||
response = client.post("/pose/analyze", json={
|
||||
"confidence_threshold": 0.8,
|
||||
"include_keypoints": True
|
||||
})
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestAPIValidation:
|
||||
"""Test API request validation."""
|
||||
|
||||
@pytest.fixture
|
||||
def validation_app(self):
|
||||
"""Create app for validation testing."""
|
||||
app = FastAPI()
|
||||
app.include_router(pose_router, prefix="/pose", tags=["pose"])
|
||||
|
||||
# Mock service
|
||||
mock_service = AsyncMock()
|
||||
app.dependency_overrides[get_pose_service] = lambda: mock_service
|
||||
|
||||
return app
|
||||
|
||||
def test_invalid_confidence_threshold_should_fail_initially(self, validation_app):
|
||||
"""Test invalid confidence threshold validation - should fail initially."""
|
||||
with TestClient(validation_app) as client:
|
||||
response = client.post("/pose/analyze", json={
|
||||
"confidence_threshold": 1.5, # Invalid: > 1.0
|
||||
"include_keypoints": True
|
||||
})
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 422
|
||||
assert "validation error" in response.json()["detail"][0]["msg"].lower()
|
||||
|
||||
def test_invalid_max_persons_should_fail_initially(self, validation_app):
|
||||
"""Test invalid max_persons validation - should fail initially."""
|
||||
with TestClient(validation_app) as client:
|
||||
response = client.post("/pose/analyze", json={
|
||||
"max_persons": 0, # Invalid: < 1
|
||||
"include_keypoints": True
|
||||
})
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 422
|
||||
571
tests/integration/test_authentication.py
Normal file
571
tests/integration/test_authentication.py
Normal file
@@ -0,0 +1,571 @@
|
||||
"""
|
||||
Integration tests for authentication and authorization.
|
||||
|
||||
Tests JWT authentication flow, user permissions, and access control.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import jwt
|
||||
import json
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
|
||||
|
||||
class MockJWTToken:
|
||||
"""Mock JWT token for testing."""
|
||||
|
||||
def __init__(self, payload: Dict[str, Any], secret: str = "test-secret"):
|
||||
self.payload = payload
|
||||
self.secret = secret
|
||||
self.token = jwt.encode(payload, secret, algorithm="HS256")
|
||||
|
||||
def decode(self, token: str, secret: str) -> Dict[str, Any]:
|
||||
"""Decode JWT token."""
|
||||
return jwt.decode(token, secret, algorithms=["HS256"])
|
||||
|
||||
|
||||
class TestJWTAuthentication:
|
||||
"""Test JWT authentication functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def valid_user_payload(self):
|
||||
"""Valid user payload for JWT token."""
|
||||
return {
|
||||
"sub": "user-001",
|
||||
"username": "testuser",
|
||||
"email": "test@example.com",
|
||||
"is_admin": False,
|
||||
"is_active": True,
|
||||
"permissions": ["read", "write"],
|
||||
"exp": datetime.utcnow() + timedelta(hours=1),
|
||||
"iat": datetime.utcnow()
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_payload(self):
|
||||
"""Admin user payload for JWT token."""
|
||||
return {
|
||||
"sub": "admin-001",
|
||||
"username": "admin",
|
||||
"email": "admin@example.com",
|
||||
"is_admin": True,
|
||||
"is_active": True,
|
||||
"permissions": ["read", "write", "admin"],
|
||||
"exp": datetime.utcnow() + timedelta(hours=1),
|
||||
"iat": datetime.utcnow()
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def expired_user_payload(self):
|
||||
"""Expired user payload for JWT token."""
|
||||
return {
|
||||
"sub": "user-002",
|
||||
"username": "expireduser",
|
||||
"email": "expired@example.com",
|
||||
"is_admin": False,
|
||||
"is_active": True,
|
||||
"permissions": ["read"],
|
||||
"exp": datetime.utcnow() - timedelta(hours=1), # Expired
|
||||
"iat": datetime.utcnow() - timedelta(hours=2)
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jwt_service(self):
|
||||
"""Mock JWT service."""
|
||||
class MockJWTService:
|
||||
def __init__(self):
|
||||
self.secret = "test-secret-key"
|
||||
self.algorithm = "HS256"
|
||||
|
||||
def create_token(self, user_data: Dict[str, Any]) -> str:
|
||||
"""Create JWT token."""
|
||||
payload = {
|
||||
**user_data,
|
||||
"exp": datetime.utcnow() + timedelta(hours=1),
|
||||
"iat": datetime.utcnow()
|
||||
}
|
||||
return jwt.encode(payload, self.secret, algorithm=self.algorithm)
|
||||
|
||||
def verify_token(self, token: str) -> Dict[str, Any]:
|
||||
"""Verify JWT token."""
|
||||
try:
|
||||
payload = jwt.decode(token, self.secret, algorithms=[self.algorithm])
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has expired"
|
||||
)
|
||||
except jwt.InvalidTokenError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token"
|
||||
)
|
||||
|
||||
def refresh_token(self, token: str) -> str:
|
||||
"""Refresh JWT token."""
|
||||
payload = self.verify_token(token)
|
||||
# Remove exp and iat for new token
|
||||
payload.pop("exp", None)
|
||||
payload.pop("iat", None)
|
||||
return self.create_token(payload)
|
||||
|
||||
return MockJWTService()
|
||||
|
||||
def test_jwt_token_creation_should_fail_initially(self, mock_jwt_service, valid_user_payload):
|
||||
"""Test JWT token creation - should fail initially."""
|
||||
token = mock_jwt_service.create_token(valid_user_payload)
|
||||
|
||||
# This will fail initially
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 0
|
||||
|
||||
# Verify token can be decoded
|
||||
decoded = mock_jwt_service.verify_token(token)
|
||||
assert decoded["sub"] == valid_user_payload["sub"]
|
||||
assert decoded["username"] == valid_user_payload["username"]
|
||||
|
||||
def test_jwt_token_verification_should_fail_initially(self, mock_jwt_service, valid_user_payload):
|
||||
"""Test JWT token verification - should fail initially."""
|
||||
token = mock_jwt_service.create_token(valid_user_payload)
|
||||
decoded = mock_jwt_service.verify_token(token)
|
||||
|
||||
# This will fail initially
|
||||
assert decoded["sub"] == valid_user_payload["sub"]
|
||||
assert decoded["is_admin"] == valid_user_payload["is_admin"]
|
||||
assert "exp" in decoded
|
||||
assert "iat" in decoded
|
||||
|
||||
def test_expired_token_rejection_should_fail_initially(self, mock_jwt_service, expired_user_payload):
|
||||
"""Test expired token rejection - should fail initially."""
|
||||
# Create token with expired payload
|
||||
token = jwt.encode(expired_user_payload, mock_jwt_service.secret, algorithm=mock_jwt_service.algorithm)
|
||||
|
||||
# This should fail initially
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
mock_jwt_service.verify_token(token)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert "expired" in exc_info.value.detail.lower()
|
||||
|
||||
def test_invalid_token_rejection_should_fail_initially(self, mock_jwt_service):
|
||||
"""Test invalid token rejection - should fail initially."""
|
||||
invalid_token = "invalid.jwt.token"
|
||||
|
||||
# This should fail initially
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
mock_jwt_service.verify_token(invalid_token)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert "invalid" in exc_info.value.detail.lower()
|
||||
|
||||
def test_token_refresh_should_fail_initially(self, mock_jwt_service, valid_user_payload):
|
||||
"""Test token refresh functionality - should fail initially."""
|
||||
original_token = mock_jwt_service.create_token(valid_user_payload)
|
||||
|
||||
# Wait a moment to ensure different timestamps
|
||||
import time
|
||||
time.sleep(0.1)
|
||||
|
||||
refreshed_token = mock_jwt_service.refresh_token(original_token)
|
||||
|
||||
# This will fail initially
|
||||
assert refreshed_token != original_token
|
||||
|
||||
# Verify both tokens are valid but have different timestamps
|
||||
original_payload = mock_jwt_service.verify_token(original_token)
|
||||
refreshed_payload = mock_jwt_service.verify_token(refreshed_token)
|
||||
|
||||
assert original_payload["sub"] == refreshed_payload["sub"]
|
||||
assert original_payload["iat"] != refreshed_payload["iat"]
|
||||
|
||||
|
||||
class TestUserAuthentication:
|
||||
"""Test user authentication scenarios."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_service(self):
|
||||
"""Mock user service."""
|
||||
class MockUserService:
|
||||
def __init__(self):
|
||||
self.users = {
|
||||
"testuser": {
|
||||
"id": "user-001",
|
||||
"username": "testuser",
|
||||
"email": "test@example.com",
|
||||
"password_hash": "hashed_password",
|
||||
"is_admin": False,
|
||||
"is_active": True,
|
||||
"permissions": ["read", "write"],
|
||||
"zones": ["zone1", "zone2"],
|
||||
"created_at": datetime.utcnow()
|
||||
},
|
||||
"admin": {
|
||||
"id": "admin-001",
|
||||
"username": "admin",
|
||||
"email": "admin@example.com",
|
||||
"password_hash": "admin_hashed_password",
|
||||
"is_admin": True,
|
||||
"is_active": True,
|
||||
"permissions": ["read", "write", "admin"],
|
||||
"zones": [], # Admin has access to all zones
|
||||
"created_at": datetime.utcnow()
|
||||
}
|
||||
}
|
||||
|
||||
async def authenticate_user(self, username: str, password: str) -> Optional[Dict[str, Any]]:
|
||||
"""Authenticate user with username and password."""
|
||||
user = self.users.get(username)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
# Mock password verification
|
||||
if password == "correct_password":
|
||||
return user
|
||||
return None
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get user by ID."""
|
||||
for user in self.users.values():
|
||||
if user["id"] == user_id:
|
||||
return user
|
||||
return None
|
||||
|
||||
async def update_user_activity(self, user_id: str):
|
||||
"""Update user last activity."""
|
||||
user = await self.get_user_by_id(user_id)
|
||||
if user:
|
||||
user["last_activity"] = datetime.utcnow()
|
||||
|
||||
return MockUserService()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_authentication_success_should_fail_initially(self, mock_user_service):
|
||||
"""Test successful user authentication - should fail initially."""
|
||||
user = await mock_user_service.authenticate_user("testuser", "correct_password")
|
||||
|
||||
# This will fail initially
|
||||
assert user is not None
|
||||
assert user["username"] == "testuser"
|
||||
assert user["is_active"] is True
|
||||
assert "read" in user["permissions"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_authentication_failure_should_fail_initially(self, mock_user_service):
|
||||
"""Test failed user authentication - should fail initially."""
|
||||
user = await mock_user_service.authenticate_user("testuser", "wrong_password")
|
||||
|
||||
# This will fail initially
|
||||
assert user is None
|
||||
|
||||
# Test with non-existent user
|
||||
user = await mock_user_service.authenticate_user("nonexistent", "any_password")
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_user_authentication_should_fail_initially(self, mock_user_service):
|
||||
"""Test admin user authentication - should fail initially."""
|
||||
admin = await mock_user_service.authenticate_user("admin", "correct_password")
|
||||
|
||||
# This will fail initially
|
||||
assert admin is not None
|
||||
assert admin["is_admin"] is True
|
||||
assert "admin" in admin["permissions"]
|
||||
assert admin["zones"] == [] # Admin has access to all zones
|
||||
|
||||
|
||||
class TestAuthorizationDependencies:
|
||||
"""Test authorization dependency functions."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request(self):
|
||||
"""Mock FastAPI request."""
|
||||
class MockRequest:
|
||||
def __init__(self):
|
||||
self.state = MagicMock()
|
||||
self.state.user = None
|
||||
|
||||
return MockRequest()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credentials(self):
|
||||
"""Mock HTTP authorization credentials."""
|
||||
def create_credentials(token: str):
|
||||
return HTTPAuthorizationCredentials(
|
||||
scheme="Bearer",
|
||||
credentials=token
|
||||
)
|
||||
return create_credentials
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_with_valid_token_should_fail_initially(self, mock_request, mock_credentials):
|
||||
"""Test get_current_user with valid token - should fail initially."""
|
||||
# Mock the get_current_user dependency
|
||||
async def mock_get_current_user(request, credentials):
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
# Mock token validation
|
||||
if credentials.credentials == "valid_token":
|
||||
return {
|
||||
"id": "user-001",
|
||||
"username": "testuser",
|
||||
"is_admin": False,
|
||||
"is_active": True,
|
||||
"permissions": ["read", "write"]
|
||||
}
|
||||
return None
|
||||
|
||||
credentials = mock_credentials("valid_token")
|
||||
user = await mock_get_current_user(mock_request, credentials)
|
||||
|
||||
# This will fail initially
|
||||
assert user is not None
|
||||
assert user["username"] == "testuser"
|
||||
assert user["is_active"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_without_credentials_should_fail_initially(self, mock_request):
|
||||
"""Test get_current_user without credentials - should fail initially."""
|
||||
async def mock_get_current_user(request, credentials):
|
||||
if not credentials:
|
||||
return None
|
||||
return {"id": "user-001"}
|
||||
|
||||
user = await mock_get_current_user(mock_request, None)
|
||||
|
||||
# This will fail initially
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_active_user_should_fail_initially(self):
|
||||
"""Test require active user dependency - should fail initially."""
|
||||
async def mock_get_current_active_user(current_user):
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required"
|
||||
)
|
||||
|
||||
if not current_user.get("is_active", True):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
# Test with active user
|
||||
active_user = {"id": "user-001", "is_active": True}
|
||||
result = await mock_get_current_active_user(active_user)
|
||||
|
||||
# This will fail initially
|
||||
assert result == active_user
|
||||
|
||||
# Test with inactive user
|
||||
inactive_user = {"id": "user-002", "is_active": False}
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await mock_get_current_active_user(inactive_user)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
# Test with no user
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await mock_get_current_active_user(None)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_admin_user_should_fail_initially(self):
|
||||
"""Test require admin user dependency - should fail initially."""
|
||||
async def mock_get_admin_user(current_user):
|
||||
if not current_user.get("is_admin", False):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin privileges required"
|
||||
)
|
||||
return current_user
|
||||
|
||||
# Test with admin user
|
||||
admin_user = {"id": "admin-001", "is_admin": True}
|
||||
result = await mock_get_admin_user(admin_user)
|
||||
|
||||
# This will fail initially
|
||||
assert result == admin_user
|
||||
|
||||
# Test with regular user
|
||||
regular_user = {"id": "user-001", "is_admin": False}
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await mock_get_admin_user(regular_user)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_checking_should_fail_initially(self):
|
||||
"""Test permission checking functionality - should fail initially."""
|
||||
def require_permission(permission: str):
|
||||
async def check_permission(current_user):
|
||||
user_permissions = current_user.get("permissions", [])
|
||||
|
||||
# Admin users have all permissions
|
||||
if current_user.get("is_admin", False):
|
||||
return current_user
|
||||
|
||||
# Check specific permission
|
||||
if permission not in user_permissions:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Permission '{permission}' required"
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
return check_permission
|
||||
|
||||
# Test with user having required permission
|
||||
user_with_permission = {
|
||||
"id": "user-001",
|
||||
"permissions": ["read", "write"],
|
||||
"is_admin": False
|
||||
}
|
||||
|
||||
check_read = require_permission("read")
|
||||
result = await check_read(user_with_permission)
|
||||
|
||||
# This will fail initially
|
||||
assert result == user_with_permission
|
||||
|
||||
# Test with user missing permission
|
||||
check_admin = require_permission("admin")
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await check_admin(user_with_permission)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert "admin" in exc_info.value.detail
|
||||
|
||||
# Test with admin user (should have all permissions)
|
||||
admin_user = {"id": "admin-001", "is_admin": True, "permissions": ["read"]}
|
||||
result = await check_admin(admin_user)
|
||||
assert result == admin_user
|
||||
|
||||
|
||||
class TestZoneAndRouterAccess:
|
||||
"""Test zone and router access control."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_domain_config(self):
|
||||
"""Mock domain configuration."""
|
||||
class MockDomainConfig:
|
||||
def __init__(self):
|
||||
self.zones = {
|
||||
"zone1": {"id": "zone1", "name": "Zone 1", "enabled": True},
|
||||
"zone2": {"id": "zone2", "name": "Zone 2", "enabled": True},
|
||||
"zone3": {"id": "zone3", "name": "Zone 3", "enabled": False}
|
||||
}
|
||||
self.routers = {
|
||||
"router1": {"id": "router1", "name": "Router 1", "enabled": True},
|
||||
"router2": {"id": "router2", "name": "Router 2", "enabled": False}
|
||||
}
|
||||
|
||||
def get_zone(self, zone_id: str):
|
||||
return self.zones.get(zone_id)
|
||||
|
||||
def get_router(self, router_id: str):
|
||||
return self.routers.get(router_id)
|
||||
|
||||
return MockDomainConfig()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zone_access_validation_should_fail_initially(self, mock_domain_config):
|
||||
"""Test zone access validation - should fail initially."""
|
||||
async def validate_zone_access(zone_id: str, current_user=None):
|
||||
zone = mock_domain_config.get_zone(zone_id)
|
||||
if not zone:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Zone '{zone_id}' not found"
|
||||
)
|
||||
|
||||
if not zone["enabled"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Zone '{zone_id}' is disabled"
|
||||
)
|
||||
|
||||
if current_user:
|
||||
if current_user.get("is_admin", False):
|
||||
return zone_id
|
||||
|
||||
user_zones = current_user.get("zones", [])
|
||||
if user_zones and zone_id not in user_zones:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Access denied to zone '{zone_id}'"
|
||||
)
|
||||
|
||||
return zone_id
|
||||
|
||||
# Test valid zone access
|
||||
result = await validate_zone_access("zone1")
|
||||
|
||||
# This will fail initially
|
||||
assert result == "zone1"
|
||||
|
||||
# Test invalid zone
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await validate_zone_access("nonexistent")
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
# Test disabled zone
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await validate_zone_access("zone3")
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
# Test user with zone access
|
||||
user_with_access = {"id": "user-001", "zones": ["zone1", "zone2"]}
|
||||
result = await validate_zone_access("zone1", user_with_access)
|
||||
assert result == "zone1"
|
||||
|
||||
# Test user without zone access
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await validate_zone_access("zone2", {"id": "user-002", "zones": ["zone1"]})
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_access_validation_should_fail_initially(self, mock_domain_config):
|
||||
"""Test router access validation - should fail initially."""
|
||||
async def validate_router_access(router_id: str, current_user=None):
|
||||
router = mock_domain_config.get_router(router_id)
|
||||
if not router:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Router '{router_id}' not found"
|
||||
)
|
||||
|
||||
if not router["enabled"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Router '{router_id}' is disabled"
|
||||
)
|
||||
|
||||
return router_id
|
||||
|
||||
# Test valid router access
|
||||
result = await validate_router_access("router1")
|
||||
|
||||
# This will fail initially
|
||||
assert result == "router1"
|
||||
|
||||
# Test disabled router
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await validate_router_access("router2")
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
447
tests/integration/test_full_system_integration.py
Normal file
447
tests/integration/test_full_system_integration.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""
|
||||
Full system integration tests for WiFi-DensePose API
|
||||
Tests the complete integration of all components working together.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
import httpx
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from src.config.settings import get_settings
|
||||
from src.app import app
|
||||
from src.database.connection import get_database_manager
|
||||
from src.services.orchestrator import get_service_orchestrator
|
||||
from src.tasks.cleanup import get_cleanup_manager
|
||||
from src.tasks.monitoring import get_monitoring_manager
|
||||
from src.tasks.backup import get_backup_manager
|
||||
|
||||
|
||||
class TestFullSystemIntegration:
|
||||
"""Test complete system integration."""
|
||||
|
||||
@pytest.fixture
|
||||
async def settings(self):
|
||||
"""Get test settings."""
|
||||
settings = get_settings()
|
||||
settings.environment = "test"
|
||||
settings.debug = True
|
||||
settings.database_url = "sqlite+aiosqlite:///test_integration.db"
|
||||
settings.redis_enabled = False
|
||||
return settings
|
||||
|
||||
@pytest.fixture
|
||||
async def db_manager(self, settings):
|
||||
"""Get database manager for testing."""
|
||||
manager = get_database_manager(settings)
|
||||
await manager.initialize()
|
||||
yield manager
|
||||
await manager.close_all_connections()
|
||||
|
||||
@pytest.fixture
|
||||
async def client(self, settings):
|
||||
"""Get test HTTP client."""
|
||||
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
@pytest.fixture
|
||||
async def orchestrator(self, settings, db_manager):
|
||||
"""Get service orchestrator for testing."""
|
||||
orchestrator = get_service_orchestrator(settings)
|
||||
await orchestrator.initialize()
|
||||
yield orchestrator
|
||||
await orchestrator.shutdown()
|
||||
|
||||
async def test_application_startup_and_shutdown(self, settings, db_manager):
|
||||
"""Test complete application startup and shutdown sequence."""
|
||||
|
||||
# Test database initialization
|
||||
await db_manager.test_connection()
|
||||
stats = await db_manager.get_connection_stats()
|
||||
assert stats["database"]["connected"] is True
|
||||
|
||||
# Test service orchestrator initialization
|
||||
orchestrator = get_service_orchestrator(settings)
|
||||
await orchestrator.initialize()
|
||||
|
||||
# Verify services are running
|
||||
health_status = await orchestrator.get_health_status()
|
||||
assert health_status["status"] in ["healthy", "warning"]
|
||||
|
||||
# Test graceful shutdown
|
||||
await orchestrator.shutdown()
|
||||
|
||||
# Verify cleanup
|
||||
final_stats = await db_manager.get_connection_stats()
|
||||
assert final_stats is not None
|
||||
|
||||
async def test_api_endpoints_integration(self, client, settings, db_manager):
|
||||
"""Test API endpoints work with database integration."""
|
||||
|
||||
# Test health endpoint
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
health_data = response.json()
|
||||
assert "status" in health_data
|
||||
assert "timestamp" in health_data
|
||||
|
||||
# Test metrics endpoint
|
||||
response = await client.get("/metrics")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Test devices endpoint
|
||||
response = await client.get("/api/v1/devices")
|
||||
assert response.status_code == 200
|
||||
devices_data = response.json()
|
||||
assert "devices" in devices_data
|
||||
assert isinstance(devices_data["devices"], list)
|
||||
|
||||
# Test sessions endpoint
|
||||
response = await client.get("/api/v1/sessions")
|
||||
assert response.status_code == 200
|
||||
sessions_data = response.json()
|
||||
assert "sessions" in sessions_data
|
||||
assert isinstance(sessions_data["sessions"], list)
|
||||
|
||||
@patch('src.core.router_interface.RouterInterface')
|
||||
@patch('src.core.csi_processor.CSIProcessor')
|
||||
@patch('src.core.pose_estimator.PoseEstimator')
|
||||
async def test_data_processing_pipeline(
|
||||
self,
|
||||
mock_pose_estimator,
|
||||
mock_csi_processor,
|
||||
mock_router_interface,
|
||||
client,
|
||||
settings,
|
||||
db_manager
|
||||
):
|
||||
"""Test complete data processing pipeline integration."""
|
||||
|
||||
# Setup mocks
|
||||
mock_router = MagicMock()
|
||||
mock_router_interface.return_value = mock_router
|
||||
mock_router.connect.return_value = True
|
||||
mock_router.start_capture.return_value = True
|
||||
mock_router.get_csi_data.return_value = {
|
||||
"timestamp": time.time(),
|
||||
"csi_matrix": [[1.0, 2.0], [3.0, 4.0]],
|
||||
"rssi": -45,
|
||||
"noise_floor": -90
|
||||
}
|
||||
|
||||
mock_processor = MagicMock()
|
||||
mock_csi_processor.return_value = mock_processor
|
||||
mock_processor.process_csi_data.return_value = {
|
||||
"processed_csi": [[1.1, 2.1], [3.1, 4.1]],
|
||||
"quality_score": 0.85,
|
||||
"phase_sanitized": True
|
||||
}
|
||||
|
||||
mock_estimator = MagicMock()
|
||||
mock_pose_estimator.return_value = mock_estimator
|
||||
mock_estimator.estimate_pose.return_value = {
|
||||
"pose_data": {
|
||||
"keypoints": [[100, 200], [150, 250]],
|
||||
"confidence": 0.9
|
||||
},
|
||||
"processing_time": 0.05
|
||||
}
|
||||
|
||||
# Test device registration
|
||||
device_data = {
|
||||
"name": "test_router",
|
||||
"ip_address": "192.168.1.1",
|
||||
"device_type": "router",
|
||||
"model": "test_model"
|
||||
}
|
||||
|
||||
response = await client.post("/api/v1/devices", json=device_data)
|
||||
assert response.status_code == 201
|
||||
device_response = response.json()
|
||||
device_id = device_response["device"]["id"]
|
||||
|
||||
# Test session creation
|
||||
session_data = {
|
||||
"device_id": device_id,
|
||||
"session_type": "pose_detection",
|
||||
"configuration": {
|
||||
"sampling_rate": 1000,
|
||||
"duration": 60
|
||||
}
|
||||
}
|
||||
|
||||
response = await client.post("/api/v1/sessions", json=session_data)
|
||||
assert response.status_code == 201
|
||||
session_response = response.json()
|
||||
session_id = session_response["session"]["id"]
|
||||
|
||||
# Test CSI data submission
|
||||
csi_data = {
|
||||
"session_id": session_id,
|
||||
"timestamp": time.time(),
|
||||
"csi_matrix": [[1.0, 2.0], [3.0, 4.0]],
|
||||
"rssi": -45,
|
||||
"noise_floor": -90
|
||||
}
|
||||
|
||||
response = await client.post("/api/v1/csi-data", json=csi_data)
|
||||
assert response.status_code == 201
|
||||
|
||||
# Test pose detection retrieval
|
||||
response = await client.get(f"/api/v1/sessions/{session_id}/pose-detections")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Test session completion
|
||||
response = await client.patch(
|
||||
f"/api/v1/sessions/{session_id}",
|
||||
json={"status": "completed"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_background_tasks_integration(self, settings, db_manager):
|
||||
"""Test background tasks integration."""
|
||||
|
||||
# Test cleanup manager
|
||||
cleanup_manager = get_cleanup_manager(settings)
|
||||
cleanup_stats = cleanup_manager.get_stats()
|
||||
assert "manager" in cleanup_stats
|
||||
|
||||
# Run cleanup task
|
||||
cleanup_result = await cleanup_manager.run_all_tasks()
|
||||
assert cleanup_result["success"] is True
|
||||
|
||||
# Test monitoring manager
|
||||
monitoring_manager = get_monitoring_manager(settings)
|
||||
monitoring_stats = monitoring_manager.get_stats()
|
||||
assert "manager" in monitoring_stats
|
||||
|
||||
# Run monitoring task
|
||||
monitoring_result = await monitoring_manager.run_all_tasks()
|
||||
assert monitoring_result["success"] is True
|
||||
|
||||
# Test backup manager
|
||||
backup_manager = get_backup_manager(settings)
|
||||
backup_stats = backup_manager.get_stats()
|
||||
assert "manager" in backup_stats
|
||||
|
||||
# Run backup task
|
||||
backup_result = await backup_manager.run_all_tasks()
|
||||
assert backup_result["success"] is True
|
||||
|
||||
async def test_error_handling_integration(self, client, settings, db_manager):
|
||||
"""Test error handling across the system."""
|
||||
|
||||
# Test invalid device creation
|
||||
invalid_device_data = {
|
||||
"name": "", # Invalid empty name
|
||||
"ip_address": "invalid_ip",
|
||||
"device_type": "unknown_type"
|
||||
}
|
||||
|
||||
response = await client.post("/api/v1/devices", json=invalid_device_data)
|
||||
assert response.status_code == 422
|
||||
error_response = response.json()
|
||||
assert "detail" in error_response
|
||||
|
||||
# Test non-existent resource access
|
||||
response = await client.get("/api/v1/devices/99999")
|
||||
assert response.status_code == 404
|
||||
|
||||
# Test invalid session creation
|
||||
invalid_session_data = {
|
||||
"device_id": "invalid_uuid",
|
||||
"session_type": "invalid_type"
|
||||
}
|
||||
|
||||
response = await client.post("/api/v1/sessions", json=invalid_session_data)
|
||||
assert response.status_code == 422
|
||||
|
||||
async def test_authentication_and_authorization(self, client, settings):
|
||||
"""Test authentication and authorization integration."""
|
||||
|
||||
# Test protected endpoint without authentication
|
||||
response = await client.get("/api/v1/admin/system-info")
|
||||
assert response.status_code in [401, 403]
|
||||
|
||||
# Test with invalid token
|
||||
headers = {"Authorization": "Bearer invalid_token"}
|
||||
response = await client.get("/api/v1/admin/system-info", headers=headers)
|
||||
assert response.status_code in [401, 403]
|
||||
|
||||
async def test_rate_limiting_integration(self, client, settings):
|
||||
"""Test rate limiting integration."""
|
||||
|
||||
# Make multiple rapid requests to test rate limiting
|
||||
responses = []
|
||||
for i in range(10):
|
||||
response = await client.get("/health")
|
||||
responses.append(response.status_code)
|
||||
|
||||
# Should have at least some successful responses
|
||||
assert 200 in responses
|
||||
|
||||
# Rate limiting might kick in for some requests
|
||||
# This depends on the rate limiting configuration
|
||||
|
||||
async def test_monitoring_and_metrics_integration(self, client, settings, db_manager):
|
||||
"""Test monitoring and metrics collection integration."""
|
||||
|
||||
# Test metrics endpoint
|
||||
response = await client.get("/metrics")
|
||||
assert response.status_code == 200
|
||||
metrics_text = response.text
|
||||
|
||||
# Check for Prometheus format metrics
|
||||
assert "# HELP" in metrics_text
|
||||
assert "# TYPE" in metrics_text
|
||||
|
||||
# Test health check with detailed information
|
||||
response = await client.get("/health?detailed=true")
|
||||
assert response.status_code == 200
|
||||
health_data = response.json()
|
||||
|
||||
assert "database" in health_data
|
||||
assert "services" in health_data
|
||||
assert "system" in health_data
|
||||
|
||||
async def test_configuration_management_integration(self, settings):
|
||||
"""Test configuration management integration."""
|
||||
|
||||
# Test settings validation
|
||||
assert settings.environment == "test"
|
||||
assert settings.debug is True
|
||||
|
||||
# Test database URL configuration
|
||||
assert "test_integration.db" in settings.database_url
|
||||
|
||||
# Test Redis configuration
|
||||
assert settings.redis_enabled is False
|
||||
|
||||
# Test logging configuration
|
||||
assert settings.log_level in ["DEBUG", "INFO", "WARNING", "ERROR"]
|
||||
|
||||
async def test_database_migration_integration(self, settings, db_manager):
|
||||
"""Test database migration integration."""
|
||||
|
||||
# Test database connection
|
||||
await db_manager.test_connection()
|
||||
|
||||
# Test table creation
|
||||
async with db_manager.get_async_session() as session:
|
||||
from sqlalchemy import text
|
||||
|
||||
# Check if tables exist
|
||||
tables_query = text("""
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND name NOT LIKE 'sqlite_%'
|
||||
""")
|
||||
|
||||
result = await session.execute(tables_query)
|
||||
tables = [row[0] for row in result.fetchall()]
|
||||
|
||||
# Should have our main tables
|
||||
expected_tables = ["devices", "sessions", "csi_data", "pose_detections"]
|
||||
for table in expected_tables:
|
||||
assert table in tables
|
||||
|
||||
async def test_concurrent_operations_integration(self, client, settings, db_manager):
|
||||
"""Test concurrent operations integration."""
|
||||
|
||||
async def create_device(name: str):
|
||||
device_data = {
|
||||
"name": f"test_device_{name}",
|
||||
"ip_address": f"192.168.1.{name}",
|
||||
"device_type": "router",
|
||||
"model": "test_model"
|
||||
}
|
||||
response = await client.post("/api/v1/devices", json=device_data)
|
||||
return response.status_code
|
||||
|
||||
# Create multiple devices concurrently
|
||||
tasks = [create_device(str(i)) for i in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All should succeed
|
||||
assert all(status == 201 for status in results)
|
||||
|
||||
# Verify all devices were created
|
||||
response = await client.get("/api/v1/devices")
|
||||
assert response.status_code == 200
|
||||
devices_data = response.json()
|
||||
assert len(devices_data["devices"]) >= 5
|
||||
|
||||
async def test_system_resource_management(self, settings, db_manager, orchestrator):
|
||||
"""Test system resource management integration."""
|
||||
|
||||
# Test connection pool management
|
||||
stats = await db_manager.get_connection_stats()
|
||||
assert "database" in stats
|
||||
assert "pool_size" in stats["database"]
|
||||
|
||||
# Test service resource usage
|
||||
health_status = await orchestrator.get_health_status()
|
||||
assert "memory_usage" in health_status
|
||||
assert "cpu_usage" in health_status
|
||||
|
||||
# Test cleanup of resources
|
||||
await orchestrator.cleanup_resources()
|
||||
|
||||
# Verify resources are cleaned up
|
||||
final_stats = await db_manager.get_connection_stats()
|
||||
assert final_stats is not None
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestSystemPerformance:
|
||||
"""Test system performance under load."""
|
||||
|
||||
async def test_api_response_times(self, client):
|
||||
"""Test API response times under normal load."""
|
||||
|
||||
start_time = time.time()
|
||||
response = await client.get("/health")
|
||||
end_time = time.time()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert (end_time - start_time) < 1.0 # Should respond within 1 second
|
||||
|
||||
async def test_database_query_performance(self, db_manager):
|
||||
"""Test database query performance."""
|
||||
|
||||
async with db_manager.get_async_session() as session:
|
||||
from sqlalchemy import text
|
||||
|
||||
start_time = time.time()
|
||||
result = await session.execute(text("SELECT 1"))
|
||||
end_time = time.time()
|
||||
|
||||
assert result.scalar() == 1
|
||||
assert (end_time - start_time) < 0.1 # Should complete within 100ms
|
||||
|
||||
async def test_memory_usage_stability(self, orchestrator):
|
||||
"""Test memory usage remains stable."""
|
||||
|
||||
import psutil
|
||||
import os
|
||||
|
||||
process = psutil.Process(os.getpid())
|
||||
initial_memory = process.memory_info().rss
|
||||
|
||||
# Perform some operations
|
||||
for _ in range(10):
|
||||
health_status = await orchestrator.get_health_status()
|
||||
assert health_status is not None
|
||||
|
||||
final_memory = process.memory_info().rss
|
||||
memory_increase = final_memory - initial_memory
|
||||
|
||||
# Memory increase should be reasonable (less than 50MB)
|
||||
assert memory_increase < 50 * 1024 * 1024
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
663
tests/integration/test_hardware_integration.py
Normal file
663
tests/integration/test_hardware_integration.py
Normal file
@@ -0,0 +1,663 @@
|
||||
"""
|
||||
Integration tests for hardware integration and router communication.
|
||||
|
||||
Tests WiFi router communication, CSI data collection, and hardware management.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import json
|
||||
import socket
|
||||
|
||||
|
||||
class MockRouterInterface:
|
||||
"""Mock WiFi router interface for testing."""
|
||||
|
||||
def __init__(self, router_id: str, ip_address: str = "192.168.1.1"):
|
||||
self.router_id = router_id
|
||||
self.ip_address = ip_address
|
||||
self.is_connected = False
|
||||
self.is_authenticated = False
|
||||
self.csi_streaming = False
|
||||
self.connection_attempts = 0
|
||||
self.last_heartbeat = None
|
||||
self.firmware_version = "1.2.3"
|
||||
self.capabilities = ["csi", "beamforming", "mimo"]
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to the router."""
|
||||
self.connection_attempts += 1
|
||||
|
||||
# Simulate connection failure for testing
|
||||
if self.connection_attempts == 1:
|
||||
return False
|
||||
|
||||
await asyncio.sleep(0.1) # Simulate connection time
|
||||
self.is_connected = True
|
||||
return True
|
||||
|
||||
async def authenticate(self, username: str, password: str) -> bool:
|
||||
"""Authenticate with the router."""
|
||||
if not self.is_connected:
|
||||
return False
|
||||
|
||||
# Simulate authentication
|
||||
if username == "admin" and password == "correct_password":
|
||||
self.is_authenticated = True
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def start_csi_streaming(self, config: Dict[str, Any]) -> bool:
|
||||
"""Start CSI data streaming."""
|
||||
if not self.is_authenticated:
|
||||
return False
|
||||
|
||||
# This should fail initially to test proper error handling
|
||||
return False
|
||||
|
||||
async def stop_csi_streaming(self) -> bool:
|
||||
"""Stop CSI data streaming."""
|
||||
if self.csi_streaming:
|
||||
self.csi_streaming = False
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get router status."""
|
||||
return {
|
||||
"router_id": self.router_id,
|
||||
"ip_address": self.ip_address,
|
||||
"is_connected": self.is_connected,
|
||||
"is_authenticated": self.is_authenticated,
|
||||
"csi_streaming": self.csi_streaming,
|
||||
"firmware_version": self.firmware_version,
|
||||
"uptime_seconds": 3600,
|
||||
"signal_strength": -45.2,
|
||||
"temperature": 42.5,
|
||||
"cpu_usage": 15.3
|
||||
}
|
||||
|
||||
async def send_heartbeat(self) -> bool:
|
||||
"""Send heartbeat to router."""
|
||||
if not self.is_connected:
|
||||
return False
|
||||
|
||||
self.last_heartbeat = datetime.utcnow()
|
||||
return True
|
||||
|
||||
|
||||
class TestRouterConnection:
|
||||
"""Test router connection functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def router_interface(self):
|
||||
"""Create router interface for testing."""
|
||||
return MockRouterInterface("router_001", "192.168.1.100")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_connection_should_fail_initially(self, router_interface):
|
||||
"""Test router connection - should fail initially."""
|
||||
# First connection attempt should fail
|
||||
result = await router_interface.connect()
|
||||
|
||||
# This will fail initially because we designed the mock to fail first attempt
|
||||
assert result is False
|
||||
assert router_interface.is_connected is False
|
||||
assert router_interface.connection_attempts == 1
|
||||
|
||||
# Second attempt should succeed
|
||||
result = await router_interface.connect()
|
||||
assert result is True
|
||||
assert router_interface.is_connected is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_authentication_should_fail_initially(self, router_interface):
|
||||
"""Test router authentication - should fail initially."""
|
||||
# Connect first
|
||||
await router_interface.connect()
|
||||
await router_interface.connect() # Second attempt succeeds
|
||||
|
||||
# Test wrong credentials
|
||||
result = await router_interface.authenticate("admin", "wrong_password")
|
||||
|
||||
# This will fail initially
|
||||
assert result is False
|
||||
assert router_interface.is_authenticated is False
|
||||
|
||||
# Test correct credentials
|
||||
result = await router_interface.authenticate("admin", "correct_password")
|
||||
assert result is True
|
||||
assert router_interface.is_authenticated is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_csi_streaming_start_should_fail_initially(self, router_interface):
|
||||
"""Test CSI streaming start - should fail initially."""
|
||||
# Setup connection and authentication
|
||||
await router_interface.connect()
|
||||
await router_interface.connect() # Second attempt succeeds
|
||||
await router_interface.authenticate("admin", "correct_password")
|
||||
|
||||
# Try to start CSI streaming
|
||||
config = {
|
||||
"frequency": 5.8e9,
|
||||
"bandwidth": 80e6,
|
||||
"sample_rate": 1000,
|
||||
"antenna_config": "4x4_mimo"
|
||||
}
|
||||
|
||||
result = await router_interface.start_csi_streaming(config)
|
||||
|
||||
# This will fail initially because the mock is designed to return False
|
||||
assert result is False
|
||||
assert router_interface.csi_streaming is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_status_retrieval_should_fail_initially(self, router_interface):
|
||||
"""Test router status retrieval - should fail initially."""
|
||||
status = await router_interface.get_status()
|
||||
|
||||
# This will fail initially
|
||||
assert isinstance(status, dict)
|
||||
assert status["router_id"] == "router_001"
|
||||
assert status["ip_address"] == "192.168.1.100"
|
||||
assert "firmware_version" in status
|
||||
assert "uptime_seconds" in status
|
||||
assert "signal_strength" in status
|
||||
assert "temperature" in status
|
||||
assert "cpu_usage" in status
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heartbeat_mechanism_should_fail_initially(self, router_interface):
|
||||
"""Test heartbeat mechanism - should fail initially."""
|
||||
# Heartbeat without connection should fail
|
||||
result = await router_interface.send_heartbeat()
|
||||
|
||||
# This will fail initially
|
||||
assert result is False
|
||||
assert router_interface.last_heartbeat is None
|
||||
|
||||
# Connect and try heartbeat
|
||||
await router_interface.connect()
|
||||
await router_interface.connect() # Second attempt succeeds
|
||||
|
||||
result = await router_interface.send_heartbeat()
|
||||
assert result is True
|
||||
assert router_interface.last_heartbeat is not None
|
||||
|
||||
|
||||
class TestMultiRouterManagement:
|
||||
"""Test management of multiple routers."""
|
||||
|
||||
@pytest.fixture
|
||||
def router_manager(self):
|
||||
"""Create router manager for testing."""
|
||||
class RouterManager:
|
||||
def __init__(self):
|
||||
self.routers = {}
|
||||
self.active_connections = 0
|
||||
|
||||
async def add_router(self, router_id: str, ip_address: str) -> bool:
|
||||
"""Add a router to management."""
|
||||
if router_id in self.routers:
|
||||
return False
|
||||
|
||||
router = MockRouterInterface(router_id, ip_address)
|
||||
self.routers[router_id] = router
|
||||
return True
|
||||
|
||||
async def connect_router(self, router_id: str) -> bool:
|
||||
"""Connect to a specific router."""
|
||||
if router_id not in self.routers:
|
||||
return False
|
||||
|
||||
router = self.routers[router_id]
|
||||
|
||||
# Try connecting twice (mock fails first time)
|
||||
success = await router.connect()
|
||||
if not success:
|
||||
success = await router.connect()
|
||||
|
||||
if success:
|
||||
self.active_connections += 1
|
||||
|
||||
return success
|
||||
|
||||
async def authenticate_router(self, router_id: str, username: str, password: str) -> bool:
|
||||
"""Authenticate with a router."""
|
||||
if router_id not in self.routers:
|
||||
return False
|
||||
|
||||
router = self.routers[router_id]
|
||||
return await router.authenticate(username, password)
|
||||
|
||||
async def get_all_status(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get status of all routers."""
|
||||
status = {}
|
||||
for router_id, router in self.routers.items():
|
||||
status[router_id] = await router.get_status()
|
||||
return status
|
||||
|
||||
async def start_all_csi_streaming(self, config: Dict[str, Any]) -> Dict[str, bool]:
|
||||
"""Start CSI streaming on all authenticated routers."""
|
||||
results = {}
|
||||
for router_id, router in self.routers.items():
|
||||
if router.is_authenticated:
|
||||
results[router_id] = await router.start_csi_streaming(config)
|
||||
else:
|
||||
results[router_id] = False
|
||||
return results
|
||||
|
||||
return RouterManager()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_router_addition_should_fail_initially(self, router_manager):
|
||||
"""Test adding multiple routers - should fail initially."""
|
||||
# Add first router
|
||||
result1 = await router_manager.add_router("router_001", "192.168.1.100")
|
||||
|
||||
# This will fail initially
|
||||
assert result1 is True
|
||||
assert "router_001" in router_manager.routers
|
||||
|
||||
# Add second router
|
||||
result2 = await router_manager.add_router("router_002", "192.168.1.101")
|
||||
assert result2 is True
|
||||
assert "router_002" in router_manager.routers
|
||||
|
||||
# Try to add duplicate router
|
||||
result3 = await router_manager.add_router("router_001", "192.168.1.102")
|
||||
assert result3 is False
|
||||
assert len(router_manager.routers) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_router_connections_should_fail_initially(self, router_manager):
|
||||
"""Test concurrent router connections - should fail initially."""
|
||||
# Add multiple routers
|
||||
await router_manager.add_router("router_001", "192.168.1.100")
|
||||
await router_manager.add_router("router_002", "192.168.1.101")
|
||||
await router_manager.add_router("router_003", "192.168.1.102")
|
||||
|
||||
# Connect to all routers concurrently
|
||||
connection_tasks = [
|
||||
router_manager.connect_router("router_001"),
|
||||
router_manager.connect_router("router_002"),
|
||||
router_manager.connect_router("router_003")
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*connection_tasks)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == 3
|
||||
assert all(results) # All connections should succeed
|
||||
assert router_manager.active_connections == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_status_aggregation_should_fail_initially(self, router_manager):
|
||||
"""Test router status aggregation - should fail initially."""
|
||||
# Add and connect routers
|
||||
await router_manager.add_router("router_001", "192.168.1.100")
|
||||
await router_manager.add_router("router_002", "192.168.1.101")
|
||||
|
||||
await router_manager.connect_router("router_001")
|
||||
await router_manager.connect_router("router_002")
|
||||
|
||||
# Get all status
|
||||
all_status = await router_manager.get_all_status()
|
||||
|
||||
# This will fail initially
|
||||
assert isinstance(all_status, dict)
|
||||
assert len(all_status) == 2
|
||||
assert "router_001" in all_status
|
||||
assert "router_002" in all_status
|
||||
|
||||
# Verify status structure
|
||||
for router_id, status in all_status.items():
|
||||
assert "router_id" in status
|
||||
assert "ip_address" in status
|
||||
assert "is_connected" in status
|
||||
assert status["is_connected"] is True
|
||||
|
||||
|
||||
class TestCSIDataCollection:
|
||||
"""Test CSI data collection from routers."""
|
||||
|
||||
@pytest.fixture
|
||||
def csi_collector(self):
|
||||
"""Create CSI data collector."""
|
||||
class CSICollector:
|
||||
def __init__(self):
|
||||
self.collected_data = []
|
||||
self.is_collecting = False
|
||||
self.collection_rate = 0
|
||||
|
||||
async def start_collection(self, router_interfaces: List[MockRouterInterface]) -> bool:
|
||||
"""Start CSI data collection."""
|
||||
# This should fail initially
|
||||
return False
|
||||
|
||||
async def stop_collection(self) -> bool:
|
||||
"""Stop CSI data collection."""
|
||||
if self.is_collecting:
|
||||
self.is_collecting = False
|
||||
return True
|
||||
return False
|
||||
|
||||
async def collect_frame(self, router_interface: MockRouterInterface) -> Optional[Dict[str, Any]]:
|
||||
"""Collect a single CSI frame."""
|
||||
if not router_interface.csi_streaming:
|
||||
return None
|
||||
|
||||
# Simulate CSI data
|
||||
return {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"router_id": router_interface.router_id,
|
||||
"amplitude": np.random.rand(64, 32).tolist(),
|
||||
"phase": np.random.rand(64, 32).tolist(),
|
||||
"frequency": 5.8e9,
|
||||
"bandwidth": 80e6,
|
||||
"antenna_count": 4,
|
||||
"subcarrier_count": 64,
|
||||
"signal_quality": np.random.uniform(0.7, 0.95)
|
||||
}
|
||||
|
||||
def get_collection_stats(self) -> Dict[str, Any]:
|
||||
"""Get collection statistics."""
|
||||
return {
|
||||
"total_frames": len(self.collected_data),
|
||||
"collection_rate": self.collection_rate,
|
||||
"is_collecting": self.is_collecting,
|
||||
"last_collection": self.collected_data[-1]["timestamp"] if self.collected_data else None
|
||||
}
|
||||
|
||||
return CSICollector()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_csi_collection_start_should_fail_initially(self, csi_collector):
|
||||
"""Test CSI collection start - should fail initially."""
|
||||
router_interfaces = [
|
||||
MockRouterInterface("router_001", "192.168.1.100"),
|
||||
MockRouterInterface("router_002", "192.168.1.101")
|
||||
]
|
||||
|
||||
result = await csi_collector.start_collection(router_interfaces)
|
||||
|
||||
# This will fail initially because the collector is designed to return False
|
||||
assert result is False
|
||||
assert csi_collector.is_collecting is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_frame_collection_should_fail_initially(self, csi_collector):
|
||||
"""Test single frame collection - should fail initially."""
|
||||
router = MockRouterInterface("router_001", "192.168.1.100")
|
||||
|
||||
# Without CSI streaming enabled
|
||||
frame = await csi_collector.collect_frame(router)
|
||||
|
||||
# This will fail initially
|
||||
assert frame is None
|
||||
|
||||
# Enable CSI streaming (manually for testing)
|
||||
router.csi_streaming = True
|
||||
frame = await csi_collector.collect_frame(router)
|
||||
|
||||
assert frame is not None
|
||||
assert "timestamp" in frame
|
||||
assert "router_id" in frame
|
||||
assert "amplitude" in frame
|
||||
assert "phase" in frame
|
||||
assert frame["router_id"] == "router_001"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collection_statistics_should_fail_initially(self, csi_collector):
|
||||
"""Test collection statistics - should fail initially."""
|
||||
stats = csi_collector.get_collection_stats()
|
||||
|
||||
# This will fail initially
|
||||
assert isinstance(stats, dict)
|
||||
assert "total_frames" in stats
|
||||
assert "collection_rate" in stats
|
||||
assert "is_collecting" in stats
|
||||
assert "last_collection" in stats
|
||||
|
||||
assert stats["total_frames"] == 0
|
||||
assert stats["is_collecting"] is False
|
||||
assert stats["last_collection"] is None
|
||||
|
||||
|
||||
class TestHardwareErrorHandling:
|
||||
"""Test hardware error handling scenarios."""
|
||||
|
||||
@pytest.fixture
|
||||
def unreliable_router(self):
|
||||
"""Create unreliable router for error testing."""
|
||||
class UnreliableRouter(MockRouterInterface):
|
||||
def __init__(self, router_id: str, ip_address: str = "192.168.1.1"):
|
||||
super().__init__(router_id, ip_address)
|
||||
self.failure_rate = 0.3 # 30% failure rate
|
||||
self.connection_drops = 0
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Unreliable connection."""
|
||||
if np.random.random() < self.failure_rate:
|
||||
return False
|
||||
return await super().connect()
|
||||
|
||||
async def send_heartbeat(self) -> bool:
|
||||
"""Unreliable heartbeat."""
|
||||
if np.random.random() < self.failure_rate:
|
||||
self.is_connected = False
|
||||
self.connection_drops += 1
|
||||
return False
|
||||
return await super().send_heartbeat()
|
||||
|
||||
async def start_csi_streaming(self, config: Dict[str, Any]) -> bool:
|
||||
"""Unreliable CSI streaming."""
|
||||
if np.random.random() < self.failure_rate:
|
||||
return False
|
||||
|
||||
# Still return False for initial test failure
|
||||
return False
|
||||
|
||||
return UnreliableRouter("unreliable_router", "192.168.1.200")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_retry_mechanism_should_fail_initially(self, unreliable_router):
|
||||
"""Test connection retry mechanism - should fail initially."""
|
||||
max_retries = 5
|
||||
success = False
|
||||
|
||||
for attempt in range(max_retries):
|
||||
result = await unreliable_router.connect()
|
||||
if result:
|
||||
success = True
|
||||
break
|
||||
|
||||
# Wait before retry
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# This will fail initially due to randomness, but should eventually pass
|
||||
# The test demonstrates the need for retry logic
|
||||
assert success or unreliable_router.connection_attempts >= max_retries
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_drop_detection_should_fail_initially(self, unreliable_router):
|
||||
"""Test connection drop detection - should fail initially."""
|
||||
# Establish connection
|
||||
await unreliable_router.connect()
|
||||
await unreliable_router.connect() # Ensure connection
|
||||
|
||||
initial_drops = unreliable_router.connection_drops
|
||||
|
||||
# Send multiple heartbeats to trigger potential drops
|
||||
for _ in range(10):
|
||||
await unreliable_router.send_heartbeat()
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# This will fail initially
|
||||
# Should detect connection drops
|
||||
final_drops = unreliable_router.connection_drops
|
||||
assert final_drops >= initial_drops # May have detected drops
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hardware_timeout_handling_should_fail_initially(self):
|
||||
"""Test hardware timeout handling - should fail initially."""
|
||||
async def slow_operation():
|
||||
"""Simulate slow hardware operation."""
|
||||
await asyncio.sleep(2.0) # 2 second delay
|
||||
return "success"
|
||||
|
||||
# Test with timeout
|
||||
try:
|
||||
result = await asyncio.wait_for(slow_operation(), timeout=1.0)
|
||||
# This should not be reached
|
||||
assert False, "Operation should have timed out"
|
||||
except asyncio.TimeoutError:
|
||||
# This will fail initially because we expect timeout handling
|
||||
assert True # Timeout was properly handled
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_network_error_simulation_should_fail_initially(self):
|
||||
"""Test network error simulation - should fail initially."""
|
||||
class NetworkErrorRouter(MockRouterInterface):
|
||||
async def connect(self) -> bool:
|
||||
"""Simulate network error."""
|
||||
raise ConnectionError("Network unreachable")
|
||||
|
||||
router = NetworkErrorRouter("error_router", "192.168.1.999")
|
||||
|
||||
# This will fail initially
|
||||
with pytest.raises(ConnectionError, match="Network unreachable"):
|
||||
await router.connect()
|
||||
|
||||
|
||||
class TestHardwareConfiguration:
|
||||
"""Test hardware configuration management."""
|
||||
|
||||
@pytest.fixture
|
||||
def config_manager(self):
|
||||
"""Create configuration manager."""
|
||||
class ConfigManager:
|
||||
def __init__(self):
|
||||
self.default_config = {
|
||||
"frequency": 5.8e9,
|
||||
"bandwidth": 80e6,
|
||||
"sample_rate": 1000,
|
||||
"antenna_config": "4x4_mimo",
|
||||
"power_level": 20,
|
||||
"channel": 36
|
||||
}
|
||||
self.router_configs = {}
|
||||
|
||||
def get_router_config(self, router_id: str) -> Dict[str, Any]:
|
||||
"""Get configuration for a specific router."""
|
||||
return self.router_configs.get(router_id, self.default_config.copy())
|
||||
|
||||
def set_router_config(self, router_id: str, config: Dict[str, Any]) -> bool:
|
||||
"""Set configuration for a specific router."""
|
||||
# Validate configuration
|
||||
required_fields = ["frequency", "bandwidth", "sample_rate"]
|
||||
if not all(field in config for field in required_fields):
|
||||
return False
|
||||
|
||||
self.router_configs[router_id] = config
|
||||
return True
|
||||
|
||||
def validate_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate router configuration."""
|
||||
errors = []
|
||||
|
||||
# Frequency validation
|
||||
if "frequency" in config:
|
||||
freq = config["frequency"]
|
||||
if not (2.4e9 <= freq <= 6e9):
|
||||
errors.append("Frequency must be between 2.4GHz and 6GHz")
|
||||
|
||||
# Bandwidth validation
|
||||
if "bandwidth" in config:
|
||||
bw = config["bandwidth"]
|
||||
if bw not in [20e6, 40e6, 80e6, 160e6]:
|
||||
errors.append("Bandwidth must be 20, 40, 80, or 160 MHz")
|
||||
|
||||
# Sample rate validation
|
||||
if "sample_rate" in config:
|
||||
sr = config["sample_rate"]
|
||||
if not (100 <= sr <= 10000):
|
||||
errors.append("Sample rate must be between 100 and 10000 Hz")
|
||||
|
||||
return {
|
||||
"valid": len(errors) == 0,
|
||||
"errors": errors
|
||||
}
|
||||
|
||||
return ConfigManager()
|
||||
|
||||
def test_default_configuration_should_fail_initially(self, config_manager):
|
||||
"""Test default configuration retrieval - should fail initially."""
|
||||
config = config_manager.get_router_config("new_router")
|
||||
|
||||
# This will fail initially
|
||||
assert isinstance(config, dict)
|
||||
assert "frequency" in config
|
||||
assert "bandwidth" in config
|
||||
assert "sample_rate" in config
|
||||
assert "antenna_config" in config
|
||||
assert config["frequency"] == 5.8e9
|
||||
assert config["bandwidth"] == 80e6
|
||||
|
||||
def test_configuration_validation_should_fail_initially(self, config_manager):
|
||||
"""Test configuration validation - should fail initially."""
|
||||
# Valid configuration
|
||||
valid_config = {
|
||||
"frequency": 5.8e9,
|
||||
"bandwidth": 80e6,
|
||||
"sample_rate": 1000
|
||||
}
|
||||
|
||||
result = config_manager.validate_config(valid_config)
|
||||
|
||||
# This will fail initially
|
||||
assert result["valid"] is True
|
||||
assert len(result["errors"]) == 0
|
||||
|
||||
# Invalid configuration
|
||||
invalid_config = {
|
||||
"frequency": 10e9, # Too high
|
||||
"bandwidth": 100e6, # Invalid
|
||||
"sample_rate": 50 # Too low
|
||||
}
|
||||
|
||||
result = config_manager.validate_config(invalid_config)
|
||||
assert result["valid"] is False
|
||||
assert len(result["errors"]) == 3
|
||||
|
||||
def test_router_specific_configuration_should_fail_initially(self, config_manager):
|
||||
"""Test router-specific configuration - should fail initially."""
|
||||
router_id = "router_001"
|
||||
custom_config = {
|
||||
"frequency": 2.4e9,
|
||||
"bandwidth": 40e6,
|
||||
"sample_rate": 500,
|
||||
"antenna_config": "2x2_mimo"
|
||||
}
|
||||
|
||||
# Set custom configuration
|
||||
result = config_manager.set_router_config(router_id, custom_config)
|
||||
|
||||
# This will fail initially
|
||||
assert result is True
|
||||
|
||||
# Retrieve custom configuration
|
||||
retrieved_config = config_manager.get_router_config(router_id)
|
||||
assert retrieved_config["frequency"] == 2.4e9
|
||||
assert retrieved_config["bandwidth"] == 40e6
|
||||
assert retrieved_config["antenna_config"] == "2x2_mimo"
|
||||
|
||||
# Test invalid configuration
|
||||
invalid_config = {"frequency": 5.8e9} # Missing required fields
|
||||
result = config_manager.set_router_config(router_id, invalid_config)
|
||||
assert result is False
|
||||
577
tests/integration/test_pose_pipeline.py
Normal file
577
tests/integration/test_pose_pipeline.py
Normal file
@@ -0,0 +1,577 @@
|
||||
"""
|
||||
Integration tests for end-to-end pose estimation pipeline.
|
||||
|
||||
Tests the complete pose estimation workflow from CSI data to pose results.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import json
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CSIData:
|
||||
"""CSI data structure for testing."""
|
||||
timestamp: datetime
|
||||
router_id: str
|
||||
amplitude: np.ndarray
|
||||
phase: np.ndarray
|
||||
frequency: float
|
||||
bandwidth: float
|
||||
antenna_count: int
|
||||
subcarrier_count: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class PoseResult:
|
||||
"""Pose estimation result structure."""
|
||||
timestamp: datetime
|
||||
frame_id: str
|
||||
persons: List[Dict[str, Any]]
|
||||
zone_summary: Dict[str, int]
|
||||
processing_time_ms: float
|
||||
confidence_scores: List[float]
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class MockCSIProcessor:
|
||||
"""Mock CSI data processor."""
|
||||
|
||||
def __init__(self):
|
||||
self.is_initialized = False
|
||||
self.processing_enabled = True
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the processor."""
|
||||
self.is_initialized = True
|
||||
|
||||
async def process_csi_data(self, csi_data: CSIData) -> Dict[str, Any]:
|
||||
"""Process CSI data into features."""
|
||||
if not self.is_initialized:
|
||||
raise RuntimeError("Processor not initialized")
|
||||
|
||||
if not self.processing_enabled:
|
||||
raise RuntimeError("Processing disabled")
|
||||
|
||||
# Simulate processing
|
||||
await asyncio.sleep(0.01) # Simulate processing time
|
||||
|
||||
return {
|
||||
"features": np.random.rand(64, 32).tolist(), # Mock feature matrix
|
||||
"quality_score": 0.85,
|
||||
"signal_strength": -45.2,
|
||||
"noise_level": -78.1,
|
||||
"processed_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
def set_processing_enabled(self, enabled: bool):
|
||||
"""Enable/disable processing."""
|
||||
self.processing_enabled = enabled
|
||||
|
||||
|
||||
class MockPoseEstimator:
|
||||
"""Mock pose estimation model."""
|
||||
|
||||
def __init__(self):
|
||||
self.is_loaded = False
|
||||
self.model_version = "1.0.0"
|
||||
self.confidence_threshold = 0.5
|
||||
|
||||
async def load_model(self):
|
||||
"""Load the pose estimation model."""
|
||||
await asyncio.sleep(0.1) # Simulate model loading
|
||||
self.is_loaded = True
|
||||
|
||||
async def estimate_poses(self, features: np.ndarray) -> Dict[str, Any]:
|
||||
"""Estimate poses from features."""
|
||||
if not self.is_loaded:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
# Simulate pose estimation
|
||||
await asyncio.sleep(0.05) # Simulate inference time
|
||||
|
||||
# Generate mock pose data
|
||||
num_persons = np.random.randint(0, 4) # 0-3 persons
|
||||
persons = []
|
||||
|
||||
for i in range(num_persons):
|
||||
confidence = np.random.uniform(0.3, 0.95)
|
||||
if confidence >= self.confidence_threshold:
|
||||
persons.append({
|
||||
"person_id": f"person_{i}",
|
||||
"confidence": confidence,
|
||||
"bounding_box": {
|
||||
"x": np.random.uniform(0, 800),
|
||||
"y": np.random.uniform(0, 600),
|
||||
"width": np.random.uniform(50, 200),
|
||||
"height": np.random.uniform(100, 400)
|
||||
},
|
||||
"keypoints": [
|
||||
{
|
||||
"name": "head",
|
||||
"x": np.random.uniform(0, 800),
|
||||
"y": np.random.uniform(0, 200),
|
||||
"confidence": np.random.uniform(0.5, 0.95)
|
||||
},
|
||||
{
|
||||
"name": "torso",
|
||||
"x": np.random.uniform(0, 800),
|
||||
"y": np.random.uniform(200, 400),
|
||||
"confidence": np.random.uniform(0.5, 0.95)
|
||||
}
|
||||
],
|
||||
"activity": "standing" if np.random.random() > 0.2 else "sitting"
|
||||
})
|
||||
|
||||
return {
|
||||
"persons": persons,
|
||||
"processing_time_ms": np.random.uniform(20, 80),
|
||||
"model_version": self.model_version,
|
||||
"confidence_threshold": self.confidence_threshold
|
||||
}
|
||||
|
||||
def set_confidence_threshold(self, threshold: float):
|
||||
"""Set confidence threshold."""
|
||||
self.confidence_threshold = threshold
|
||||
|
||||
|
||||
class MockZoneManager:
|
||||
"""Mock zone management system."""
|
||||
|
||||
def __init__(self):
|
||||
self.zones = {
|
||||
"zone1": {"id": "zone1", "name": "Zone 1", "bounds": [0, 0, 400, 600]},
|
||||
"zone2": {"id": "zone2", "name": "Zone 2", "bounds": [400, 0, 800, 600]},
|
||||
"zone3": {"id": "zone3", "name": "Zone 3", "bounds": [0, 300, 800, 600]}
|
||||
}
|
||||
|
||||
def assign_persons_to_zones(self, persons: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Assign detected persons to zones."""
|
||||
zone_summary = {zone_id: 0 for zone_id in self.zones.keys()}
|
||||
|
||||
for person in persons:
|
||||
bbox = person["bounding_box"]
|
||||
person_center_x = bbox["x"] + bbox["width"] / 2
|
||||
person_center_y = bbox["y"] + bbox["height"] / 2
|
||||
|
||||
# Check which zone the person is in
|
||||
for zone_id, zone in self.zones.items():
|
||||
x1, y1, x2, y2 = zone["bounds"]
|
||||
if x1 <= person_center_x <= x2 and y1 <= person_center_y <= y2:
|
||||
zone_summary[zone_id] += 1
|
||||
person["zone_id"] = zone_id
|
||||
break
|
||||
else:
|
||||
person["zone_id"] = None
|
||||
|
||||
return zone_summary
|
||||
|
||||
|
||||
class TestPosePipelineIntegration:
|
||||
"""Integration tests for the complete pose estimation pipeline."""
|
||||
|
||||
@pytest.fixture
|
||||
def csi_processor(self):
|
||||
"""Create CSI processor."""
|
||||
return MockCSIProcessor()
|
||||
|
||||
@pytest.fixture
|
||||
def pose_estimator(self):
|
||||
"""Create pose estimator."""
|
||||
return MockPoseEstimator()
|
||||
|
||||
@pytest.fixture
|
||||
def zone_manager(self):
|
||||
"""Create zone manager."""
|
||||
return MockZoneManager()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_csi_data(self):
|
||||
"""Create sample CSI data."""
|
||||
return CSIData(
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
amplitude=np.random.rand(64, 32),
|
||||
phase=np.random.rand(64, 32),
|
||||
frequency=5.8e9, # 5.8 GHz
|
||||
bandwidth=80e6, # 80 MHz
|
||||
antenna_count=4,
|
||||
subcarrier_count=64
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
async def pose_pipeline(self, csi_processor, pose_estimator, zone_manager):
|
||||
"""Create complete pose pipeline."""
|
||||
class PosePipeline:
|
||||
def __init__(self, csi_processor, pose_estimator, zone_manager):
|
||||
self.csi_processor = csi_processor
|
||||
self.pose_estimator = pose_estimator
|
||||
self.zone_manager = zone_manager
|
||||
self.is_initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the pipeline."""
|
||||
await self.csi_processor.initialize()
|
||||
await self.pose_estimator.load_model()
|
||||
self.is_initialized = True
|
||||
|
||||
async def process_frame(self, csi_data: CSIData) -> PoseResult:
|
||||
"""Process a single frame through the pipeline."""
|
||||
if not self.is_initialized:
|
||||
raise RuntimeError("Pipeline not initialized")
|
||||
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
# Step 1: Process CSI data
|
||||
processed_data = await self.csi_processor.process_csi_data(csi_data)
|
||||
|
||||
# Step 2: Extract features
|
||||
features = np.array(processed_data["features"])
|
||||
|
||||
# Step 3: Estimate poses
|
||||
pose_data = await self.pose_estimator.estimate_poses(features)
|
||||
|
||||
# Step 4: Assign to zones
|
||||
zone_summary = self.zone_manager.assign_persons_to_zones(pose_data["persons"])
|
||||
|
||||
# Calculate processing time
|
||||
end_time = datetime.utcnow()
|
||||
processing_time = (end_time - start_time).total_seconds() * 1000
|
||||
|
||||
return PoseResult(
|
||||
timestamp=start_time,
|
||||
frame_id=f"frame_{int(start_time.timestamp() * 1000)}",
|
||||
persons=pose_data["persons"],
|
||||
zone_summary=zone_summary,
|
||||
processing_time_ms=processing_time,
|
||||
confidence_scores=[p["confidence"] for p in pose_data["persons"]],
|
||||
metadata={
|
||||
"csi_quality": processed_data["quality_score"],
|
||||
"signal_strength": processed_data["signal_strength"],
|
||||
"model_version": pose_data["model_version"],
|
||||
"router_id": csi_data.router_id
|
||||
}
|
||||
)
|
||||
|
||||
pipeline = PosePipeline(csi_processor, pose_estimator, zone_manager)
|
||||
await pipeline.initialize()
|
||||
return pipeline
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_initialization_should_fail_initially(self, csi_processor, pose_estimator, zone_manager):
|
||||
"""Test pipeline initialization - should fail initially."""
|
||||
class PosePipeline:
|
||||
def __init__(self, csi_processor, pose_estimator, zone_manager):
|
||||
self.csi_processor = csi_processor
|
||||
self.pose_estimator = pose_estimator
|
||||
self.zone_manager = zone_manager
|
||||
self.is_initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
await self.csi_processor.initialize()
|
||||
await self.pose_estimator.load_model()
|
||||
self.is_initialized = True
|
||||
|
||||
pipeline = PosePipeline(csi_processor, pose_estimator, zone_manager)
|
||||
|
||||
# Initially not initialized
|
||||
assert not pipeline.is_initialized
|
||||
assert not csi_processor.is_initialized
|
||||
assert not pose_estimator.is_loaded
|
||||
|
||||
# Initialize pipeline
|
||||
await pipeline.initialize()
|
||||
|
||||
# This will fail initially
|
||||
assert pipeline.is_initialized
|
||||
assert csi_processor.is_initialized
|
||||
assert pose_estimator.is_loaded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_pose_estimation_should_fail_initially(self, pose_pipeline, sample_csi_data):
|
||||
"""Test end-to-end pose estimation - should fail initially."""
|
||||
result = await pose_pipeline.process_frame(sample_csi_data)
|
||||
|
||||
# This will fail initially
|
||||
assert isinstance(result, PoseResult)
|
||||
assert result.timestamp is not None
|
||||
assert result.frame_id.startswith("frame_")
|
||||
assert isinstance(result.persons, list)
|
||||
assert isinstance(result.zone_summary, dict)
|
||||
assert result.processing_time_ms > 0
|
||||
assert isinstance(result.confidence_scores, list)
|
||||
assert isinstance(result.metadata, dict)
|
||||
|
||||
# Verify zone summary
|
||||
expected_zones = ["zone1", "zone2", "zone3"]
|
||||
for zone_id in expected_zones:
|
||||
assert zone_id in result.zone_summary
|
||||
assert isinstance(result.zone_summary[zone_id], int)
|
||||
assert result.zone_summary[zone_id] >= 0
|
||||
|
||||
# Verify metadata
|
||||
assert "csi_quality" in result.metadata
|
||||
assert "signal_strength" in result.metadata
|
||||
assert "model_version" in result.metadata
|
||||
assert "router_id" in result.metadata
|
||||
assert result.metadata["router_id"] == sample_csi_data.router_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_with_multiple_frames_should_fail_initially(self, pose_pipeline):
|
||||
"""Test pipeline with multiple frames - should fail initially."""
|
||||
results = []
|
||||
|
||||
# Process multiple frames
|
||||
for i in range(5):
|
||||
csi_data = CSIData(
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id=f"router_{i % 2 + 1:03d}", # Alternate between router_001 and router_002
|
||||
amplitude=np.random.rand(64, 32),
|
||||
phase=np.random.rand(64, 32),
|
||||
frequency=5.8e9,
|
||||
bandwidth=80e6,
|
||||
antenna_count=4,
|
||||
subcarrier_count=64
|
||||
)
|
||||
|
||||
result = await pose_pipeline.process_frame(csi_data)
|
||||
results.append(result)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == 5
|
||||
|
||||
# Verify each result
|
||||
for i, result in enumerate(results):
|
||||
assert result.frame_id != results[0].frame_id if i > 0 else True
|
||||
assert result.metadata["router_id"] in ["router_001", "router_002"]
|
||||
assert result.processing_time_ms > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_error_handling_should_fail_initially(self, csi_processor, pose_estimator, zone_manager, sample_csi_data):
|
||||
"""Test pipeline error handling - should fail initially."""
|
||||
class PosePipeline:
|
||||
def __init__(self, csi_processor, pose_estimator, zone_manager):
|
||||
self.csi_processor = csi_processor
|
||||
self.pose_estimator = pose_estimator
|
||||
self.zone_manager = zone_manager
|
||||
self.is_initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
await self.csi_processor.initialize()
|
||||
await self.pose_estimator.load_model()
|
||||
self.is_initialized = True
|
||||
|
||||
async def process_frame(self, csi_data):
|
||||
if not self.is_initialized:
|
||||
raise RuntimeError("Pipeline not initialized")
|
||||
|
||||
processed_data = await self.csi_processor.process_csi_data(csi_data)
|
||||
features = np.array(processed_data["features"])
|
||||
pose_data = await self.pose_estimator.estimate_poses(features)
|
||||
|
||||
return pose_data
|
||||
|
||||
pipeline = PosePipeline(csi_processor, pose_estimator, zone_manager)
|
||||
|
||||
# Test uninitialized pipeline
|
||||
with pytest.raises(RuntimeError, match="Pipeline not initialized"):
|
||||
await pipeline.process_frame(sample_csi_data)
|
||||
|
||||
# Initialize pipeline
|
||||
await pipeline.initialize()
|
||||
|
||||
# Test with disabled CSI processor
|
||||
csi_processor.set_processing_enabled(False)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Processing disabled"):
|
||||
await pipeline.process_frame(sample_csi_data)
|
||||
|
||||
# This assertion will fail initially
|
||||
assert True # Test completed successfully
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confidence_threshold_filtering_should_fail_initially(self, pose_pipeline, sample_csi_data):
|
||||
"""Test confidence threshold filtering - should fail initially."""
|
||||
# Set high confidence threshold
|
||||
pose_pipeline.pose_estimator.set_confidence_threshold(0.9)
|
||||
|
||||
result = await pose_pipeline.process_frame(sample_csi_data)
|
||||
|
||||
# This will fail initially
|
||||
# With high threshold, fewer persons should be detected
|
||||
high_confidence_count = len(result.persons)
|
||||
|
||||
# Set low confidence threshold
|
||||
pose_pipeline.pose_estimator.set_confidence_threshold(0.1)
|
||||
|
||||
result = await pose_pipeline.process_frame(sample_csi_data)
|
||||
low_confidence_count = len(result.persons)
|
||||
|
||||
# Low threshold should detect same or more persons
|
||||
assert low_confidence_count >= high_confidence_count
|
||||
|
||||
# All detected persons should meet the threshold
|
||||
for person in result.persons:
|
||||
assert person["confidence"] >= 0.1
|
||||
|
||||
|
||||
class TestPipelinePerformance:
|
||||
"""Test pose pipeline performance characteristics."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_throughput_should_fail_initially(self, pose_pipeline):
|
||||
"""Test pipeline throughput - should fail initially."""
|
||||
frame_count = 10
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
# Process multiple frames
|
||||
for i in range(frame_count):
|
||||
csi_data = CSIData(
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
amplitude=np.random.rand(64, 32),
|
||||
phase=np.random.rand(64, 32),
|
||||
frequency=5.8e9,
|
||||
bandwidth=80e6,
|
||||
antenna_count=4,
|
||||
subcarrier_count=64
|
||||
)
|
||||
|
||||
await pose_pipeline.process_frame(csi_data)
|
||||
|
||||
end_time = datetime.utcnow()
|
||||
total_time = (end_time - start_time).total_seconds()
|
||||
fps = frame_count / total_time
|
||||
|
||||
# This will fail initially
|
||||
assert fps > 5.0 # Should process at least 5 FPS
|
||||
assert total_time < 5.0 # Should complete 10 frames in under 5 seconds
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_frame_processing_should_fail_initially(self, pose_pipeline):
|
||||
"""Test concurrent frame processing - should fail initially."""
|
||||
async def process_single_frame(frame_id: int):
|
||||
csi_data = CSIData(
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id=f"router_{frame_id % 3 + 1:03d}",
|
||||
amplitude=np.random.rand(64, 32),
|
||||
phase=np.random.rand(64, 32),
|
||||
frequency=5.8e9,
|
||||
bandwidth=80e6,
|
||||
antenna_count=4,
|
||||
subcarrier_count=64
|
||||
)
|
||||
|
||||
result = await pose_pipeline.process_frame(csi_data)
|
||||
return result.frame_id
|
||||
|
||||
# Process frames concurrently
|
||||
tasks = [process_single_frame(i) for i in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == 5
|
||||
assert len(set(results)) == 5 # All frame IDs should be unique
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_usage_stability_should_fail_initially(self, pose_pipeline):
|
||||
"""Test memory usage stability - should fail initially."""
|
||||
import psutil
|
||||
import os
|
||||
|
||||
process = psutil.Process(os.getpid())
|
||||
initial_memory = process.memory_info().rss
|
||||
|
||||
# Process many frames
|
||||
for i in range(50):
|
||||
csi_data = CSIData(
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
amplitude=np.random.rand(64, 32),
|
||||
phase=np.random.rand(64, 32),
|
||||
frequency=5.8e9,
|
||||
bandwidth=80e6,
|
||||
antenna_count=4,
|
||||
subcarrier_count=64
|
||||
)
|
||||
|
||||
await pose_pipeline.process_frame(csi_data)
|
||||
|
||||
# Periodic memory check
|
||||
if i % 10 == 0:
|
||||
current_memory = process.memory_info().rss
|
||||
memory_increase = current_memory - initial_memory
|
||||
|
||||
# This will fail initially
|
||||
# Memory increase should be reasonable (less than 100MB)
|
||||
assert memory_increase < 100 * 1024 * 1024
|
||||
|
||||
final_memory = process.memory_info().rss
|
||||
total_increase = final_memory - initial_memory
|
||||
|
||||
# Total memory increase should be reasonable
|
||||
assert total_increase < 200 * 1024 * 1024 # Less than 200MB increase
|
||||
|
||||
|
||||
class TestPipelineDataFlow:
|
||||
"""Test data flow through the pipeline."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_transformation_chain_should_fail_initially(self, csi_processor, pose_estimator, zone_manager, sample_csi_data):
|
||||
"""Test data transformation through the pipeline - should fail initially."""
|
||||
# Step 1: CSI processing
|
||||
await csi_processor.initialize()
|
||||
processed_data = await csi_processor.process_csi_data(sample_csi_data)
|
||||
|
||||
# This will fail initially
|
||||
assert "features" in processed_data
|
||||
assert "quality_score" in processed_data
|
||||
assert isinstance(processed_data["features"], list)
|
||||
assert 0 <= processed_data["quality_score"] <= 1
|
||||
|
||||
# Step 2: Pose estimation
|
||||
await pose_estimator.load_model()
|
||||
features = np.array(processed_data["features"])
|
||||
pose_data = await pose_estimator.estimate_poses(features)
|
||||
|
||||
assert "persons" in pose_data
|
||||
assert "processing_time_ms" in pose_data
|
||||
assert isinstance(pose_data["persons"], list)
|
||||
|
||||
# Step 3: Zone assignment
|
||||
zone_summary = zone_manager.assign_persons_to_zones(pose_data["persons"])
|
||||
|
||||
assert isinstance(zone_summary, dict)
|
||||
assert all(isinstance(count, int) for count in zone_summary.values())
|
||||
|
||||
# Verify person zone assignments
|
||||
for person in pose_data["persons"]:
|
||||
if "zone_id" in person and person["zone_id"]:
|
||||
assert person["zone_id"] in zone_summary
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_state_consistency_should_fail_initially(self, pose_pipeline, sample_csi_data):
|
||||
"""Test pipeline state consistency - should fail initially."""
|
||||
# Process the same frame multiple times
|
||||
results = []
|
||||
for _ in range(3):
|
||||
result = await pose_pipeline.process_frame(sample_csi_data)
|
||||
results.append(result)
|
||||
|
||||
# This will fail initially
|
||||
# Results should be consistent (same input should produce similar output)
|
||||
assert len(results) == 3
|
||||
|
||||
# All results should have the same router_id
|
||||
router_ids = [r.metadata["router_id"] for r in results]
|
||||
assert all(rid == router_ids[0] for rid in router_ids)
|
||||
|
||||
# Processing times should be reasonable and similar
|
||||
processing_times = [r.processing_time_ms for r in results]
|
||||
assert all(10 <= pt <= 200 for pt in processing_times) # Between 10ms and 200ms
|
||||
565
tests/integration/test_rate_limiting.py
Normal file
565
tests/integration/test_rate_limiting.py
Normal file
@@ -0,0 +1,565 @@
|
||||
"""
|
||||
Integration tests for rate limiting functionality.
|
||||
|
||||
Tests rate limit behavior, throttling, and quota management.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import time
|
||||
|
||||
from fastapi import HTTPException, status, Request, Response
|
||||
|
||||
|
||||
class MockRateLimiter:
|
||||
"""Mock rate limiter for testing."""
|
||||
|
||||
def __init__(self, requests_per_minute: int = 60, requests_per_hour: int = 1000):
|
||||
self.requests_per_minute = requests_per_minute
|
||||
self.requests_per_hour = requests_per_hour
|
||||
self.request_history = {}
|
||||
self.blocked_clients = set()
|
||||
|
||||
def _get_client_key(self, client_id: str, endpoint: str = None) -> str:
|
||||
"""Get client key for rate limiting."""
|
||||
return f"{client_id}:{endpoint}" if endpoint else client_id
|
||||
|
||||
def _cleanup_old_requests(self, client_key: str):
|
||||
"""Clean up old request records."""
|
||||
if client_key not in self.request_history:
|
||||
return
|
||||
|
||||
now = datetime.utcnow()
|
||||
minute_ago = now - timedelta(minutes=1)
|
||||
hour_ago = now - timedelta(hours=1)
|
||||
|
||||
# Keep only requests from the last hour
|
||||
self.request_history[client_key] = [
|
||||
req_time for req_time in self.request_history[client_key]
|
||||
if req_time > hour_ago
|
||||
]
|
||||
|
||||
def check_rate_limit(self, client_id: str, endpoint: str = None) -> Dict[str, Any]:
|
||||
"""Check if client is within rate limits."""
|
||||
client_key = self._get_client_key(client_id, endpoint)
|
||||
|
||||
if client_id in self.blocked_clients:
|
||||
return {
|
||||
"allowed": False,
|
||||
"reason": "Client blocked",
|
||||
"retry_after": 3600 # 1 hour
|
||||
}
|
||||
|
||||
self._cleanup_old_requests(client_key)
|
||||
|
||||
if client_key not in self.request_history:
|
||||
self.request_history[client_key] = []
|
||||
|
||||
now = datetime.utcnow()
|
||||
minute_ago = now - timedelta(minutes=1)
|
||||
|
||||
# Count requests in the last minute
|
||||
recent_requests = [
|
||||
req_time for req_time in self.request_history[client_key]
|
||||
if req_time > minute_ago
|
||||
]
|
||||
|
||||
# Count requests in the last hour
|
||||
hour_requests = len(self.request_history[client_key])
|
||||
|
||||
if len(recent_requests) >= self.requests_per_minute:
|
||||
return {
|
||||
"allowed": False,
|
||||
"reason": "Rate limit exceeded (per minute)",
|
||||
"retry_after": 60,
|
||||
"current_requests": len(recent_requests),
|
||||
"limit": self.requests_per_minute
|
||||
}
|
||||
|
||||
if hour_requests >= self.requests_per_hour:
|
||||
return {
|
||||
"allowed": False,
|
||||
"reason": "Rate limit exceeded (per hour)",
|
||||
"retry_after": 3600,
|
||||
"current_requests": hour_requests,
|
||||
"limit": self.requests_per_hour
|
||||
}
|
||||
|
||||
# Record this request
|
||||
self.request_history[client_key].append(now)
|
||||
|
||||
return {
|
||||
"allowed": True,
|
||||
"remaining_minute": self.requests_per_minute - len(recent_requests) - 1,
|
||||
"remaining_hour": self.requests_per_hour - hour_requests - 1,
|
||||
"reset_time": minute_ago + timedelta(minutes=1)
|
||||
}
|
||||
|
||||
def block_client(self, client_id: str):
|
||||
"""Block a client."""
|
||||
self.blocked_clients.add(client_id)
|
||||
|
||||
def unblock_client(self, client_id: str):
|
||||
"""Unblock a client."""
|
||||
self.blocked_clients.discard(client_id)
|
||||
|
||||
|
||||
class TestRateLimitingBasic:
|
||||
"""Test basic rate limiting functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def rate_limiter(self):
|
||||
"""Create rate limiter for testing."""
|
||||
return MockRateLimiter(requests_per_minute=5, requests_per_hour=100)
|
||||
|
||||
def test_rate_limit_within_bounds_should_fail_initially(self, rate_limiter):
|
||||
"""Test rate limiting within bounds - should fail initially."""
|
||||
client_id = "test-client-001"
|
||||
|
||||
# Make requests within limit
|
||||
for i in range(3):
|
||||
result = rate_limiter.check_rate_limit(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert result["allowed"] is True
|
||||
assert "remaining_minute" in result
|
||||
assert "remaining_hour" in result
|
||||
|
||||
def test_rate_limit_per_minute_exceeded_should_fail_initially(self, rate_limiter):
|
||||
"""Test per-minute rate limit exceeded - should fail initially."""
|
||||
client_id = "test-client-002"
|
||||
|
||||
# Make requests up to the limit
|
||||
for i in range(5):
|
||||
result = rate_limiter.check_rate_limit(client_id)
|
||||
assert result["allowed"] is True
|
||||
|
||||
# Next request should be blocked
|
||||
result = rate_limiter.check_rate_limit(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert result["allowed"] is False
|
||||
assert "per minute" in result["reason"]
|
||||
assert result["retry_after"] == 60
|
||||
assert result["current_requests"] == 5
|
||||
assert result["limit"] == 5
|
||||
|
||||
def test_rate_limit_per_hour_exceeded_should_fail_initially(self, rate_limiter):
|
||||
"""Test per-hour rate limit exceeded - should fail initially."""
|
||||
# Create rate limiter with very low hour limit for testing
|
||||
limiter = MockRateLimiter(requests_per_minute=10, requests_per_hour=3)
|
||||
client_id = "test-client-003"
|
||||
|
||||
# Make requests up to hour limit
|
||||
for i in range(3):
|
||||
result = limiter.check_rate_limit(client_id)
|
||||
assert result["allowed"] is True
|
||||
|
||||
# Next request should be blocked
|
||||
result = limiter.check_rate_limit(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert result["allowed"] is False
|
||||
assert "per hour" in result["reason"]
|
||||
assert result["retry_after"] == 3600
|
||||
|
||||
def test_blocked_client_should_fail_initially(self, rate_limiter):
|
||||
"""Test blocked client handling - should fail initially."""
|
||||
client_id = "blocked-client"
|
||||
|
||||
# Block the client
|
||||
rate_limiter.block_client(client_id)
|
||||
|
||||
# Request should be blocked
|
||||
result = rate_limiter.check_rate_limit(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert result["allowed"] is False
|
||||
assert result["reason"] == "Client blocked"
|
||||
assert result["retry_after"] == 3600
|
||||
|
||||
# Unblock and test
|
||||
rate_limiter.unblock_client(client_id)
|
||||
result = rate_limiter.check_rate_limit(client_id)
|
||||
assert result["allowed"] is True
|
||||
|
||||
def test_endpoint_specific_rate_limiting_should_fail_initially(self, rate_limiter):
|
||||
"""Test endpoint-specific rate limiting - should fail initially."""
|
||||
client_id = "test-client-004"
|
||||
|
||||
# Make requests to different endpoints
|
||||
result1 = rate_limiter.check_rate_limit(client_id, "/api/pose/current")
|
||||
result2 = rate_limiter.check_rate_limit(client_id, "/api/stream/status")
|
||||
|
||||
# This will fail initially
|
||||
assert result1["allowed"] is True
|
||||
assert result2["allowed"] is True
|
||||
|
||||
# Each endpoint should have separate rate limiting
|
||||
for i in range(4):
|
||||
rate_limiter.check_rate_limit(client_id, "/api/pose/current")
|
||||
|
||||
# Pose endpoint should be at limit, but stream should still work
|
||||
pose_result = rate_limiter.check_rate_limit(client_id, "/api/pose/current")
|
||||
stream_result = rate_limiter.check_rate_limit(client_id, "/api/stream/status")
|
||||
|
||||
assert pose_result["allowed"] is False
|
||||
assert stream_result["allowed"] is True
|
||||
|
||||
|
||||
class TestRateLimitMiddleware:
|
||||
"""Test rate limiting middleware functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request(self):
|
||||
"""Mock FastAPI request."""
|
||||
class MockRequest:
|
||||
def __init__(self, client_ip="127.0.0.1", path="/api/test", method="GET"):
|
||||
self.client = MagicMock()
|
||||
self.client.host = client_ip
|
||||
self.url = MagicMock()
|
||||
self.url.path = path
|
||||
self.method = method
|
||||
self.headers = {}
|
||||
self.state = MagicMock()
|
||||
|
||||
return MockRequest
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response(self):
|
||||
"""Mock FastAPI response."""
|
||||
class MockResponse:
|
||||
def __init__(self):
|
||||
self.status_code = 200
|
||||
self.headers = {}
|
||||
|
||||
return MockResponse()
|
||||
|
||||
@pytest.fixture
|
||||
def rate_limit_middleware(self, rate_limiter):
|
||||
"""Create rate limiting middleware."""
|
||||
class RateLimitMiddleware:
|
||||
def __init__(self, rate_limiter):
|
||||
self.rate_limiter = rate_limiter
|
||||
|
||||
async def __call__(self, request, call_next):
|
||||
# Get client identifier
|
||||
client_id = self._get_client_id(request)
|
||||
endpoint = request.url.path
|
||||
|
||||
# Check rate limit
|
||||
limit_result = self.rate_limiter.check_rate_limit(client_id, endpoint)
|
||||
|
||||
if not limit_result["allowed"]:
|
||||
# Return rate limit exceeded response
|
||||
response = Response(
|
||||
content=f"Rate limit exceeded: {limit_result['reason']}",
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS
|
||||
)
|
||||
response.headers["Retry-After"] = str(limit_result["retry_after"])
|
||||
response.headers["X-RateLimit-Limit"] = str(limit_result.get("limit", "unknown"))
|
||||
response.headers["X-RateLimit-Remaining"] = "0"
|
||||
return response
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add rate limit headers
|
||||
response.headers["X-RateLimit-Limit"] = str(self.rate_limiter.requests_per_minute)
|
||||
response.headers["X-RateLimit-Remaining"] = str(limit_result.get("remaining_minute", 0))
|
||||
response.headers["X-RateLimit-Reset"] = str(int(limit_result.get("reset_time", datetime.utcnow()).timestamp()))
|
||||
|
||||
return response
|
||||
|
||||
def _get_client_id(self, request):
|
||||
"""Get client identifier from request."""
|
||||
# Check for API key in headers
|
||||
api_key = request.headers.get("X-API-Key")
|
||||
if api_key:
|
||||
return f"api:{api_key}"
|
||||
|
||||
# Check for user ID in request state (from auth)
|
||||
if hasattr(request.state, "user") and request.state.user:
|
||||
return f"user:{request.state.user.get('id', 'unknown')}"
|
||||
|
||||
# Fall back to IP address
|
||||
return f"ip:{request.client.host}"
|
||||
|
||||
return RateLimitMiddleware(rate_limiter)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_allows_normal_requests_should_fail_initially(
|
||||
self, rate_limit_middleware, mock_request, mock_response
|
||||
):
|
||||
"""Test middleware allows normal requests - should fail initially."""
|
||||
request = mock_request()
|
||||
|
||||
async def mock_call_next(req):
|
||||
return mock_response
|
||||
|
||||
response = await rate_limit_middleware(request, mock_call_next)
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
assert "X-RateLimit-Limit" in response.headers
|
||||
assert "X-RateLimit-Remaining" in response.headers
|
||||
assert "X-RateLimit-Reset" in response.headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_blocks_excessive_requests_should_fail_initially(
|
||||
self, rate_limit_middleware, mock_request
|
||||
):
|
||||
"""Test middleware blocks excessive requests - should fail initially."""
|
||||
request = mock_request()
|
||||
|
||||
async def mock_call_next(req):
|
||||
response = Response(content="OK", status_code=200)
|
||||
return response
|
||||
|
||||
# Make requests up to the limit
|
||||
for i in range(5):
|
||||
response = await rate_limit_middleware(request, mock_call_next)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Next request should be blocked
|
||||
response = await rate_limit_middleware(request, mock_call_next)
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
assert "Retry-After" in response.headers
|
||||
assert "X-RateLimit-Remaining" in response.headers
|
||||
assert response.headers["X-RateLimit-Remaining"] == "0"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_client_identification_should_fail_initially(
|
||||
self, rate_limit_middleware, mock_request
|
||||
):
|
||||
"""Test middleware client identification - should fail initially."""
|
||||
# Test API key identification
|
||||
request_with_api_key = mock_request()
|
||||
request_with_api_key.headers["X-API-Key"] = "test-api-key-123"
|
||||
|
||||
# Test user identification
|
||||
request_with_user = mock_request()
|
||||
request_with_user.state.user = {"id": "user-123"}
|
||||
|
||||
# Test IP identification
|
||||
request_with_ip = mock_request(client_ip="192.168.1.100")
|
||||
|
||||
async def mock_call_next(req):
|
||||
return Response(content="OK", status_code=200)
|
||||
|
||||
# Each should be treated as different clients
|
||||
response1 = await rate_limit_middleware(request_with_api_key, mock_call_next)
|
||||
response2 = await rate_limit_middleware(request_with_user, mock_call_next)
|
||||
response3 = await rate_limit_middleware(request_with_ip, mock_call_next)
|
||||
|
||||
# This will fail initially
|
||||
assert response1.status_code == 200
|
||||
assert response2.status_code == 200
|
||||
assert response3.status_code == 200
|
||||
|
||||
|
||||
class TestRateLimitingStrategies:
|
||||
"""Test different rate limiting strategies."""
|
||||
|
||||
@pytest.fixture
|
||||
def sliding_window_limiter(self):
|
||||
"""Create sliding window rate limiter."""
|
||||
class SlidingWindowLimiter:
|
||||
def __init__(self, window_size_seconds: int = 60, max_requests: int = 10):
|
||||
self.window_size = window_size_seconds
|
||||
self.max_requests = max_requests
|
||||
self.request_times = {}
|
||||
|
||||
def check_limit(self, client_id: str) -> Dict[str, Any]:
|
||||
now = time.time()
|
||||
|
||||
if client_id not in self.request_times:
|
||||
self.request_times[client_id] = []
|
||||
|
||||
# Remove old requests outside the window
|
||||
cutoff_time = now - self.window_size
|
||||
self.request_times[client_id] = [
|
||||
req_time for req_time in self.request_times[client_id]
|
||||
if req_time > cutoff_time
|
||||
]
|
||||
|
||||
# Check if we're at the limit
|
||||
if len(self.request_times[client_id]) >= self.max_requests:
|
||||
oldest_request = min(self.request_times[client_id])
|
||||
retry_after = int(oldest_request + self.window_size - now)
|
||||
|
||||
return {
|
||||
"allowed": False,
|
||||
"retry_after": max(retry_after, 1),
|
||||
"current_requests": len(self.request_times[client_id]),
|
||||
"limit": self.max_requests
|
||||
}
|
||||
|
||||
# Record this request
|
||||
self.request_times[client_id].append(now)
|
||||
|
||||
return {
|
||||
"allowed": True,
|
||||
"remaining": self.max_requests - len(self.request_times[client_id]),
|
||||
"window_reset": int(now + self.window_size)
|
||||
}
|
||||
|
||||
return SlidingWindowLimiter(window_size_seconds=10, max_requests=3)
|
||||
|
||||
@pytest.fixture
|
||||
def token_bucket_limiter(self):
|
||||
"""Create token bucket rate limiter."""
|
||||
class TokenBucketLimiter:
|
||||
def __init__(self, capacity: int = 10, refill_rate: float = 1.0):
|
||||
self.capacity = capacity
|
||||
self.refill_rate = refill_rate # tokens per second
|
||||
self.buckets = {}
|
||||
|
||||
def check_limit(self, client_id: str) -> Dict[str, Any]:
|
||||
now = time.time()
|
||||
|
||||
if client_id not in self.buckets:
|
||||
self.buckets[client_id] = {
|
||||
"tokens": self.capacity,
|
||||
"last_refill": now
|
||||
}
|
||||
|
||||
bucket = self.buckets[client_id]
|
||||
|
||||
# Refill tokens based on time elapsed
|
||||
time_elapsed = now - bucket["last_refill"]
|
||||
tokens_to_add = time_elapsed * self.refill_rate
|
||||
bucket["tokens"] = min(self.capacity, bucket["tokens"] + tokens_to_add)
|
||||
bucket["last_refill"] = now
|
||||
|
||||
# Check if we have tokens available
|
||||
if bucket["tokens"] < 1:
|
||||
return {
|
||||
"allowed": False,
|
||||
"retry_after": int((1 - bucket["tokens"]) / self.refill_rate),
|
||||
"tokens_remaining": bucket["tokens"]
|
||||
}
|
||||
|
||||
# Consume a token
|
||||
bucket["tokens"] -= 1
|
||||
|
||||
return {
|
||||
"allowed": True,
|
||||
"tokens_remaining": bucket["tokens"]
|
||||
}
|
||||
|
||||
return TokenBucketLimiter(capacity=5, refill_rate=0.5) # 0.5 tokens per second
|
||||
|
||||
def test_sliding_window_limiter_should_fail_initially(self, sliding_window_limiter):
|
||||
"""Test sliding window rate limiter - should fail initially."""
|
||||
client_id = "sliding-test-client"
|
||||
|
||||
# Make requests up to limit
|
||||
for i in range(3):
|
||||
result = sliding_window_limiter.check_limit(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert result["allowed"] is True
|
||||
assert "remaining" in result
|
||||
|
||||
# Next request should be blocked
|
||||
result = sliding_window_limiter.check_limit(client_id)
|
||||
assert result["allowed"] is False
|
||||
assert result["current_requests"] == 3
|
||||
assert result["limit"] == 3
|
||||
|
||||
def test_token_bucket_limiter_should_fail_initially(self, token_bucket_limiter):
|
||||
"""Test token bucket rate limiter - should fail initially."""
|
||||
client_id = "bucket-test-client"
|
||||
|
||||
# Make requests up to capacity
|
||||
for i in range(5):
|
||||
result = token_bucket_limiter.check_limit(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert result["allowed"] is True
|
||||
assert "tokens_remaining" in result
|
||||
|
||||
# Next request should be blocked (no tokens left)
|
||||
result = token_bucket_limiter.check_limit(client_id)
|
||||
assert result["allowed"] is False
|
||||
assert result["tokens_remaining"] < 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_bucket_refill_should_fail_initially(self, token_bucket_limiter):
|
||||
"""Test token bucket refill mechanism - should fail initially."""
|
||||
client_id = "refill-test-client"
|
||||
|
||||
# Exhaust all tokens
|
||||
for i in range(5):
|
||||
token_bucket_limiter.check_limit(client_id)
|
||||
|
||||
# Should be blocked
|
||||
result = token_bucket_limiter.check_limit(client_id)
|
||||
assert result["allowed"] is False
|
||||
|
||||
# Wait for refill (simulate time passing)
|
||||
await asyncio.sleep(2.1) # Wait for 1 token to be refilled (0.5 tokens/sec * 2.1 sec > 1)
|
||||
|
||||
# Should now be allowed
|
||||
result = token_bucket_limiter.check_limit(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert result["allowed"] is True
|
||||
|
||||
|
||||
class TestRateLimitingPerformance:
|
||||
"""Test rate limiting performance characteristics."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_rate_limit_checks_should_fail_initially(self):
|
||||
"""Test concurrent rate limit checks - should fail initially."""
|
||||
rate_limiter = MockRateLimiter(requests_per_minute=100, requests_per_hour=1000)
|
||||
|
||||
async def make_request(client_id: str, request_id: int):
|
||||
result = rate_limiter.check_rate_limit(f"{client_id}-{request_id}")
|
||||
return result["allowed"]
|
||||
|
||||
# Create many concurrent requests
|
||||
tasks = [
|
||||
make_request("concurrent-client", i)
|
||||
for i in range(50)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == 50
|
||||
assert all(results) # All should be allowed since they're different clients
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limiter_memory_cleanup_should_fail_initially(self):
|
||||
"""Test rate limiter memory cleanup - should fail initially."""
|
||||
rate_limiter = MockRateLimiter(requests_per_minute=10, requests_per_hour=100)
|
||||
|
||||
# Make requests for many different clients
|
||||
for i in range(100):
|
||||
rate_limiter.check_rate_limit(f"client-{i}")
|
||||
|
||||
initial_memory_size = len(rate_limiter.request_history)
|
||||
|
||||
# Simulate time passing and cleanup
|
||||
for client_key in list(rate_limiter.request_history.keys()):
|
||||
rate_limiter._cleanup_old_requests(client_key)
|
||||
|
||||
# This will fail initially
|
||||
assert initial_memory_size == 100
|
||||
|
||||
# After cleanup, old entries should be removed
|
||||
# (In a real implementation, this would clean up old timestamps)
|
||||
final_memory_size = len([
|
||||
key for key, history in rate_limiter.request_history.items()
|
||||
if history # Only count non-empty histories
|
||||
])
|
||||
|
||||
assert final_memory_size <= initial_memory_size
|
||||
729
tests/integration/test_streaming_pipeline.py
Normal file
729
tests/integration/test_streaming_pipeline.py
Normal file
@@ -0,0 +1,729 @@
|
||||
"""
|
||||
Integration tests for real-time streaming pipeline.
|
||||
|
||||
Tests the complete real-time data flow from CSI collection to client delivery.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional, AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import json
|
||||
import queue
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamFrame:
|
||||
"""Streaming frame data structure."""
|
||||
frame_id: str
|
||||
timestamp: datetime
|
||||
router_id: str
|
||||
pose_data: Dict[str, Any]
|
||||
processing_time_ms: float
|
||||
quality_score: float
|
||||
|
||||
|
||||
class MockStreamBuffer:
|
||||
"""Mock streaming buffer for testing."""
|
||||
|
||||
def __init__(self, max_size: int = 100):
|
||||
self.max_size = max_size
|
||||
self.buffer = asyncio.Queue(maxsize=max_size)
|
||||
self.dropped_frames = 0
|
||||
self.total_frames = 0
|
||||
|
||||
async def put_frame(self, frame: StreamFrame) -> bool:
|
||||
"""Add frame to buffer."""
|
||||
self.total_frames += 1
|
||||
|
||||
try:
|
||||
self.buffer.put_nowait(frame)
|
||||
return True
|
||||
except asyncio.QueueFull:
|
||||
self.dropped_frames += 1
|
||||
return False
|
||||
|
||||
async def get_frame(self, timeout: float = 1.0) -> Optional[StreamFrame]:
|
||||
"""Get frame from buffer."""
|
||||
try:
|
||||
return await asyncio.wait_for(self.buffer.get(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
return None
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get buffer statistics."""
|
||||
return {
|
||||
"buffer_size": self.buffer.qsize(),
|
||||
"max_size": self.max_size,
|
||||
"total_frames": self.total_frames,
|
||||
"dropped_frames": self.dropped_frames,
|
||||
"drop_rate": self.dropped_frames / max(self.total_frames, 1)
|
||||
}
|
||||
|
||||
|
||||
class MockStreamProcessor:
|
||||
"""Mock stream processor for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.is_running = False
|
||||
self.processing_rate = 30 # FPS
|
||||
self.frame_counter = 0
|
||||
self.error_rate = 0.0
|
||||
|
||||
async def start_processing(self, input_buffer: MockStreamBuffer, output_buffer: MockStreamBuffer):
|
||||
"""Start stream processing."""
|
||||
self.is_running = True
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
# Get frame from input
|
||||
frame = await input_buffer.get_frame(timeout=0.1)
|
||||
if frame is None:
|
||||
continue
|
||||
|
||||
# Simulate processing error
|
||||
if np.random.random() < self.error_rate:
|
||||
continue # Skip frame due to error
|
||||
|
||||
# Process frame
|
||||
processed_frame = await self._process_frame(frame)
|
||||
|
||||
# Put to output buffer
|
||||
await output_buffer.put_frame(processed_frame)
|
||||
|
||||
# Control processing rate
|
||||
await asyncio.sleep(1.0 / self.processing_rate)
|
||||
|
||||
except Exception as e:
|
||||
# Handle processing errors
|
||||
continue
|
||||
|
||||
async def _process_frame(self, frame: StreamFrame) -> StreamFrame:
|
||||
"""Process a single frame."""
|
||||
# Simulate processing time
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Add processing metadata
|
||||
processed_pose_data = frame.pose_data.copy()
|
||||
processed_pose_data["processed_at"] = datetime.utcnow().isoformat()
|
||||
processed_pose_data["processor_id"] = "stream_processor_001"
|
||||
|
||||
return StreamFrame(
|
||||
frame_id=f"processed_{frame.frame_id}",
|
||||
timestamp=frame.timestamp,
|
||||
router_id=frame.router_id,
|
||||
pose_data=processed_pose_data,
|
||||
processing_time_ms=frame.processing_time_ms + 10, # Add processing overhead
|
||||
quality_score=frame.quality_score * 0.95 # Slight quality degradation
|
||||
)
|
||||
|
||||
def stop_processing(self):
|
||||
"""Stop stream processing."""
|
||||
self.is_running = False
|
||||
|
||||
def set_error_rate(self, error_rate: float):
|
||||
"""Set processing error rate."""
|
||||
self.error_rate = error_rate
|
||||
|
||||
|
||||
class MockWebSocketManager:
|
||||
"""Mock WebSocket manager for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.connected_clients = {}
|
||||
self.message_queue = asyncio.Queue()
|
||||
self.total_messages_sent = 0
|
||||
self.failed_sends = 0
|
||||
|
||||
async def add_client(self, client_id: str, websocket_mock) -> bool:
|
||||
"""Add WebSocket client."""
|
||||
if client_id in self.connected_clients:
|
||||
return False
|
||||
|
||||
self.connected_clients[client_id] = {
|
||||
"websocket": websocket_mock,
|
||||
"connected_at": datetime.utcnow(),
|
||||
"messages_sent": 0,
|
||||
"last_ping": datetime.utcnow()
|
||||
}
|
||||
return True
|
||||
|
||||
async def remove_client(self, client_id: str) -> bool:
|
||||
"""Remove WebSocket client."""
|
||||
if client_id in self.connected_clients:
|
||||
del self.connected_clients[client_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
async def broadcast_frame(self, frame: StreamFrame) -> Dict[str, bool]:
|
||||
"""Broadcast frame to all connected clients."""
|
||||
results = {}
|
||||
|
||||
message = {
|
||||
"type": "pose_update",
|
||||
"frame_id": frame.frame_id,
|
||||
"timestamp": frame.timestamp.isoformat(),
|
||||
"router_id": frame.router_id,
|
||||
"pose_data": frame.pose_data,
|
||||
"processing_time_ms": frame.processing_time_ms,
|
||||
"quality_score": frame.quality_score
|
||||
}
|
||||
|
||||
for client_id, client_info in self.connected_clients.items():
|
||||
try:
|
||||
# Simulate WebSocket send
|
||||
success = await self._send_to_client(client_id, message)
|
||||
results[client_id] = success
|
||||
|
||||
if success:
|
||||
client_info["messages_sent"] += 1
|
||||
self.total_messages_sent += 1
|
||||
else:
|
||||
self.failed_sends += 1
|
||||
|
||||
except Exception:
|
||||
results[client_id] = False
|
||||
self.failed_sends += 1
|
||||
|
||||
return results
|
||||
|
||||
async def _send_to_client(self, client_id: str, message: Dict[str, Any]) -> bool:
|
||||
"""Send message to specific client."""
|
||||
# Simulate network issues
|
||||
if np.random.random() < 0.05: # 5% failure rate
|
||||
return False
|
||||
|
||||
# Simulate send delay
|
||||
await asyncio.sleep(0.001)
|
||||
return True
|
||||
|
||||
def get_client_stats(self) -> Dict[str, Any]:
|
||||
"""Get client statistics."""
|
||||
return {
|
||||
"connected_clients": len(self.connected_clients),
|
||||
"total_messages_sent": self.total_messages_sent,
|
||||
"failed_sends": self.failed_sends,
|
||||
"clients": {
|
||||
client_id: {
|
||||
"messages_sent": info["messages_sent"],
|
||||
"connected_duration": (datetime.utcnow() - info["connected_at"]).total_seconds()
|
||||
}
|
||||
for client_id, info in self.connected_clients.items()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TestStreamingPipelineBasic:
|
||||
"""Test basic streaming pipeline functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def stream_buffer(self):
|
||||
"""Create stream buffer."""
|
||||
return MockStreamBuffer(max_size=50)
|
||||
|
||||
@pytest.fixture
|
||||
def stream_processor(self):
|
||||
"""Create stream processor."""
|
||||
return MockStreamProcessor()
|
||||
|
||||
@pytest.fixture
|
||||
def websocket_manager(self):
|
||||
"""Create WebSocket manager."""
|
||||
return MockWebSocketManager()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_frame(self):
|
||||
"""Create sample stream frame."""
|
||||
return StreamFrame(
|
||||
frame_id="frame_001",
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
pose_data={
|
||||
"persons": [
|
||||
{
|
||||
"person_id": "person_1",
|
||||
"confidence": 0.85,
|
||||
"bounding_box": {"x": 100, "y": 150, "width": 80, "height": 180},
|
||||
"activity": "standing"
|
||||
}
|
||||
],
|
||||
"zone_summary": {"zone1": 1, "zone2": 0}
|
||||
},
|
||||
processing_time_ms=45.2,
|
||||
quality_score=0.92
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffer_frame_operations_should_fail_initially(self, stream_buffer, sample_frame):
|
||||
"""Test buffer frame operations - should fail initially."""
|
||||
# Put frame in buffer
|
||||
result = await stream_buffer.put_frame(sample_frame)
|
||||
|
||||
# This will fail initially
|
||||
assert result is True
|
||||
|
||||
# Get frame from buffer
|
||||
retrieved_frame = await stream_buffer.get_frame()
|
||||
assert retrieved_frame is not None
|
||||
assert retrieved_frame.frame_id == sample_frame.frame_id
|
||||
assert retrieved_frame.router_id == sample_frame.router_id
|
||||
|
||||
# Buffer should be empty now
|
||||
empty_frame = await stream_buffer.get_frame(timeout=0.1)
|
||||
assert empty_frame is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffer_overflow_handling_should_fail_initially(self, sample_frame):
|
||||
"""Test buffer overflow handling - should fail initially."""
|
||||
small_buffer = MockStreamBuffer(max_size=2)
|
||||
|
||||
# Fill buffer to capacity
|
||||
result1 = await small_buffer.put_frame(sample_frame)
|
||||
result2 = await small_buffer.put_frame(sample_frame)
|
||||
|
||||
# This will fail initially
|
||||
assert result1 is True
|
||||
assert result2 is True
|
||||
|
||||
# Next frame should be dropped
|
||||
result3 = await small_buffer.put_frame(sample_frame)
|
||||
assert result3 is False
|
||||
|
||||
# Check statistics
|
||||
stats = small_buffer.get_stats()
|
||||
assert stats["total_frames"] == 3
|
||||
assert stats["dropped_frames"] == 1
|
||||
assert stats["drop_rate"] > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_processing_should_fail_initially(self, stream_processor, sample_frame):
|
||||
"""Test stream processing - should fail initially."""
|
||||
input_buffer = MockStreamBuffer()
|
||||
output_buffer = MockStreamBuffer()
|
||||
|
||||
# Add frame to input buffer
|
||||
await input_buffer.put_frame(sample_frame)
|
||||
|
||||
# Start processing task
|
||||
processing_task = asyncio.create_task(
|
||||
stream_processor.start_processing(input_buffer, output_buffer)
|
||||
)
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Stop processing
|
||||
stream_processor.stop_processing()
|
||||
await processing_task
|
||||
|
||||
# Check output
|
||||
processed_frame = await output_buffer.get_frame(timeout=0.1)
|
||||
|
||||
# This will fail initially
|
||||
assert processed_frame is not None
|
||||
assert processed_frame.frame_id.startswith("processed_")
|
||||
assert "processed_at" in processed_frame.pose_data
|
||||
assert processed_frame.processing_time_ms > sample_frame.processing_time_ms
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_client_management_should_fail_initially(self, websocket_manager):
|
||||
"""Test WebSocket client management - should fail initially."""
|
||||
mock_websocket = MagicMock()
|
||||
|
||||
# Add client
|
||||
result = await websocket_manager.add_client("client_001", mock_websocket)
|
||||
|
||||
# This will fail initially
|
||||
assert result is True
|
||||
assert "client_001" in websocket_manager.connected_clients
|
||||
|
||||
# Try to add duplicate client
|
||||
result = await websocket_manager.add_client("client_001", mock_websocket)
|
||||
assert result is False
|
||||
|
||||
# Remove client
|
||||
result = await websocket_manager.remove_client("client_001")
|
||||
assert result is True
|
||||
assert "client_001" not in websocket_manager.connected_clients
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_frame_broadcasting_should_fail_initially(self, websocket_manager, sample_frame):
|
||||
"""Test frame broadcasting - should fail initially."""
|
||||
# Add multiple clients
|
||||
for i in range(3):
|
||||
await websocket_manager.add_client(f"client_{i:03d}", MagicMock())
|
||||
|
||||
# Broadcast frame
|
||||
results = await websocket_manager.broadcast_frame(sample_frame)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == 3
|
||||
assert all(isinstance(success, bool) for success in results.values())
|
||||
|
||||
# Check statistics
|
||||
stats = websocket_manager.get_client_stats()
|
||||
assert stats["connected_clients"] == 3
|
||||
assert stats["total_messages_sent"] >= 0
|
||||
|
||||
|
||||
class TestStreamingPipelineIntegration:
|
||||
"""Test complete streaming pipeline integration."""
|
||||
|
||||
@pytest.fixture
|
||||
async def streaming_pipeline(self):
|
||||
"""Create complete streaming pipeline."""
|
||||
class StreamingPipeline:
|
||||
def __init__(self):
|
||||
self.input_buffer = MockStreamBuffer(max_size=100)
|
||||
self.output_buffer = MockStreamBuffer(max_size=100)
|
||||
self.processor = MockStreamProcessor()
|
||||
self.websocket_manager = MockWebSocketManager()
|
||||
self.is_running = False
|
||||
self.processing_task = None
|
||||
self.broadcasting_task = None
|
||||
|
||||
async def start(self):
|
||||
"""Start the streaming pipeline."""
|
||||
if self.is_running:
|
||||
return False
|
||||
|
||||
self.is_running = True
|
||||
|
||||
# Start processing task
|
||||
self.processing_task = asyncio.create_task(
|
||||
self.processor.start_processing(self.input_buffer, self.output_buffer)
|
||||
)
|
||||
|
||||
# Start broadcasting task
|
||||
self.broadcasting_task = asyncio.create_task(
|
||||
self._broadcast_loop()
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the streaming pipeline."""
|
||||
if not self.is_running:
|
||||
return False
|
||||
|
||||
self.is_running = False
|
||||
self.processor.stop_processing()
|
||||
|
||||
# Cancel tasks
|
||||
if self.processing_task:
|
||||
self.processing_task.cancel()
|
||||
if self.broadcasting_task:
|
||||
self.broadcasting_task.cancel()
|
||||
|
||||
return True
|
||||
|
||||
async def add_frame(self, frame: StreamFrame) -> bool:
|
||||
"""Add frame to pipeline."""
|
||||
return await self.input_buffer.put_frame(frame)
|
||||
|
||||
async def add_client(self, client_id: str, websocket_mock) -> bool:
|
||||
"""Add WebSocket client."""
|
||||
return await self.websocket_manager.add_client(client_id, websocket_mock)
|
||||
|
||||
async def _broadcast_loop(self):
|
||||
"""Broadcasting loop."""
|
||||
while self.is_running:
|
||||
try:
|
||||
frame = await self.output_buffer.get_frame(timeout=0.1)
|
||||
if frame:
|
||||
await self.websocket_manager.broadcast_frame(frame)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
def get_pipeline_stats(self) -> Dict[str, Any]:
|
||||
"""Get pipeline statistics."""
|
||||
return {
|
||||
"is_running": self.is_running,
|
||||
"input_buffer": self.input_buffer.get_stats(),
|
||||
"output_buffer": self.output_buffer.get_stats(),
|
||||
"websocket_clients": self.websocket_manager.get_client_stats()
|
||||
}
|
||||
|
||||
return StreamingPipeline()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_streaming_should_fail_initially(self, streaming_pipeline):
|
||||
"""Test end-to-end streaming - should fail initially."""
|
||||
# Start pipeline
|
||||
result = await streaming_pipeline.start()
|
||||
|
||||
# This will fail initially
|
||||
assert result is True
|
||||
assert streaming_pipeline.is_running is True
|
||||
|
||||
# Add clients
|
||||
for i in range(2):
|
||||
await streaming_pipeline.add_client(f"client_{i}", MagicMock())
|
||||
|
||||
# Add frames
|
||||
for i in range(5):
|
||||
frame = StreamFrame(
|
||||
frame_id=f"frame_{i:03d}",
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
pose_data={"persons": [], "zone_summary": {}},
|
||||
processing_time_ms=30.0,
|
||||
quality_score=0.9
|
||||
)
|
||||
await streaming_pipeline.add_frame(frame)
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Stop pipeline
|
||||
await streaming_pipeline.stop()
|
||||
|
||||
# Check statistics
|
||||
stats = streaming_pipeline.get_pipeline_stats()
|
||||
assert stats["input_buffer"]["total_frames"] == 5
|
||||
assert stats["websocket_clients"]["connected_clients"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_performance_should_fail_initially(self, streaming_pipeline):
|
||||
"""Test pipeline performance - should fail initially."""
|
||||
await streaming_pipeline.start()
|
||||
|
||||
# Add multiple clients
|
||||
for i in range(10):
|
||||
await streaming_pipeline.add_client(f"client_{i:03d}", MagicMock())
|
||||
|
||||
# Measure throughput
|
||||
start_time = datetime.utcnow()
|
||||
frame_count = 50
|
||||
|
||||
for i in range(frame_count):
|
||||
frame = StreamFrame(
|
||||
frame_id=f"perf_frame_{i:03d}",
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
pose_data={"persons": [], "zone_summary": {}},
|
||||
processing_time_ms=25.0,
|
||||
quality_score=0.88
|
||||
)
|
||||
await streaming_pipeline.add_frame(frame)
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(2.0)
|
||||
|
||||
end_time = datetime.utcnow()
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
|
||||
await streaming_pipeline.stop()
|
||||
|
||||
# This will fail initially
|
||||
# Check performance metrics
|
||||
stats = streaming_pipeline.get_pipeline_stats()
|
||||
throughput = frame_count / duration
|
||||
|
||||
assert throughput > 10 # Should process at least 10 FPS
|
||||
assert stats["input_buffer"]["drop_rate"] < 0.1 # Less than 10% drop rate
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_error_recovery_should_fail_initially(self, streaming_pipeline):
|
||||
"""Test pipeline error recovery - should fail initially."""
|
||||
await streaming_pipeline.start()
|
||||
|
||||
# Set high error rate
|
||||
streaming_pipeline.processor.set_error_rate(0.5) # 50% error rate
|
||||
|
||||
# Add frames
|
||||
for i in range(20):
|
||||
frame = StreamFrame(
|
||||
frame_id=f"error_frame_{i:03d}",
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
pose_data={"persons": [], "zone_summary": {}},
|
||||
processing_time_ms=30.0,
|
||||
quality_score=0.9
|
||||
)
|
||||
await streaming_pipeline.add_frame(frame)
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
await streaming_pipeline.stop()
|
||||
|
||||
# This will fail initially
|
||||
# Pipeline should continue running despite errors
|
||||
stats = streaming_pipeline.get_pipeline_stats()
|
||||
assert stats["input_buffer"]["total_frames"] == 20
|
||||
# Some frames should be processed despite errors
|
||||
assert stats["output_buffer"]["total_frames"] > 0
|
||||
|
||||
|
||||
class TestStreamingLatency:
|
||||
"""Test streaming latency characteristics."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_latency_should_fail_initially(self):
|
||||
"""Test end-to-end latency - should fail initially."""
|
||||
class LatencyTracker:
|
||||
def __init__(self):
|
||||
self.latencies = []
|
||||
|
||||
async def measure_latency(self, frame: StreamFrame) -> float:
|
||||
"""Measure processing latency."""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
# Simulate processing pipeline
|
||||
await asyncio.sleep(0.05) # 50ms processing time
|
||||
|
||||
end_time = datetime.utcnow()
|
||||
latency = (end_time - start_time).total_seconds() * 1000 # Convert to ms
|
||||
|
||||
self.latencies.append(latency)
|
||||
return latency
|
||||
|
||||
tracker = LatencyTracker()
|
||||
|
||||
# Measure latency for multiple frames
|
||||
for i in range(10):
|
||||
frame = StreamFrame(
|
||||
frame_id=f"latency_frame_{i}",
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
pose_data={},
|
||||
processing_time_ms=0,
|
||||
quality_score=1.0
|
||||
)
|
||||
|
||||
latency = await tracker.measure_latency(frame)
|
||||
|
||||
# This will fail initially
|
||||
assert latency > 0
|
||||
assert latency < 200 # Should be less than 200ms
|
||||
|
||||
# Check average latency
|
||||
avg_latency = sum(tracker.latencies) / len(tracker.latencies)
|
||||
assert avg_latency < 100 # Average should be less than 100ms
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_stream_handling_should_fail_initially(self):
|
||||
"""Test concurrent stream handling - should fail initially."""
|
||||
async def process_stream(stream_id: str, frame_count: int) -> Dict[str, Any]:
|
||||
"""Process a single stream."""
|
||||
buffer = MockStreamBuffer()
|
||||
processed_frames = 0
|
||||
|
||||
for i in range(frame_count):
|
||||
frame = StreamFrame(
|
||||
frame_id=f"{stream_id}_frame_{i}",
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id=stream_id,
|
||||
pose_data={},
|
||||
processing_time_ms=20.0,
|
||||
quality_score=0.9
|
||||
)
|
||||
|
||||
success = await buffer.put_frame(frame)
|
||||
if success:
|
||||
processed_frames += 1
|
||||
|
||||
await asyncio.sleep(0.01) # Simulate frame rate
|
||||
|
||||
return {
|
||||
"stream_id": stream_id,
|
||||
"processed_frames": processed_frames,
|
||||
"total_frames": frame_count
|
||||
}
|
||||
|
||||
# Process multiple streams concurrently
|
||||
streams = ["router_001", "router_002", "router_003"]
|
||||
tasks = [process_stream(stream_id, 20) for stream_id in streams]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == 3
|
||||
|
||||
for result in results:
|
||||
assert result["processed_frames"] == result["total_frames"]
|
||||
assert result["stream_id"] in streams
|
||||
|
||||
|
||||
class TestStreamingResilience:
|
||||
"""Test streaming pipeline resilience."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_disconnection_handling_should_fail_initially(self):
|
||||
"""Test client disconnection handling - should fail initially."""
|
||||
websocket_manager = MockWebSocketManager()
|
||||
|
||||
# Add clients
|
||||
client_ids = [f"client_{i:03d}" for i in range(5)]
|
||||
for client_id in client_ids:
|
||||
await websocket_manager.add_client(client_id, MagicMock())
|
||||
|
||||
# Simulate frame broadcasting
|
||||
frame = StreamFrame(
|
||||
frame_id="disconnect_test_frame",
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
pose_data={},
|
||||
processing_time_ms=30.0,
|
||||
quality_score=0.9
|
||||
)
|
||||
|
||||
# Broadcast to all clients
|
||||
results = await websocket_manager.broadcast_frame(frame)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == 5
|
||||
|
||||
# Simulate client disconnections
|
||||
await websocket_manager.remove_client("client_001")
|
||||
await websocket_manager.remove_client("client_003")
|
||||
|
||||
# Broadcast again
|
||||
results = await websocket_manager.broadcast_frame(frame)
|
||||
assert len(results) == 3 # Only remaining clients
|
||||
|
||||
# Check statistics
|
||||
stats = websocket_manager.get_client_stats()
|
||||
assert stats["connected_clients"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_pressure_handling_should_fail_initially(self):
|
||||
"""Test memory pressure handling - should fail initially."""
|
||||
# Create small buffers to simulate memory pressure
|
||||
small_buffer = MockStreamBuffer(max_size=5)
|
||||
|
||||
# Generate many frames quickly
|
||||
frames_generated = 0
|
||||
frames_accepted = 0
|
||||
|
||||
for i in range(20):
|
||||
frame = StreamFrame(
|
||||
frame_id=f"memory_pressure_frame_{i}",
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
pose_data={},
|
||||
processing_time_ms=25.0,
|
||||
quality_score=0.85
|
||||
)
|
||||
|
||||
frames_generated += 1
|
||||
success = await small_buffer.put_frame(frame)
|
||||
if success:
|
||||
frames_accepted += 1
|
||||
|
||||
# This will fail initially
|
||||
# Buffer should handle memory pressure gracefully
|
||||
stats = small_buffer.get_stats()
|
||||
assert stats["total_frames"] == frames_generated
|
||||
assert stats["dropped_frames"] > 0 # Some frames should be dropped
|
||||
assert frames_accepted <= small_buffer.max_size
|
||||
|
||||
# Drop rate should be reasonable
|
||||
assert stats["drop_rate"] > 0.5 # More than 50% dropped due to small buffer
|
||||
419
tests/integration/test_websocket_streaming.py
Normal file
419
tests/integration/test_websocket_streaming.py
Normal file
@@ -0,0 +1,419 @@
|
||||
"""
|
||||
Integration tests for WebSocket streaming functionality.
|
||||
|
||||
Tests WebSocket connections, message handling, and real-time data streaming.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import websockets
|
||||
from fastapi import FastAPI, WebSocket
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
class MockWebSocket:
|
||||
"""Mock WebSocket for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.messages_sent = []
|
||||
self.messages_received = []
|
||||
self.closed = False
|
||||
self.accept_called = False
|
||||
|
||||
async def accept(self):
|
||||
"""Mock accept method."""
|
||||
self.accept_called = True
|
||||
|
||||
async def send_json(self, data: Dict[str, Any]):
|
||||
"""Mock send_json method."""
|
||||
self.messages_sent.append(data)
|
||||
|
||||
async def send_text(self, text: str):
|
||||
"""Mock send_text method."""
|
||||
self.messages_sent.append(text)
|
||||
|
||||
async def receive_text(self) -> str:
|
||||
"""Mock receive_text method."""
|
||||
if self.messages_received:
|
||||
return self.messages_received.pop(0)
|
||||
# Simulate WebSocket disconnect
|
||||
from fastapi import WebSocketDisconnect
|
||||
raise WebSocketDisconnect()
|
||||
|
||||
async def close(self):
|
||||
"""Mock close method."""
|
||||
self.closed = True
|
||||
|
||||
def add_received_message(self, message: str):
|
||||
"""Add a message to be received."""
|
||||
self.messages_received.append(message)
|
||||
|
||||
|
||||
class TestWebSocketStreaming:
|
||||
"""Integration tests for WebSocket streaming."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket(self):
|
||||
"""Create mock WebSocket."""
|
||||
return MockWebSocket()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_connection_manager(self):
|
||||
"""Mock connection manager."""
|
||||
manager = AsyncMock()
|
||||
manager.connect.return_value = "client-001"
|
||||
manager.disconnect.return_value = True
|
||||
manager.get_connection_stats.return_value = {
|
||||
"total_clients": 1,
|
||||
"active_streams": ["pose"]
|
||||
}
|
||||
manager.broadcast.return_value = 1
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stream_service(self):
|
||||
"""Mock stream service."""
|
||||
service = AsyncMock()
|
||||
service.get_status.return_value = {
|
||||
"is_active": True,
|
||||
"active_streams": [],
|
||||
"uptime_seconds": 3600.0
|
||||
}
|
||||
service.is_active.return_value = True
|
||||
service.start.return_value = None
|
||||
service.stop.return_value = None
|
||||
return service
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_pose_connection_should_fail_initially(self, mock_websocket, mock_connection_manager):
|
||||
"""Test WebSocket pose connection establishment - should fail initially."""
|
||||
# This test should fail because we haven't implemented the WebSocket handler properly
|
||||
|
||||
# Simulate WebSocket connection
|
||||
zone_ids = "zone1,zone2"
|
||||
min_confidence = 0.7
|
||||
max_fps = 30
|
||||
|
||||
# Mock the websocket_pose_stream function
|
||||
async def mock_websocket_handler(websocket, zone_ids, min_confidence, max_fps):
|
||||
await websocket.accept()
|
||||
|
||||
# Parse zone IDs
|
||||
zone_list = [zone.strip() for zone in zone_ids.split(",") if zone.strip()]
|
||||
|
||||
# Register client
|
||||
client_id = await mock_connection_manager.connect(
|
||||
websocket=websocket,
|
||||
stream_type="pose",
|
||||
zone_ids=zone_list,
|
||||
min_confidence=min_confidence,
|
||||
max_fps=max_fps
|
||||
)
|
||||
|
||||
# Send confirmation
|
||||
await websocket.send_json({
|
||||
"type": "connection_established",
|
||||
"client_id": client_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"config": {
|
||||
"zone_ids": zone_list,
|
||||
"min_confidence": min_confidence,
|
||||
"max_fps": max_fps
|
||||
}
|
||||
})
|
||||
|
||||
return client_id
|
||||
|
||||
# Execute the handler
|
||||
client_id = await mock_websocket_handler(mock_websocket, zone_ids, min_confidence, max_fps)
|
||||
|
||||
# This assertion will fail initially, driving us to implement the WebSocket handler
|
||||
assert mock_websocket.accept_called
|
||||
assert len(mock_websocket.messages_sent) == 1
|
||||
assert mock_websocket.messages_sent[0]["type"] == "connection_established"
|
||||
assert mock_websocket.messages_sent[0]["client_id"] == "client-001"
|
||||
assert "config" in mock_websocket.messages_sent[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_message_handling_should_fail_initially(self, mock_websocket):
|
||||
"""Test WebSocket message handling - should fail initially."""
|
||||
# Mock message handler
|
||||
async def handle_websocket_message(client_id: str, data: Dict[str, Any], websocket):
|
||||
message_type = data.get("type")
|
||||
|
||||
if message_type == "ping":
|
||||
await websocket.send_json({
|
||||
"type": "pong",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
elif message_type == "update_config":
|
||||
config = data.get("config", {})
|
||||
await websocket.send_json({
|
||||
"type": "config_updated",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"config": config
|
||||
})
|
||||
else:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": f"Unknown message type: {message_type}"
|
||||
})
|
||||
|
||||
# Test ping message
|
||||
ping_data = {"type": "ping"}
|
||||
await handle_websocket_message("client-001", ping_data, mock_websocket)
|
||||
|
||||
# This will fail initially
|
||||
assert len(mock_websocket.messages_sent) == 1
|
||||
assert mock_websocket.messages_sent[0]["type"] == "pong"
|
||||
|
||||
# Test config update
|
||||
mock_websocket.messages_sent.clear()
|
||||
config_data = {
|
||||
"type": "update_config",
|
||||
"config": {"min_confidence": 0.8, "max_fps": 15}
|
||||
}
|
||||
await handle_websocket_message("client-001", config_data, mock_websocket)
|
||||
|
||||
# This will fail initially
|
||||
assert len(mock_websocket.messages_sent) == 1
|
||||
assert mock_websocket.messages_sent[0]["type"] == "config_updated"
|
||||
assert mock_websocket.messages_sent[0]["config"]["min_confidence"] == 0.8
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_events_stream_should_fail_initially(self, mock_websocket, mock_connection_manager):
|
||||
"""Test WebSocket events stream - should fail initially."""
|
||||
# Mock events stream handler
|
||||
async def mock_events_handler(websocket, event_types, zone_ids):
|
||||
await websocket.accept()
|
||||
|
||||
# Parse parameters
|
||||
event_list = [event.strip() for event in event_types.split(",") if event.strip()] if event_types else None
|
||||
zone_list = [zone.strip() for zone in zone_ids.split(",") if zone.strip()] if zone_ids else None
|
||||
|
||||
# Register client
|
||||
client_id = await mock_connection_manager.connect(
|
||||
websocket=websocket,
|
||||
stream_type="events",
|
||||
zone_ids=zone_list,
|
||||
event_types=event_list
|
||||
)
|
||||
|
||||
# Send confirmation
|
||||
await websocket.send_json({
|
||||
"type": "connection_established",
|
||||
"client_id": client_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"config": {
|
||||
"event_types": event_list,
|
||||
"zone_ids": zone_list
|
||||
}
|
||||
})
|
||||
|
||||
return client_id
|
||||
|
||||
# Execute handler
|
||||
client_id = await mock_events_handler(mock_websocket, "fall_detection,intrusion", "zone1")
|
||||
|
||||
# This will fail initially
|
||||
assert mock_websocket.accept_called
|
||||
assert len(mock_websocket.messages_sent) == 1
|
||||
assert mock_websocket.messages_sent[0]["type"] == "connection_established"
|
||||
assert mock_websocket.messages_sent[0]["config"]["event_types"] == ["fall_detection", "intrusion"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_disconnect_handling_should_fail_initially(self, mock_websocket, mock_connection_manager):
|
||||
"""Test WebSocket disconnect handling - should fail initially."""
|
||||
# Mock disconnect scenario
|
||||
client_id = "client-001"
|
||||
|
||||
# Simulate disconnect
|
||||
disconnect_result = await mock_connection_manager.disconnect(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert disconnect_result is True
|
||||
mock_connection_manager.disconnect.assert_called_once_with(client_id)
|
||||
|
||||
|
||||
class TestWebSocketConnectionManager:
|
||||
"""Test WebSocket connection management."""
|
||||
|
||||
@pytest.fixture
|
||||
def connection_manager(self):
|
||||
"""Create connection manager for testing."""
|
||||
# Mock connection manager implementation
|
||||
class MockConnectionManager:
|
||||
def __init__(self):
|
||||
self.connections = {}
|
||||
self.client_counter = 0
|
||||
|
||||
async def connect(self, websocket, stream_type, zone_ids=None, **kwargs):
|
||||
self.client_counter += 1
|
||||
client_id = f"client-{self.client_counter:03d}"
|
||||
self.connections[client_id] = {
|
||||
"websocket": websocket,
|
||||
"stream_type": stream_type,
|
||||
"zone_ids": zone_ids or [],
|
||||
"connected_at": datetime.utcnow(),
|
||||
**kwargs
|
||||
}
|
||||
return client_id
|
||||
|
||||
async def disconnect(self, client_id):
|
||||
if client_id in self.connections:
|
||||
del self.connections[client_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_connected_clients(self):
|
||||
return list(self.connections.keys())
|
||||
|
||||
async def get_connection_stats(self):
|
||||
return {
|
||||
"total_clients": len(self.connections),
|
||||
"active_streams": list(set(conn["stream_type"] for conn in self.connections.values()))
|
||||
}
|
||||
|
||||
async def broadcast(self, data, stream_type=None, zone_ids=None):
|
||||
sent_count = 0
|
||||
for client_id, conn in self.connections.items():
|
||||
if stream_type and conn["stream_type"] != stream_type:
|
||||
continue
|
||||
if zone_ids and not any(zone in conn["zone_ids"] for zone in zone_ids):
|
||||
continue
|
||||
|
||||
# Mock sending data
|
||||
sent_count += 1
|
||||
|
||||
return sent_count
|
||||
|
||||
return MockConnectionManager()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_manager_connect_should_fail_initially(self, connection_manager, mock_websocket):
|
||||
"""Test connection manager connect functionality - should fail initially."""
|
||||
client_id = await connection_manager.connect(
|
||||
websocket=mock_websocket,
|
||||
stream_type="pose",
|
||||
zone_ids=["zone1", "zone2"],
|
||||
min_confidence=0.7
|
||||
)
|
||||
|
||||
# This will fail initially
|
||||
assert client_id == "client-001"
|
||||
assert client_id in connection_manager.connections
|
||||
assert connection_manager.connections[client_id]["stream_type"] == "pose"
|
||||
assert connection_manager.connections[client_id]["zone_ids"] == ["zone1", "zone2"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_manager_disconnect_should_fail_initially(self, connection_manager, mock_websocket):
|
||||
"""Test connection manager disconnect functionality - should fail initially."""
|
||||
# Connect first
|
||||
client_id = await connection_manager.connect(
|
||||
websocket=mock_websocket,
|
||||
stream_type="pose"
|
||||
)
|
||||
|
||||
# Disconnect
|
||||
result = await connection_manager.disconnect(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert result is True
|
||||
assert client_id not in connection_manager.connections
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_manager_broadcast_should_fail_initially(self, connection_manager):
|
||||
"""Test connection manager broadcast functionality - should fail initially."""
|
||||
# Connect multiple clients
|
||||
ws1 = MockWebSocket()
|
||||
ws2 = MockWebSocket()
|
||||
|
||||
client1 = await connection_manager.connect(ws1, "pose", zone_ids=["zone1"])
|
||||
client2 = await connection_manager.connect(ws2, "events", zone_ids=["zone2"])
|
||||
|
||||
# Broadcast to pose stream
|
||||
sent_count = await connection_manager.broadcast(
|
||||
data={"type": "pose_data", "data": {}},
|
||||
stream_type="pose"
|
||||
)
|
||||
|
||||
# This will fail initially
|
||||
assert sent_count == 1
|
||||
|
||||
# Broadcast to specific zone
|
||||
sent_count = await connection_manager.broadcast(
|
||||
data={"type": "zone_event", "data": {}},
|
||||
zone_ids=["zone1"]
|
||||
)
|
||||
|
||||
# This will fail initially
|
||||
assert sent_count == 1
|
||||
|
||||
|
||||
class TestWebSocketPerformance:
|
||||
"""Test WebSocket performance characteristics."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_concurrent_connections_should_fail_initially(self):
|
||||
"""Test handling multiple concurrent WebSocket connections - should fail initially."""
|
||||
# Mock multiple connections
|
||||
connection_count = 10
|
||||
connections = []
|
||||
|
||||
for i in range(connection_count):
|
||||
mock_ws = MockWebSocket()
|
||||
connections.append(mock_ws)
|
||||
|
||||
# Simulate concurrent connections
|
||||
async def simulate_connection(websocket, client_id):
|
||||
await websocket.accept()
|
||||
await websocket.send_json({
|
||||
"type": "connection_established",
|
||||
"client_id": client_id
|
||||
})
|
||||
return True
|
||||
|
||||
# Execute concurrent connections
|
||||
tasks = [
|
||||
simulate_connection(ws, f"client-{i:03d}")
|
||||
for i, ws in enumerate(connections)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == connection_count
|
||||
assert all(results)
|
||||
assert all(ws.accept_called for ws in connections)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_message_throughput_should_fail_initially(self):
|
||||
"""Test WebSocket message throughput - should fail initially."""
|
||||
mock_ws = MockWebSocket()
|
||||
message_count = 100
|
||||
|
||||
# Simulate high-frequency message sending
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
for i in range(message_count):
|
||||
await mock_ws.send_json({
|
||||
"type": "pose_data",
|
||||
"frame_id": f"frame-{i:04d}",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
|
||||
end_time = datetime.utcnow()
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
|
||||
# This will fail initially
|
||||
assert len(mock_ws.messages_sent) == message_count
|
||||
assert duration < 1.0 # Should handle 100 messages in under 1 second
|
||||
|
||||
# Calculate throughput
|
||||
throughput = message_count / duration if duration > 0 else float('inf')
|
||||
assert throughput > 100 # Should handle at least 100 messages per second
|
||||
712
tests/mocks/hardware_mocks.py
Normal file
712
tests/mocks/hardware_mocks.py
Normal file
@@ -0,0 +1,712 @@
|
||||
"""
|
||||
Hardware simulation mocks for testing.
|
||||
|
||||
Provides realistic hardware behavior simulation for routers and sensors.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional, Callable, AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
import json
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class RouterStatus(Enum):
|
||||
"""Router status enumeration."""
|
||||
OFFLINE = "offline"
|
||||
CONNECTING = "connecting"
|
||||
ONLINE = "online"
|
||||
ERROR = "error"
|
||||
MAINTENANCE = "maintenance"
|
||||
|
||||
|
||||
class SignalQuality(Enum):
|
||||
"""Signal quality levels."""
|
||||
POOR = "poor"
|
||||
FAIR = "fair"
|
||||
GOOD = "good"
|
||||
EXCELLENT = "excellent"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RouterConfig:
|
||||
"""Router configuration."""
|
||||
router_id: str
|
||||
frequency: float = 5.8e9 # 5.8 GHz
|
||||
bandwidth: float = 80e6 # 80 MHz
|
||||
num_antennas: int = 4
|
||||
num_subcarriers: int = 64
|
||||
tx_power: float = 20.0 # dBm
|
||||
location: Dict[str, float] = field(default_factory=lambda: {"x": 0, "y": 0, "z": 0})
|
||||
firmware_version: str = "1.2.3"
|
||||
|
||||
|
||||
class MockWiFiRouter:
|
||||
"""Mock WiFi router with CSI capabilities."""
|
||||
|
||||
def __init__(self, config: RouterConfig):
|
||||
self.config = config
|
||||
self.status = RouterStatus.OFFLINE
|
||||
self.signal_quality = SignalQuality.GOOD
|
||||
self.is_streaming = False
|
||||
self.connected_devices = []
|
||||
self.csi_data_buffer = []
|
||||
self.error_rate = 0.01 # 1% error rate
|
||||
self.latency_ms = 5.0
|
||||
self.throughput_mbps = 100.0
|
||||
self.temperature_celsius = 45.0
|
||||
self.uptime_seconds = 0
|
||||
self.last_heartbeat = None
|
||||
self.callbacks = {
|
||||
"on_status_change": [],
|
||||
"on_csi_data": [],
|
||||
"on_error": []
|
||||
}
|
||||
self._streaming_task = None
|
||||
self._heartbeat_task = None
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to router."""
|
||||
if self.status != RouterStatus.OFFLINE:
|
||||
return False
|
||||
|
||||
self.status = RouterStatus.CONNECTING
|
||||
await self._notify_status_change()
|
||||
|
||||
# Simulate connection delay
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Simulate occasional connection failures
|
||||
if random.random() < 0.05: # 5% failure rate
|
||||
self.status = RouterStatus.ERROR
|
||||
await self._notify_error("Connection failed")
|
||||
return False
|
||||
|
||||
self.status = RouterStatus.ONLINE
|
||||
self.last_heartbeat = datetime.utcnow()
|
||||
await self._notify_status_change()
|
||||
|
||||
# Start heartbeat
|
||||
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
||||
|
||||
return True
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from router."""
|
||||
if self.status == RouterStatus.OFFLINE:
|
||||
return
|
||||
|
||||
# Stop streaming if active
|
||||
if self.is_streaming:
|
||||
await self.stop_csi_streaming()
|
||||
|
||||
# Stop heartbeat
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
try:
|
||||
await self._heartbeat_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self.status = RouterStatus.OFFLINE
|
||||
await self._notify_status_change()
|
||||
|
||||
async def start_csi_streaming(self, sample_rate: int = 1000) -> bool:
|
||||
"""Start CSI data streaming."""
|
||||
if self.status != RouterStatus.ONLINE:
|
||||
return False
|
||||
|
||||
if self.is_streaming:
|
||||
return False
|
||||
|
||||
self.is_streaming = True
|
||||
self._streaming_task = asyncio.create_task(self._csi_streaming_loop(sample_rate))
|
||||
|
||||
return True
|
||||
|
||||
async def stop_csi_streaming(self):
|
||||
"""Stop CSI data streaming."""
|
||||
if not self.is_streaming:
|
||||
return
|
||||
|
||||
self.is_streaming = False
|
||||
|
||||
if self._streaming_task:
|
||||
self._streaming_task.cancel()
|
||||
try:
|
||||
await self._streaming_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _csi_streaming_loop(self, sample_rate: int):
|
||||
"""CSI data streaming loop."""
|
||||
interval = 1.0 / sample_rate
|
||||
|
||||
try:
|
||||
while self.is_streaming:
|
||||
# Generate CSI data
|
||||
csi_data = self._generate_csi_sample()
|
||||
|
||||
# Add to buffer
|
||||
self.csi_data_buffer.append(csi_data)
|
||||
|
||||
# Keep buffer size manageable
|
||||
if len(self.csi_data_buffer) > 1000:
|
||||
self.csi_data_buffer = self.csi_data_buffer[-1000:]
|
||||
|
||||
# Notify callbacks
|
||||
await self._notify_csi_data(csi_data)
|
||||
|
||||
# Simulate processing delay and jitter
|
||||
actual_interval = interval * random.uniform(0.9, 1.1)
|
||||
await asyncio.sleep(actual_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _heartbeat_loop(self):
|
||||
"""Heartbeat loop to maintain connection."""
|
||||
try:
|
||||
while self.status == RouterStatus.ONLINE:
|
||||
self.last_heartbeat = datetime.utcnow()
|
||||
self.uptime_seconds += 1
|
||||
|
||||
# Simulate temperature variations
|
||||
self.temperature_celsius += random.uniform(-1, 1)
|
||||
self.temperature_celsius = max(30, min(80, self.temperature_celsius))
|
||||
|
||||
# Check for overheating
|
||||
if self.temperature_celsius > 75:
|
||||
self.signal_quality = SignalQuality.POOR
|
||||
await self._notify_error("High temperature warning")
|
||||
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def _generate_csi_sample(self) -> Dict[str, Any]:
|
||||
"""Generate realistic CSI sample."""
|
||||
# Base amplitude and phase matrices
|
||||
amplitude = np.random.uniform(0.2, 0.8, (self.config.num_antennas, self.config.num_subcarriers))
|
||||
phase = np.random.uniform(-np.pi, np.pi, (self.config.num_antennas, self.config.num_subcarriers))
|
||||
|
||||
# Add signal quality effects
|
||||
if self.signal_quality == SignalQuality.POOR:
|
||||
noise_level = 0.3
|
||||
elif self.signal_quality == SignalQuality.FAIR:
|
||||
noise_level = 0.2
|
||||
elif self.signal_quality == SignalQuality.GOOD:
|
||||
noise_level = 0.1
|
||||
else: # EXCELLENT
|
||||
noise_level = 0.05
|
||||
|
||||
# Add noise
|
||||
amplitude += np.random.normal(0, noise_level, amplitude.shape)
|
||||
phase += np.random.normal(0, noise_level * np.pi, phase.shape)
|
||||
|
||||
# Clip values
|
||||
amplitude = np.clip(amplitude, 0, 1)
|
||||
phase = np.mod(phase + np.pi, 2 * np.pi) - np.pi
|
||||
|
||||
# Simulate packet errors
|
||||
if random.random() < self.error_rate:
|
||||
# Corrupt some data
|
||||
corruption_mask = np.random.random(amplitude.shape) < 0.1
|
||||
amplitude[corruption_mask] = 0
|
||||
phase[corruption_mask] = 0
|
||||
|
||||
return {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"router_id": self.config.router_id,
|
||||
"amplitude": amplitude.tolist(),
|
||||
"phase": phase.tolist(),
|
||||
"frequency": self.config.frequency,
|
||||
"bandwidth": self.config.bandwidth,
|
||||
"num_antennas": self.config.num_antennas,
|
||||
"num_subcarriers": self.config.num_subcarriers,
|
||||
"signal_quality": self.signal_quality.value,
|
||||
"temperature": self.temperature_celsius,
|
||||
"tx_power": self.config.tx_power,
|
||||
"sequence_number": len(self.csi_data_buffer)
|
||||
}
|
||||
|
||||
def register_callback(self, event: str, callback: Callable):
|
||||
"""Register event callback."""
|
||||
if event in self.callbacks:
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def unregister_callback(self, event: str, callback: Callable):
|
||||
"""Unregister event callback."""
|
||||
if event in self.callbacks and callback in self.callbacks[event]:
|
||||
self.callbacks[event].remove(callback)
|
||||
|
||||
async def _notify_status_change(self):
|
||||
"""Notify status change callbacks."""
|
||||
for callback in self.callbacks["on_status_change"]:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(self.status)
|
||||
else:
|
||||
callback(self.status)
|
||||
except Exception:
|
||||
pass # Ignore callback errors
|
||||
|
||||
async def _notify_csi_data(self, data: Dict[str, Any]):
|
||||
"""Notify CSI data callbacks."""
|
||||
for callback in self.callbacks["on_csi_data"]:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(data)
|
||||
else:
|
||||
callback(data)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _notify_error(self, error_message: str):
|
||||
"""Notify error callbacks."""
|
||||
for callback in self.callbacks["on_error"]:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(error_message)
|
||||
else:
|
||||
callback(error_message)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Get router status information."""
|
||||
return {
|
||||
"router_id": self.config.router_id,
|
||||
"status": self.status.value,
|
||||
"signal_quality": self.signal_quality.value,
|
||||
"is_streaming": self.is_streaming,
|
||||
"connected_devices": len(self.connected_devices),
|
||||
"temperature": self.temperature_celsius,
|
||||
"uptime_seconds": self.uptime_seconds,
|
||||
"last_heartbeat": self.last_heartbeat.isoformat() if self.last_heartbeat else None,
|
||||
"error_rate": self.error_rate,
|
||||
"latency_ms": self.latency_ms,
|
||||
"throughput_mbps": self.throughput_mbps,
|
||||
"firmware_version": self.config.firmware_version,
|
||||
"location": self.config.location
|
||||
}
|
||||
|
||||
def set_signal_quality(self, quality: SignalQuality):
|
||||
"""Set signal quality for testing."""
|
||||
self.signal_quality = quality
|
||||
|
||||
def set_error_rate(self, error_rate: float):
|
||||
"""Set error rate for testing."""
|
||||
self.error_rate = max(0, min(1, error_rate))
|
||||
|
||||
def simulate_interference(self, duration_seconds: float = 5.0):
|
||||
"""Simulate interference for testing."""
|
||||
async def interference_task():
|
||||
original_quality = self.signal_quality
|
||||
self.signal_quality = SignalQuality.POOR
|
||||
await asyncio.sleep(duration_seconds)
|
||||
self.signal_quality = original_quality
|
||||
|
||||
asyncio.create_task(interference_task())
|
||||
|
||||
def get_csi_buffer(self) -> List[Dict[str, Any]]:
|
||||
"""Get CSI data buffer."""
|
||||
return self.csi_data_buffer.copy()
|
||||
|
||||
def clear_csi_buffer(self):
|
||||
"""Clear CSI data buffer."""
|
||||
self.csi_data_buffer.clear()
|
||||
|
||||
|
||||
class MockRouterNetwork:
|
||||
"""Mock network of WiFi routers."""
|
||||
|
||||
def __init__(self):
|
||||
self.routers = {}
|
||||
self.network_topology = {}
|
||||
self.interference_sources = []
|
||||
self.global_callbacks = {
|
||||
"on_router_added": [],
|
||||
"on_router_removed": [],
|
||||
"on_network_event": []
|
||||
}
|
||||
|
||||
def add_router(self, config: RouterConfig) -> MockWiFiRouter:
|
||||
"""Add router to network."""
|
||||
if config.router_id in self.routers:
|
||||
raise ValueError(f"Router {config.router_id} already exists")
|
||||
|
||||
router = MockWiFiRouter(config)
|
||||
self.routers[config.router_id] = router
|
||||
|
||||
# Register for router events
|
||||
router.register_callback("on_status_change", self._on_router_status_change)
|
||||
router.register_callback("on_error", self._on_router_error)
|
||||
|
||||
# Notify callbacks
|
||||
for callback in self.global_callbacks["on_router_added"]:
|
||||
callback(router)
|
||||
|
||||
return router
|
||||
|
||||
def remove_router(self, router_id: str) -> bool:
|
||||
"""Remove router from network."""
|
||||
if router_id not in self.routers:
|
||||
return False
|
||||
|
||||
router = self.routers[router_id]
|
||||
|
||||
# Disconnect if connected
|
||||
if router.status != RouterStatus.OFFLINE:
|
||||
asyncio.create_task(router.disconnect())
|
||||
|
||||
del self.routers[router_id]
|
||||
|
||||
# Notify callbacks
|
||||
for callback in self.global_callbacks["on_router_removed"]:
|
||||
callback(router_id)
|
||||
|
||||
return True
|
||||
|
||||
def get_router(self, router_id: str) -> Optional[MockWiFiRouter]:
|
||||
"""Get router by ID."""
|
||||
return self.routers.get(router_id)
|
||||
|
||||
def get_all_routers(self) -> Dict[str, MockWiFiRouter]:
|
||||
"""Get all routers."""
|
||||
return self.routers.copy()
|
||||
|
||||
async def connect_all_routers(self) -> Dict[str, bool]:
|
||||
"""Connect all routers."""
|
||||
results = {}
|
||||
tasks = []
|
||||
|
||||
for router_id, router in self.routers.items():
|
||||
task = asyncio.create_task(router.connect())
|
||||
tasks.append((router_id, task))
|
||||
|
||||
for router_id, task in tasks:
|
||||
try:
|
||||
result = await task
|
||||
results[router_id] = result
|
||||
except Exception:
|
||||
results[router_id] = False
|
||||
|
||||
return results
|
||||
|
||||
async def disconnect_all_routers(self):
|
||||
"""Disconnect all routers."""
|
||||
tasks = []
|
||||
|
||||
for router in self.routers.values():
|
||||
if router.status != RouterStatus.OFFLINE:
|
||||
task = asyncio.create_task(router.disconnect())
|
||||
tasks.append(task)
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def start_all_streaming(self, sample_rate: int = 1000) -> Dict[str, bool]:
|
||||
"""Start CSI streaming on all routers."""
|
||||
results = {}
|
||||
|
||||
for router_id, router in self.routers.items():
|
||||
if router.status == RouterStatus.ONLINE:
|
||||
result = await router.start_csi_streaming(sample_rate)
|
||||
results[router_id] = result
|
||||
else:
|
||||
results[router_id] = False
|
||||
|
||||
return results
|
||||
|
||||
async def stop_all_streaming(self):
|
||||
"""Stop CSI streaming on all routers."""
|
||||
tasks = []
|
||||
|
||||
for router in self.routers.values():
|
||||
if router.is_streaming:
|
||||
task = asyncio.create_task(router.stop_csi_streaming())
|
||||
tasks.append(task)
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
def get_network_status(self) -> Dict[str, Any]:
|
||||
"""Get overall network status."""
|
||||
total_routers = len(self.routers)
|
||||
online_routers = sum(1 for r in self.routers.values() if r.status == RouterStatus.ONLINE)
|
||||
streaming_routers = sum(1 for r in self.routers.values() if r.is_streaming)
|
||||
|
||||
return {
|
||||
"total_routers": total_routers,
|
||||
"online_routers": online_routers,
|
||||
"streaming_routers": streaming_routers,
|
||||
"network_health": online_routers / max(total_routers, 1),
|
||||
"interference_sources": len(self.interference_sources),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
def simulate_network_partition(self, router_ids: List[str], duration_seconds: float = 10.0):
|
||||
"""Simulate network partition for testing."""
|
||||
async def partition_task():
|
||||
# Disconnect specified routers
|
||||
affected_routers = [self.routers[rid] for rid in router_ids if rid in self.routers]
|
||||
|
||||
for router in affected_routers:
|
||||
if router.status == RouterStatus.ONLINE:
|
||||
router.status = RouterStatus.ERROR
|
||||
await router._notify_status_change()
|
||||
|
||||
await asyncio.sleep(duration_seconds)
|
||||
|
||||
# Reconnect routers
|
||||
for router in affected_routers:
|
||||
if router.status == RouterStatus.ERROR:
|
||||
await router.connect()
|
||||
|
||||
asyncio.create_task(partition_task())
|
||||
|
||||
def add_interference_source(self, location: Dict[str, float], strength: float, frequency: float):
|
||||
"""Add interference source."""
|
||||
interference = {
|
||||
"id": f"interference_{len(self.interference_sources)}",
|
||||
"location": location,
|
||||
"strength": strength,
|
||||
"frequency": frequency,
|
||||
"active": True
|
||||
}
|
||||
|
||||
self.interference_sources.append(interference)
|
||||
|
||||
# Affect nearby routers
|
||||
for router in self.routers.values():
|
||||
distance = self._calculate_distance(router.config.location, location)
|
||||
if distance < 50: # Within 50 meters
|
||||
if strength > 0.5:
|
||||
router.set_signal_quality(SignalQuality.POOR)
|
||||
elif strength > 0.3:
|
||||
router.set_signal_quality(SignalQuality.FAIR)
|
||||
|
||||
def _calculate_distance(self, loc1: Dict[str, float], loc2: Dict[str, float]) -> float:
|
||||
"""Calculate distance between two locations."""
|
||||
dx = loc1.get("x", 0) - loc2.get("x", 0)
|
||||
dy = loc1.get("y", 0) - loc2.get("y", 0)
|
||||
dz = loc1.get("z", 0) - loc2.get("z", 0)
|
||||
return np.sqrt(dx**2 + dy**2 + dz**2)
|
||||
|
||||
async def _on_router_status_change(self, status: RouterStatus):
|
||||
"""Handle router status change."""
|
||||
for callback in self.global_callbacks["on_network_event"]:
|
||||
await callback("router_status_change", {"status": status})
|
||||
|
||||
async def _on_router_error(self, error_message: str):
|
||||
"""Handle router error."""
|
||||
for callback in self.global_callbacks["on_network_event"]:
|
||||
await callback("router_error", {"error": error_message})
|
||||
|
||||
def register_global_callback(self, event: str, callback: Callable):
|
||||
"""Register global network callback."""
|
||||
if event in self.global_callbacks:
|
||||
self.global_callbacks[event].append(callback)
|
||||
|
||||
|
||||
class MockSensorArray:
|
||||
"""Mock sensor array for environmental monitoring."""
|
||||
|
||||
def __init__(self, sensor_id: str, location: Dict[str, float]):
|
||||
self.sensor_id = sensor_id
|
||||
self.location = location
|
||||
self.is_active = False
|
||||
self.sensors = {
|
||||
"temperature": {"value": 22.0, "unit": "celsius", "range": (15, 35)},
|
||||
"humidity": {"value": 45.0, "unit": "percent", "range": (30, 70)},
|
||||
"pressure": {"value": 1013.25, "unit": "hPa", "range": (980, 1050)},
|
||||
"light": {"value": 300.0, "unit": "lux", "range": (0, 1000)},
|
||||
"motion": {"value": False, "unit": "boolean", "range": (False, True)},
|
||||
"sound": {"value": 35.0, "unit": "dB", "range": (20, 80)}
|
||||
}
|
||||
self.reading_history = []
|
||||
self.callbacks = []
|
||||
|
||||
async def start_monitoring(self, interval_seconds: float = 1.0):
|
||||
"""Start sensor monitoring."""
|
||||
if self.is_active:
|
||||
return False
|
||||
|
||||
self.is_active = True
|
||||
asyncio.create_task(self._monitoring_loop(interval_seconds))
|
||||
return True
|
||||
|
||||
def stop_monitoring(self):
|
||||
"""Stop sensor monitoring."""
|
||||
self.is_active = False
|
||||
|
||||
async def _monitoring_loop(self, interval: float):
|
||||
"""Sensor monitoring loop."""
|
||||
try:
|
||||
while self.is_active:
|
||||
reading = self._generate_sensor_reading()
|
||||
self.reading_history.append(reading)
|
||||
|
||||
# Keep history manageable
|
||||
if len(self.reading_history) > 1000:
|
||||
self.reading_history = self.reading_history[-1000:]
|
||||
|
||||
# Notify callbacks
|
||||
for callback in self.callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(reading)
|
||||
else:
|
||||
callback(reading)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def _generate_sensor_reading(self) -> Dict[str, Any]:
|
||||
"""Generate realistic sensor reading."""
|
||||
reading = {
|
||||
"sensor_id": self.sensor_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"location": self.location,
|
||||
"readings": {}
|
||||
}
|
||||
|
||||
for sensor_name, config in self.sensors.items():
|
||||
if sensor_name == "motion":
|
||||
# Motion detection with some randomness
|
||||
reading["readings"][sensor_name] = random.random() < 0.1 # 10% chance of motion
|
||||
else:
|
||||
# Continuous sensors with drift
|
||||
current_value = config["value"]
|
||||
min_val, max_val = config["range"]
|
||||
|
||||
# Add small random drift
|
||||
drift = random.uniform(-0.1, 0.1) * (max_val - min_val)
|
||||
new_value = current_value + drift
|
||||
|
||||
# Keep within range
|
||||
new_value = max(min_val, min(max_val, new_value))
|
||||
|
||||
config["value"] = new_value
|
||||
reading["readings"][sensor_name] = {
|
||||
"value": round(new_value, 2),
|
||||
"unit": config["unit"]
|
||||
}
|
||||
|
||||
return reading
|
||||
|
||||
def register_callback(self, callback: Callable):
|
||||
"""Register sensor callback."""
|
||||
self.callbacks.append(callback)
|
||||
|
||||
def unregister_callback(self, callback: Callable):
|
||||
"""Unregister sensor callback."""
|
||||
if callback in self.callbacks:
|
||||
self.callbacks.remove(callback)
|
||||
|
||||
def get_latest_reading(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get latest sensor reading."""
|
||||
return self.reading_history[-1] if self.reading_history else None
|
||||
|
||||
def get_reading_history(self, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Get sensor reading history."""
|
||||
return self.reading_history[-limit:]
|
||||
|
||||
def simulate_event(self, event_type: str, duration_seconds: float = 5.0):
|
||||
"""Simulate environmental event."""
|
||||
async def event_task():
|
||||
if event_type == "motion_detected":
|
||||
self.sensors["motion"]["value"] = True
|
||||
await asyncio.sleep(duration_seconds)
|
||||
self.sensors["motion"]["value"] = False
|
||||
|
||||
elif event_type == "temperature_spike":
|
||||
original_temp = self.sensors["temperature"]["value"]
|
||||
self.sensors["temperature"]["value"] = min(35, original_temp + 10)
|
||||
await asyncio.sleep(duration_seconds)
|
||||
self.sensors["temperature"]["value"] = original_temp
|
||||
|
||||
elif event_type == "loud_noise":
|
||||
original_sound = self.sensors["sound"]["value"]
|
||||
self.sensors["sound"]["value"] = min(80, original_sound + 20)
|
||||
await asyncio.sleep(duration_seconds)
|
||||
self.sensors["sound"]["value"] = original_sound
|
||||
|
||||
asyncio.create_task(event_task())
|
||||
|
||||
|
||||
# Utility functions for creating test hardware setups
|
||||
def create_test_router_network(num_routers: int = 3) -> MockRouterNetwork:
|
||||
"""Create test router network."""
|
||||
network = MockRouterNetwork()
|
||||
|
||||
for i in range(num_routers):
|
||||
config = RouterConfig(
|
||||
router_id=f"router_{i:03d}",
|
||||
location={"x": i * 10, "y": 0, "z": 2.5}
|
||||
)
|
||||
network.add_router(config)
|
||||
|
||||
return network
|
||||
|
||||
|
||||
def create_test_sensor_array(num_sensors: int = 2) -> List[MockSensorArray]:
|
||||
"""Create test sensor array."""
|
||||
sensors = []
|
||||
|
||||
for i in range(num_sensors):
|
||||
sensor = MockSensorArray(
|
||||
sensor_id=f"sensor_{i:03d}",
|
||||
location={"x": i * 5, "y": 5, "z": 1.0}
|
||||
)
|
||||
sensors.append(sensor)
|
||||
|
||||
return sensors
|
||||
|
||||
|
||||
async def setup_test_hardware_environment() -> Dict[str, Any]:
|
||||
"""Setup complete test hardware environment."""
|
||||
# Create router network
|
||||
router_network = create_test_router_network(3)
|
||||
|
||||
# Create sensor arrays
|
||||
sensor_arrays = create_test_sensor_array(2)
|
||||
|
||||
# Connect all routers
|
||||
router_results = await router_network.connect_all_routers()
|
||||
|
||||
# Start sensor monitoring
|
||||
sensor_tasks = []
|
||||
for sensor in sensor_arrays:
|
||||
task = asyncio.create_task(sensor.start_monitoring(1.0))
|
||||
sensor_tasks.append(task)
|
||||
|
||||
sensor_results = await asyncio.gather(*sensor_tasks)
|
||||
|
||||
return {
|
||||
"router_network": router_network,
|
||||
"sensor_arrays": sensor_arrays,
|
||||
"router_connection_results": router_results,
|
||||
"sensor_start_results": sensor_results,
|
||||
"setup_timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
async def teardown_test_hardware_environment(environment: Dict[str, Any]):
|
||||
"""Teardown test hardware environment."""
|
||||
# Stop sensor monitoring
|
||||
for sensor in environment["sensor_arrays"]:
|
||||
sensor.stop_monitoring()
|
||||
|
||||
# Disconnect all routers
|
||||
await environment["router_network"].disconnect_all_routers()
|
||||
649
tests/performance/test_api_throughput.py
Normal file
649
tests/performance/test_api_throughput.py
Normal file
@@ -0,0 +1,649 @@
|
||||
"""
|
||||
Performance tests for API throughput and load testing.
|
||||
|
||||
Tests API endpoint performance under various load conditions.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import time
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import json
|
||||
import statistics
|
||||
|
||||
|
||||
class MockAPIServer:
|
||||
"""Mock API server for load testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.request_count = 0
|
||||
self.response_times = []
|
||||
self.error_count = 0
|
||||
self.concurrent_requests = 0
|
||||
self.max_concurrent = 0
|
||||
self.is_running = False
|
||||
self.rate_limit_enabled = False
|
||||
self.rate_limit_per_second = 100
|
||||
self.request_timestamps = []
|
||||
|
||||
async def handle_request(self, endpoint: str, method: str = "GET", data: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""Handle API request."""
|
||||
start_time = time.time()
|
||||
self.concurrent_requests += 1
|
||||
self.max_concurrent = max(self.max_concurrent, self.concurrent_requests)
|
||||
self.request_count += 1
|
||||
self.request_timestamps.append(start_time)
|
||||
|
||||
try:
|
||||
# Check rate limiting
|
||||
if self.rate_limit_enabled:
|
||||
recent_requests = [
|
||||
ts for ts in self.request_timestamps
|
||||
if start_time - ts <= 1.0
|
||||
]
|
||||
if len(recent_requests) > self.rate_limit_per_second:
|
||||
self.error_count += 1
|
||||
return {
|
||||
"status": 429,
|
||||
"error": "Rate limit exceeded",
|
||||
"response_time_ms": 1.0
|
||||
}
|
||||
|
||||
# Simulate processing time based on endpoint
|
||||
processing_time = self._get_processing_time(endpoint, method)
|
||||
await asyncio.sleep(processing_time)
|
||||
|
||||
# Generate response
|
||||
response = self._generate_response(endpoint, method, data)
|
||||
|
||||
end_time = time.time()
|
||||
response_time = (end_time - start_time) * 1000
|
||||
self.response_times.append(response_time)
|
||||
|
||||
return {
|
||||
"status": 200,
|
||||
"data": response,
|
||||
"response_time_ms": response_time
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.error_count += 1
|
||||
return {
|
||||
"status": 500,
|
||||
"error": str(e),
|
||||
"response_time_ms": (time.time() - start_time) * 1000
|
||||
}
|
||||
finally:
|
||||
self.concurrent_requests -= 1
|
||||
|
||||
def _get_processing_time(self, endpoint: str, method: str) -> float:
|
||||
"""Get processing time for endpoint."""
|
||||
processing_times = {
|
||||
"/health": 0.001,
|
||||
"/pose/detect": 0.05,
|
||||
"/pose/stream": 0.02,
|
||||
"/auth/login": 0.01,
|
||||
"/auth/refresh": 0.005,
|
||||
"/config": 0.003
|
||||
}
|
||||
|
||||
base_time = processing_times.get(endpoint, 0.01)
|
||||
|
||||
# Add some variance
|
||||
return base_time * np.random.uniform(0.8, 1.2)
|
||||
|
||||
def _generate_response(self, endpoint: str, method: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Generate response for endpoint."""
|
||||
if endpoint == "/health":
|
||||
return {"status": "healthy", "timestamp": datetime.utcnow().isoformat()}
|
||||
|
||||
elif endpoint == "/pose/detect":
|
||||
return {
|
||||
"persons": [
|
||||
{
|
||||
"person_id": "person_1",
|
||||
"confidence": 0.85,
|
||||
"bounding_box": {"x": 100, "y": 150, "width": 80, "height": 180},
|
||||
"keypoints": [[x, y, 0.9] for x, y in zip(range(17), range(17))]
|
||||
}
|
||||
],
|
||||
"processing_time_ms": 45.2,
|
||||
"model_version": "v1.0"
|
||||
}
|
||||
|
||||
elif endpoint == "/auth/login":
|
||||
return {
|
||||
"access_token": "mock_access_token",
|
||||
"refresh_token": "mock_refresh_token",
|
||||
"expires_in": 3600
|
||||
}
|
||||
|
||||
else:
|
||||
return {"message": "Success", "endpoint": endpoint, "method": method}
|
||||
|
||||
def get_performance_stats(self) -> Dict[str, Any]:
|
||||
"""Get performance statistics."""
|
||||
if not self.response_times:
|
||||
return {
|
||||
"total_requests": self.request_count,
|
||||
"error_count": self.error_count,
|
||||
"error_rate": 0,
|
||||
"avg_response_time_ms": 0,
|
||||
"median_response_time_ms": 0,
|
||||
"p95_response_time_ms": 0,
|
||||
"p99_response_time_ms": 0,
|
||||
"max_concurrent_requests": self.max_concurrent,
|
||||
"requests_per_second": 0
|
||||
}
|
||||
|
||||
return {
|
||||
"total_requests": self.request_count,
|
||||
"error_count": self.error_count,
|
||||
"error_rate": self.error_count / self.request_count,
|
||||
"avg_response_time_ms": statistics.mean(self.response_times),
|
||||
"median_response_time_ms": statistics.median(self.response_times),
|
||||
"p95_response_time_ms": np.percentile(self.response_times, 95),
|
||||
"p99_response_time_ms": np.percentile(self.response_times, 99),
|
||||
"max_concurrent_requests": self.max_concurrent,
|
||||
"requests_per_second": self._calculate_rps()
|
||||
}
|
||||
|
||||
def _calculate_rps(self) -> float:
|
||||
"""Calculate requests per second."""
|
||||
if len(self.request_timestamps) < 2:
|
||||
return 0
|
||||
|
||||
duration = self.request_timestamps[-1] - self.request_timestamps[0]
|
||||
return len(self.request_timestamps) / max(duration, 0.001)
|
||||
|
||||
def enable_rate_limiting(self, requests_per_second: int):
|
||||
"""Enable rate limiting."""
|
||||
self.rate_limit_enabled = True
|
||||
self.rate_limit_per_second = requests_per_second
|
||||
|
||||
def reset_stats(self):
|
||||
"""Reset performance statistics."""
|
||||
self.request_count = 0
|
||||
self.response_times = []
|
||||
self.error_count = 0
|
||||
self.concurrent_requests = 0
|
||||
self.max_concurrent = 0
|
||||
self.request_timestamps = []
|
||||
|
||||
|
||||
class TestAPIThroughput:
|
||||
"""Test API throughput under various conditions."""
|
||||
|
||||
@pytest.fixture
|
||||
def api_server(self):
|
||||
"""Create mock API server."""
|
||||
return MockAPIServer()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_request_performance_should_fail_initially(self, api_server):
|
||||
"""Test single request performance - should fail initially."""
|
||||
start_time = time.time()
|
||||
response = await api_server.handle_request("/health")
|
||||
end_time = time.time()
|
||||
|
||||
response_time = (end_time - start_time) * 1000
|
||||
|
||||
# This will fail initially
|
||||
assert response["status"] == 200
|
||||
assert response_time < 50 # Should respond within 50ms
|
||||
assert response["response_time_ms"] > 0
|
||||
|
||||
stats = api_server.get_performance_stats()
|
||||
assert stats["total_requests"] == 1
|
||||
assert stats["error_count"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_request_handling_should_fail_initially(self, api_server):
|
||||
"""Test concurrent request handling - should fail initially."""
|
||||
# Send multiple concurrent requests
|
||||
concurrent_requests = 10
|
||||
tasks = []
|
||||
|
||||
for i in range(concurrent_requests):
|
||||
task = asyncio.create_task(api_server.handle_request("/health"))
|
||||
tasks.append(task)
|
||||
|
||||
start_time = time.time()
|
||||
responses = await asyncio.gather(*tasks)
|
||||
end_time = time.time()
|
||||
|
||||
total_time = (end_time - start_time) * 1000
|
||||
|
||||
# This will fail initially
|
||||
assert len(responses) == concurrent_requests
|
||||
assert all(r["status"] == 200 for r in responses)
|
||||
|
||||
# All requests should complete within reasonable time
|
||||
assert total_time < 200 # Should complete within 200ms
|
||||
|
||||
stats = api_server.get_performance_stats()
|
||||
assert stats["total_requests"] == concurrent_requests
|
||||
assert stats["max_concurrent_requests"] <= concurrent_requests
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sustained_load_performance_should_fail_initially(self, api_server):
|
||||
"""Test sustained load performance - should fail initially."""
|
||||
duration_seconds = 3
|
||||
target_rps = 50 # 50 requests per second
|
||||
|
||||
async def send_requests():
|
||||
"""Send requests at target rate."""
|
||||
interval = 1.0 / target_rps
|
||||
end_time = time.time() + duration_seconds
|
||||
|
||||
while time.time() < end_time:
|
||||
await api_server.handle_request("/health")
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
start_time = time.time()
|
||||
await send_requests()
|
||||
actual_duration = time.time() - start_time
|
||||
|
||||
stats = api_server.get_performance_stats()
|
||||
actual_rps = stats["requests_per_second"]
|
||||
|
||||
# This will fail initially
|
||||
assert actual_rps >= target_rps * 0.8 # Within 80% of target
|
||||
assert stats["error_rate"] < 0.05 # Less than 5% error rate
|
||||
assert stats["avg_response_time_ms"] < 100 # Average response time under 100ms
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_endpoint_performance_should_fail_initially(self, api_server):
|
||||
"""Test different endpoint performance - should fail initially."""
|
||||
endpoints = [
|
||||
"/health",
|
||||
"/pose/detect",
|
||||
"/auth/login",
|
||||
"/config"
|
||||
]
|
||||
|
||||
results = {}
|
||||
|
||||
for endpoint in endpoints:
|
||||
# Test each endpoint multiple times
|
||||
response_times = []
|
||||
|
||||
for _ in range(10):
|
||||
response = await api_server.handle_request(endpoint)
|
||||
response_times.append(response["response_time_ms"])
|
||||
|
||||
results[endpoint] = {
|
||||
"avg_response_time": statistics.mean(response_times),
|
||||
"min_response_time": min(response_times),
|
||||
"max_response_time": max(response_times)
|
||||
}
|
||||
|
||||
# This will fail initially
|
||||
# Health endpoint should be fastest
|
||||
assert results["/health"]["avg_response_time"] < results["/pose/detect"]["avg_response_time"]
|
||||
|
||||
# All endpoints should respond within reasonable time
|
||||
for endpoint, metrics in results.items():
|
||||
assert metrics["avg_response_time"] < 200 # Less than 200ms average
|
||||
assert metrics["max_response_time"] < 500 # Less than 500ms max
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limiting_behavior_should_fail_initially(self, api_server):
|
||||
"""Test rate limiting behavior - should fail initially."""
|
||||
# Enable rate limiting
|
||||
api_server.enable_rate_limiting(requests_per_second=10)
|
||||
|
||||
# Send requests faster than rate limit
|
||||
rapid_requests = 20
|
||||
tasks = []
|
||||
|
||||
for i in range(rapid_requests):
|
||||
task = asyncio.create_task(api_server.handle_request("/health"))
|
||||
tasks.append(task)
|
||||
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
# This will fail initially
|
||||
# Some requests should be rate limited
|
||||
success_responses = [r for r in responses if r["status"] == 200]
|
||||
rate_limited_responses = [r for r in responses if r["status"] == 429]
|
||||
|
||||
assert len(success_responses) > 0
|
||||
assert len(rate_limited_responses) > 0
|
||||
assert len(success_responses) + len(rate_limited_responses) == rapid_requests
|
||||
|
||||
stats = api_server.get_performance_stats()
|
||||
assert stats["error_count"] > 0 # Should have rate limit errors
|
||||
|
||||
|
||||
class TestAPILoadTesting:
|
||||
"""Test API under heavy load conditions."""
|
||||
|
||||
@pytest.fixture
|
||||
def load_test_server(self):
|
||||
"""Create server for load testing."""
|
||||
server = MockAPIServer()
|
||||
return server
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_high_concurrency_load_should_fail_initially(self, load_test_server):
|
||||
"""Test high concurrency load - should fail initially."""
|
||||
concurrent_users = 50
|
||||
requests_per_user = 5
|
||||
|
||||
async def user_session(user_id: int):
|
||||
"""Simulate user session."""
|
||||
session_responses = []
|
||||
|
||||
for i in range(requests_per_user):
|
||||
response = await load_test_server.handle_request("/health")
|
||||
session_responses.append(response)
|
||||
|
||||
# Small delay between requests
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return session_responses
|
||||
|
||||
# Create user sessions
|
||||
user_tasks = [user_session(i) for i in range(concurrent_users)]
|
||||
|
||||
start_time = time.time()
|
||||
all_sessions = await asyncio.gather(*user_tasks)
|
||||
end_time = time.time()
|
||||
|
||||
total_duration = end_time - start_time
|
||||
total_requests = concurrent_users * requests_per_user
|
||||
|
||||
# This will fail initially
|
||||
# All sessions should complete
|
||||
assert len(all_sessions) == concurrent_users
|
||||
|
||||
# Check performance metrics
|
||||
stats = load_test_server.get_performance_stats()
|
||||
assert stats["total_requests"] == total_requests
|
||||
assert stats["error_rate"] < 0.1 # Less than 10% error rate
|
||||
assert stats["requests_per_second"] > 100 # Should handle at least 100 RPS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_endpoint_load_should_fail_initially(self, load_test_server):
|
||||
"""Test mixed endpoint load - should fail initially."""
|
||||
# Define endpoint mix (realistic usage pattern)
|
||||
endpoint_mix = [
|
||||
("/health", 0.4), # 40% health checks
|
||||
("/pose/detect", 0.3), # 30% pose detection
|
||||
("/auth/login", 0.1), # 10% authentication
|
||||
("/config", 0.2) # 20% configuration
|
||||
]
|
||||
|
||||
total_requests = 100
|
||||
|
||||
async def send_mixed_requests():
|
||||
"""Send requests with mixed endpoints."""
|
||||
tasks = []
|
||||
|
||||
for i in range(total_requests):
|
||||
# Select endpoint based on distribution
|
||||
rand = np.random.random()
|
||||
cumulative = 0
|
||||
|
||||
for endpoint, probability in endpoint_mix:
|
||||
cumulative += probability
|
||||
if rand <= cumulative:
|
||||
task = asyncio.create_task(
|
||||
load_test_server.handle_request(endpoint)
|
||||
)
|
||||
tasks.append(task)
|
||||
break
|
||||
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
start_time = time.time()
|
||||
responses = await send_mixed_requests()
|
||||
end_time = time.time()
|
||||
|
||||
duration = end_time - start_time
|
||||
|
||||
# This will fail initially
|
||||
assert len(responses) == total_requests
|
||||
|
||||
# Check response distribution
|
||||
success_responses = [r for r in responses if r["status"] == 200]
|
||||
assert len(success_responses) >= total_requests * 0.9 # At least 90% success
|
||||
|
||||
stats = load_test_server.get_performance_stats()
|
||||
assert stats["requests_per_second"] > 50 # Should handle at least 50 RPS
|
||||
assert stats["avg_response_time_ms"] < 150 # Average response time under 150ms
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stress_testing_should_fail_initially(self, load_test_server):
|
||||
"""Test stress testing - should fail initially."""
|
||||
# Gradually increase load to find breaking point
|
||||
load_levels = [10, 25, 50, 100, 200]
|
||||
results = {}
|
||||
|
||||
for concurrent_requests in load_levels:
|
||||
load_test_server.reset_stats()
|
||||
|
||||
# Send concurrent requests
|
||||
tasks = [
|
||||
load_test_server.handle_request("/health")
|
||||
for _ in range(concurrent_requests)
|
||||
]
|
||||
|
||||
start_time = time.time()
|
||||
responses = await asyncio.gather(*tasks)
|
||||
end_time = time.time()
|
||||
|
||||
duration = end_time - start_time
|
||||
stats = load_test_server.get_performance_stats()
|
||||
|
||||
results[concurrent_requests] = {
|
||||
"duration": duration,
|
||||
"rps": stats["requests_per_second"],
|
||||
"error_rate": stats["error_rate"],
|
||||
"avg_response_time": stats["avg_response_time_ms"],
|
||||
"p95_response_time": stats["p95_response_time_ms"]
|
||||
}
|
||||
|
||||
# This will fail initially
|
||||
# Performance should degrade gracefully with increased load
|
||||
for load_level, metrics in results.items():
|
||||
assert metrics["error_rate"] < 0.2 # Less than 20% error rate
|
||||
assert metrics["avg_response_time"] < 1000 # Less than 1 second average
|
||||
|
||||
# Higher loads should have higher response times
|
||||
assert results[10]["avg_response_time"] <= results[200]["avg_response_time"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_usage_under_load_should_fail_initially(self, load_test_server):
|
||||
"""Test memory usage under load - should fail initially."""
|
||||
import psutil
|
||||
import os
|
||||
|
||||
process = psutil.Process(os.getpid())
|
||||
initial_memory = process.memory_info().rss
|
||||
|
||||
# Generate sustained load
|
||||
duration_seconds = 5
|
||||
target_rps = 100
|
||||
|
||||
async def sustained_load():
|
||||
"""Generate sustained load."""
|
||||
interval = 1.0 / target_rps
|
||||
end_time = time.time() + duration_seconds
|
||||
|
||||
while time.time() < end_time:
|
||||
await load_test_server.handle_request("/pose/detect")
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
await sustained_load()
|
||||
|
||||
final_memory = process.memory_info().rss
|
||||
memory_increase = final_memory - initial_memory
|
||||
|
||||
# This will fail initially
|
||||
# Memory increase should be reasonable (less than 100MB)
|
||||
assert memory_increase < 100 * 1024 * 1024
|
||||
|
||||
stats = load_test_server.get_performance_stats()
|
||||
assert stats["total_requests"] > duration_seconds * target_rps * 0.8
|
||||
|
||||
|
||||
class TestAPIPerformanceOptimization:
|
||||
"""Test API performance optimization techniques."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_caching_effect_should_fail_initially(self):
|
||||
"""Test response caching effect - should fail initially."""
|
||||
class CachedAPIServer(MockAPIServer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.cache = {}
|
||||
self.cache_hits = 0
|
||||
self.cache_misses = 0
|
||||
|
||||
async def handle_request(self, endpoint: str, method: str = "GET", data: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
cache_key = f"{method}:{endpoint}"
|
||||
|
||||
if cache_key in self.cache:
|
||||
self.cache_hits += 1
|
||||
cached_response = self.cache[cache_key].copy()
|
||||
cached_response["response_time_ms"] = 1.0 # Cached responses are fast
|
||||
return cached_response
|
||||
|
||||
self.cache_misses += 1
|
||||
response = await super().handle_request(endpoint, method, data)
|
||||
|
||||
# Cache successful responses
|
||||
if response["status"] == 200:
|
||||
self.cache[cache_key] = response.copy()
|
||||
|
||||
return response
|
||||
|
||||
cached_server = CachedAPIServer()
|
||||
|
||||
# First request (cache miss)
|
||||
response1 = await cached_server.handle_request("/health")
|
||||
|
||||
# Second request (cache hit)
|
||||
response2 = await cached_server.handle_request("/health")
|
||||
|
||||
# This will fail initially
|
||||
assert response1["status"] == 200
|
||||
assert response2["status"] == 200
|
||||
assert response2["response_time_ms"] < response1["response_time_ms"]
|
||||
assert cached_server.cache_hits == 1
|
||||
assert cached_server.cache_misses == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_pooling_effect_should_fail_initially(self):
|
||||
"""Test connection pooling effect - should fail initially."""
|
||||
# Simulate connection overhead
|
||||
class ConnectionPoolServer(MockAPIServer):
|
||||
def __init__(self, pool_size: int = 10):
|
||||
super().__init__()
|
||||
self.pool_size = pool_size
|
||||
self.active_connections = 0
|
||||
self.connection_overhead = 0.01 # 10ms connection overhead
|
||||
|
||||
async def handle_request(self, endpoint: str, method: str = "GET", data: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
# Simulate connection acquisition
|
||||
if self.active_connections < self.pool_size:
|
||||
# New connection needed
|
||||
await asyncio.sleep(self.connection_overhead)
|
||||
self.active_connections += 1
|
||||
|
||||
try:
|
||||
return await super().handle_request(endpoint, method, data)
|
||||
finally:
|
||||
# Connection returned to pool (not closed)
|
||||
pass
|
||||
|
||||
pooled_server = ConnectionPoolServer(pool_size=5)
|
||||
|
||||
# Send requests that exceed pool size
|
||||
concurrent_requests = 10
|
||||
tasks = [
|
||||
pooled_server.handle_request("/health")
|
||||
for _ in range(concurrent_requests)
|
||||
]
|
||||
|
||||
start_time = time.time()
|
||||
responses = await asyncio.gather(*tasks)
|
||||
end_time = time.time()
|
||||
|
||||
total_time = (end_time - start_time) * 1000
|
||||
|
||||
# This will fail initially
|
||||
assert len(responses) == concurrent_requests
|
||||
assert all(r["status"] == 200 for r in responses)
|
||||
|
||||
# With connection pooling, should complete reasonably fast
|
||||
assert total_time < 500 # Should complete within 500ms
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_batching_performance_should_fail_initially(self):
|
||||
"""Test request batching performance - should fail initially."""
|
||||
class BatchingServer(MockAPIServer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.batch_size = 5
|
||||
self.pending_requests = []
|
||||
self.batch_processing = False
|
||||
|
||||
async def handle_batch_request(self, requests: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Handle batch of requests."""
|
||||
# Batch processing is more efficient
|
||||
batch_overhead = 0.01 # 10ms overhead for entire batch
|
||||
await asyncio.sleep(batch_overhead)
|
||||
|
||||
responses = []
|
||||
for req in requests:
|
||||
# Individual processing is faster in batch
|
||||
processing_time = self._get_processing_time(req["endpoint"], req["method"]) * 0.5
|
||||
await asyncio.sleep(processing_time)
|
||||
|
||||
response = self._generate_response(req["endpoint"], req["method"], req.get("data"))
|
||||
responses.append({
|
||||
"status": 200,
|
||||
"data": response,
|
||||
"response_time_ms": processing_time * 1000
|
||||
})
|
||||
|
||||
return responses
|
||||
|
||||
batching_server = BatchingServer()
|
||||
|
||||
# Test individual requests vs batch
|
||||
individual_requests = 5
|
||||
|
||||
# Individual requests
|
||||
start_time = time.time()
|
||||
individual_tasks = [
|
||||
batching_server.handle_request("/health")
|
||||
for _ in range(individual_requests)
|
||||
]
|
||||
individual_responses = await asyncio.gather(*individual_tasks)
|
||||
individual_time = (time.time() - start_time) * 1000
|
||||
|
||||
# Batch request
|
||||
batch_requests = [
|
||||
{"endpoint": "/health", "method": "GET"}
|
||||
for _ in range(individual_requests)
|
||||
]
|
||||
|
||||
start_time = time.time()
|
||||
batch_responses = await batching_server.handle_batch_request(batch_requests)
|
||||
batch_time = (time.time() - start_time) * 1000
|
||||
|
||||
# This will fail initially
|
||||
assert len(individual_responses) == individual_requests
|
||||
assert len(batch_responses) == individual_requests
|
||||
|
||||
# Batch should be more efficient
|
||||
assert batch_time < individual_time
|
||||
assert all(r["status"] == 200 for r in batch_responses)
|
||||
507
tests/performance/test_inference_speed.py
Normal file
507
tests/performance/test_inference_speed.py
Normal file
@@ -0,0 +1,507 @@
|
||||
"""
|
||||
Performance tests for ML model inference speed.
|
||||
|
||||
Tests pose estimation model performance, throughput, and optimization.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import numpy as np
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import psutil
|
||||
import os
|
||||
|
||||
|
||||
class MockPoseModel:
|
||||
"""Mock pose estimation model for performance testing."""
|
||||
|
||||
def __init__(self, model_complexity: str = "standard"):
|
||||
self.model_complexity = model_complexity
|
||||
self.is_loaded = False
|
||||
self.inference_count = 0
|
||||
self.total_inference_time = 0.0
|
||||
self.batch_size = 1
|
||||
|
||||
# Model complexity affects inference time
|
||||
self.base_inference_time = {
|
||||
"lightweight": 0.02, # 20ms
|
||||
"standard": 0.05, # 50ms
|
||||
"high_accuracy": 0.15 # 150ms
|
||||
}.get(model_complexity, 0.05)
|
||||
|
||||
async def load_model(self):
|
||||
"""Load the model."""
|
||||
# Simulate model loading time
|
||||
load_time = {
|
||||
"lightweight": 0.5,
|
||||
"standard": 2.0,
|
||||
"high_accuracy": 5.0
|
||||
}.get(self.model_complexity, 2.0)
|
||||
|
||||
await asyncio.sleep(load_time)
|
||||
self.is_loaded = True
|
||||
|
||||
async def predict(self, features: np.ndarray) -> Dict[str, Any]:
|
||||
"""Run inference on features."""
|
||||
if not self.is_loaded:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Simulate inference computation
|
||||
batch_size = features.shape[0] if len(features.shape) > 2 else 1
|
||||
inference_time = self.base_inference_time * batch_size
|
||||
|
||||
# Add some variance
|
||||
inference_time *= np.random.uniform(0.8, 1.2)
|
||||
|
||||
await asyncio.sleep(inference_time)
|
||||
|
||||
end_time = time.time()
|
||||
actual_inference_time = end_time - start_time
|
||||
|
||||
self.inference_count += batch_size
|
||||
self.total_inference_time += actual_inference_time
|
||||
|
||||
# Generate mock predictions
|
||||
predictions = []
|
||||
for i in range(batch_size):
|
||||
predictions.append({
|
||||
"person_id": f"person_{i}",
|
||||
"confidence": np.random.uniform(0.5, 0.95),
|
||||
"keypoints": np.random.rand(17, 3).tolist(), # 17 keypoints with x,y,confidence
|
||||
"bounding_box": {
|
||||
"x": np.random.uniform(0, 640),
|
||||
"y": np.random.uniform(0, 480),
|
||||
"width": np.random.uniform(50, 200),
|
||||
"height": np.random.uniform(100, 300)
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
"predictions": predictions,
|
||||
"inference_time_ms": actual_inference_time * 1000,
|
||||
"model_complexity": self.model_complexity,
|
||||
"batch_size": batch_size
|
||||
}
|
||||
|
||||
def get_performance_stats(self) -> Dict[str, Any]:
|
||||
"""Get performance statistics."""
|
||||
avg_inference_time = (
|
||||
self.total_inference_time / self.inference_count
|
||||
if self.inference_count > 0 else 0
|
||||
)
|
||||
|
||||
return {
|
||||
"total_inferences": self.inference_count,
|
||||
"total_time_seconds": self.total_inference_time,
|
||||
"average_inference_time_ms": avg_inference_time * 1000,
|
||||
"throughput_fps": 1.0 / avg_inference_time if avg_inference_time > 0 else 0,
|
||||
"model_complexity": self.model_complexity
|
||||
}
|
||||
|
||||
|
||||
class TestInferenceSpeed:
|
||||
"""Test inference speed for different model configurations."""
|
||||
|
||||
@pytest.fixture
|
||||
def lightweight_model(self):
|
||||
"""Create lightweight model."""
|
||||
return MockPoseModel("lightweight")
|
||||
|
||||
@pytest.fixture
|
||||
def standard_model(self):
|
||||
"""Create standard model."""
|
||||
return MockPoseModel("standard")
|
||||
|
||||
@pytest.fixture
|
||||
def high_accuracy_model(self):
|
||||
"""Create high accuracy model."""
|
||||
return MockPoseModel("high_accuracy")
|
||||
|
||||
@pytest.fixture
|
||||
def sample_features(self):
|
||||
"""Create sample feature data."""
|
||||
return np.random.rand(64, 32) # 64x32 feature matrix
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_inference_speed_should_fail_initially(self, standard_model, sample_features):
|
||||
"""Test single inference speed - should fail initially."""
|
||||
await standard_model.load_model()
|
||||
|
||||
start_time = time.time()
|
||||
result = await standard_model.predict(sample_features)
|
||||
end_time = time.time()
|
||||
|
||||
inference_time = (end_time - start_time) * 1000 # Convert to ms
|
||||
|
||||
# This will fail initially
|
||||
assert inference_time < 100 # Should be less than 100ms
|
||||
assert result["inference_time_ms"] > 0
|
||||
assert len(result["predictions"]) > 0
|
||||
assert result["model_complexity"] == "standard"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_complexity_comparison_should_fail_initially(self, sample_features):
|
||||
"""Test model complexity comparison - should fail initially."""
|
||||
models = {
|
||||
"lightweight": MockPoseModel("lightweight"),
|
||||
"standard": MockPoseModel("standard"),
|
||||
"high_accuracy": MockPoseModel("high_accuracy")
|
||||
}
|
||||
|
||||
# Load all models
|
||||
for model in models.values():
|
||||
await model.load_model()
|
||||
|
||||
# Run inference on each model
|
||||
results = {}
|
||||
for name, model in models.items():
|
||||
start_time = time.time()
|
||||
result = await model.predict(sample_features)
|
||||
end_time = time.time()
|
||||
|
||||
results[name] = {
|
||||
"inference_time_ms": (end_time - start_time) * 1000,
|
||||
"result": result
|
||||
}
|
||||
|
||||
# This will fail initially
|
||||
# Lightweight should be fastest
|
||||
assert results["lightweight"]["inference_time_ms"] < results["standard"]["inference_time_ms"]
|
||||
assert results["standard"]["inference_time_ms"] < results["high_accuracy"]["inference_time_ms"]
|
||||
|
||||
# All should complete within reasonable time
|
||||
for name, result in results.items():
|
||||
assert result["inference_time_ms"] < 500 # Less than 500ms
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_inference_performance_should_fail_initially(self, standard_model):
|
||||
"""Test batch inference performance - should fail initially."""
|
||||
await standard_model.load_model()
|
||||
|
||||
# Test different batch sizes
|
||||
batch_sizes = [1, 4, 8, 16]
|
||||
results = {}
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
# Create batch of features
|
||||
batch_features = np.random.rand(batch_size, 64, 32)
|
||||
|
||||
start_time = time.time()
|
||||
result = await standard_model.predict(batch_features)
|
||||
end_time = time.time()
|
||||
|
||||
total_time = (end_time - start_time) * 1000
|
||||
per_sample_time = total_time / batch_size
|
||||
|
||||
results[batch_size] = {
|
||||
"total_time_ms": total_time,
|
||||
"per_sample_time_ms": per_sample_time,
|
||||
"throughput_fps": 1000 / per_sample_time,
|
||||
"predictions": len(result["predictions"])
|
||||
}
|
||||
|
||||
# This will fail initially
|
||||
# Batch processing should be more efficient per sample
|
||||
assert results[1]["per_sample_time_ms"] > results[4]["per_sample_time_ms"]
|
||||
assert results[4]["per_sample_time_ms"] > results[8]["per_sample_time_ms"]
|
||||
|
||||
# Verify correct number of predictions
|
||||
for batch_size, result in results.items():
|
||||
assert result["predictions"] == batch_size
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sustained_inference_performance_should_fail_initially(self, standard_model, sample_features):
|
||||
"""Test sustained inference performance - should fail initially."""
|
||||
await standard_model.load_model()
|
||||
|
||||
# Run many inferences to test sustained performance
|
||||
num_inferences = 50
|
||||
inference_times = []
|
||||
|
||||
for i in range(num_inferences):
|
||||
start_time = time.time()
|
||||
await standard_model.predict(sample_features)
|
||||
end_time = time.time()
|
||||
|
||||
inference_times.append((end_time - start_time) * 1000)
|
||||
|
||||
# This will fail initially
|
||||
# Calculate performance metrics
|
||||
avg_time = np.mean(inference_times)
|
||||
std_time = np.std(inference_times)
|
||||
min_time = np.min(inference_times)
|
||||
max_time = np.max(inference_times)
|
||||
|
||||
assert avg_time < 100 # Average should be less than 100ms
|
||||
assert std_time < 20 # Standard deviation should be low (consistent performance)
|
||||
assert max_time < avg_time * 2 # No inference should take more than 2x average
|
||||
|
||||
# Check model statistics
|
||||
stats = standard_model.get_performance_stats()
|
||||
assert stats["total_inferences"] == num_inferences
|
||||
assert stats["throughput_fps"] > 10 # Should achieve at least 10 FPS
|
||||
|
||||
|
||||
class TestInferenceOptimization:
|
||||
"""Test inference optimization techniques."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_warmup_effect_should_fail_initially(self, standard_model, sample_features):
|
||||
"""Test model warmup effect - should fail initially."""
|
||||
await standard_model.load_model()
|
||||
|
||||
# First inference (cold start)
|
||||
start_time = time.time()
|
||||
await standard_model.predict(sample_features)
|
||||
cold_start_time = (time.time() - start_time) * 1000
|
||||
|
||||
# Subsequent inferences (warmed up)
|
||||
warm_times = []
|
||||
for _ in range(5):
|
||||
start_time = time.time()
|
||||
await standard_model.predict(sample_features)
|
||||
warm_times.append((time.time() - start_time) * 1000)
|
||||
|
||||
avg_warm_time = np.mean(warm_times)
|
||||
|
||||
# This will fail initially
|
||||
# Warm inferences should be faster than cold start
|
||||
assert avg_warm_time <= cold_start_time
|
||||
assert cold_start_time > 0
|
||||
assert avg_warm_time > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_inference_performance_should_fail_initially(self, sample_features):
|
||||
"""Test concurrent inference performance - should fail initially."""
|
||||
# Create multiple model instances
|
||||
models = [MockPoseModel("standard") for _ in range(3)]
|
||||
|
||||
# Load all models
|
||||
for model in models:
|
||||
await model.load_model()
|
||||
|
||||
async def run_inference(model, features):
|
||||
start_time = time.time()
|
||||
result = await model.predict(features)
|
||||
end_time = time.time()
|
||||
return (end_time - start_time) * 1000
|
||||
|
||||
# Run concurrent inferences
|
||||
tasks = [run_inference(model, sample_features) for model in models]
|
||||
inference_times = await asyncio.gather(*tasks)
|
||||
|
||||
# This will fail initially
|
||||
# All inferences should complete
|
||||
assert len(inference_times) == 3
|
||||
assert all(time > 0 for time in inference_times)
|
||||
|
||||
# Concurrent execution shouldn't be much slower than sequential
|
||||
avg_concurrent_time = np.mean(inference_times)
|
||||
assert avg_concurrent_time < 200 # Should complete within 200ms each
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_usage_during_inference_should_fail_initially(self, standard_model, sample_features):
|
||||
"""Test memory usage during inference - should fail initially."""
|
||||
process = psutil.Process(os.getpid())
|
||||
|
||||
await standard_model.load_model()
|
||||
initial_memory = process.memory_info().rss
|
||||
|
||||
# Run multiple inferences
|
||||
for i in range(20):
|
||||
await standard_model.predict(sample_features)
|
||||
|
||||
# Check memory every 5 inferences
|
||||
if i % 5 == 0:
|
||||
current_memory = process.memory_info().rss
|
||||
memory_increase = current_memory - initial_memory
|
||||
|
||||
# This will fail initially
|
||||
# Memory increase should be reasonable (less than 50MB)
|
||||
assert memory_increase < 50 * 1024 * 1024
|
||||
|
||||
final_memory = process.memory_info().rss
|
||||
total_increase = final_memory - initial_memory
|
||||
|
||||
# Total memory increase should be reasonable
|
||||
assert total_increase < 100 * 1024 * 1024 # Less than 100MB
|
||||
|
||||
|
||||
class TestInferenceAccuracy:
|
||||
"""Test inference accuracy and quality metrics."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prediction_consistency_should_fail_initially(self, standard_model, sample_features):
|
||||
"""Test prediction consistency - should fail initially."""
|
||||
await standard_model.load_model()
|
||||
|
||||
# Run same inference multiple times
|
||||
results = []
|
||||
for _ in range(5):
|
||||
result = await standard_model.predict(sample_features)
|
||||
results.append(result)
|
||||
|
||||
# This will fail initially
|
||||
# All results should have similar structure
|
||||
for result in results:
|
||||
assert "predictions" in result
|
||||
assert "inference_time_ms" in result
|
||||
assert len(result["predictions"]) > 0
|
||||
|
||||
# Inference times should be consistent
|
||||
inference_times = [r["inference_time_ms"] for r in results]
|
||||
avg_time = np.mean(inference_times)
|
||||
std_time = np.std(inference_times)
|
||||
|
||||
assert std_time < avg_time * 0.5 # Standard deviation should be less than 50% of mean
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confidence_score_distribution_should_fail_initially(self, standard_model, sample_features):
|
||||
"""Test confidence score distribution - should fail initially."""
|
||||
await standard_model.load_model()
|
||||
|
||||
# Collect confidence scores from multiple inferences
|
||||
all_confidences = []
|
||||
|
||||
for _ in range(20):
|
||||
result = await standard_model.predict(sample_features)
|
||||
for prediction in result["predictions"]:
|
||||
all_confidences.append(prediction["confidence"])
|
||||
|
||||
# This will fail initially
|
||||
if all_confidences: # Only test if we have predictions
|
||||
# Confidence scores should be in valid range
|
||||
assert all(0.0 <= conf <= 1.0 for conf in all_confidences)
|
||||
|
||||
# Should have reasonable distribution
|
||||
avg_confidence = np.mean(all_confidences)
|
||||
assert 0.3 <= avg_confidence <= 0.95 # Reasonable average confidence
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keypoint_detection_quality_should_fail_initially(self, standard_model, sample_features):
|
||||
"""Test keypoint detection quality - should fail initially."""
|
||||
await standard_model.load_model()
|
||||
|
||||
result = await standard_model.predict(sample_features)
|
||||
|
||||
# This will fail initially
|
||||
for prediction in result["predictions"]:
|
||||
keypoints = prediction["keypoints"]
|
||||
|
||||
# Should have correct number of keypoints
|
||||
assert len(keypoints) == 17 # Standard pose has 17 keypoints
|
||||
|
||||
# Each keypoint should have x, y, confidence
|
||||
for keypoint in keypoints:
|
||||
assert len(keypoint) == 3
|
||||
x, y, conf = keypoint
|
||||
assert isinstance(x, (int, float))
|
||||
assert isinstance(y, (int, float))
|
||||
assert 0.0 <= conf <= 1.0
|
||||
|
||||
|
||||
class TestInferenceScaling:
|
||||
"""Test inference scaling characteristics."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_size_scaling_should_fail_initially(self, standard_model):
|
||||
"""Test inference scaling with input size - should fail initially."""
|
||||
await standard_model.load_model()
|
||||
|
||||
# Test different input sizes
|
||||
input_sizes = [(32, 16), (64, 32), (128, 64), (256, 128)]
|
||||
results = {}
|
||||
|
||||
for height, width in input_sizes:
|
||||
features = np.random.rand(height, width)
|
||||
|
||||
start_time = time.time()
|
||||
result = await standard_model.predict(features)
|
||||
end_time = time.time()
|
||||
|
||||
inference_time = (end_time - start_time) * 1000
|
||||
input_size = height * width
|
||||
|
||||
results[input_size] = {
|
||||
"inference_time_ms": inference_time,
|
||||
"dimensions": (height, width),
|
||||
"predictions": len(result["predictions"])
|
||||
}
|
||||
|
||||
# This will fail initially
|
||||
# Larger inputs should generally take longer
|
||||
sizes = sorted(results.keys())
|
||||
for i in range(len(sizes) - 1):
|
||||
current_size = sizes[i]
|
||||
next_size = sizes[i + 1]
|
||||
|
||||
# Allow some variance, but larger inputs should generally be slower
|
||||
time_ratio = results[next_size]["inference_time_ms"] / results[current_size]["inference_time_ms"]
|
||||
assert time_ratio >= 0.8 # Next size shouldn't be much faster
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_throughput_under_load_should_fail_initially(self, standard_model, sample_features):
|
||||
"""Test throughput under sustained load - should fail initially."""
|
||||
await standard_model.load_model()
|
||||
|
||||
# Simulate sustained load
|
||||
duration_seconds = 5
|
||||
start_time = time.time()
|
||||
inference_count = 0
|
||||
|
||||
while time.time() - start_time < duration_seconds:
|
||||
await standard_model.predict(sample_features)
|
||||
inference_count += 1
|
||||
|
||||
actual_duration = time.time() - start_time
|
||||
throughput = inference_count / actual_duration
|
||||
|
||||
# This will fail initially
|
||||
# Should maintain reasonable throughput under load
|
||||
assert throughput > 5 # At least 5 FPS
|
||||
assert inference_count > 20 # Should complete at least 20 inferences in 5 seconds
|
||||
|
||||
# Check model statistics
|
||||
stats = standard_model.get_performance_stats()
|
||||
assert stats["total_inferences"] >= inference_count
|
||||
assert stats["throughput_fps"] > 0
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
class TestInferenceBenchmarks:
|
||||
"""Benchmark tests for inference performance."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_benchmark_lightweight_model_should_fail_initially(self, benchmark):
|
||||
"""Benchmark lightweight model performance - should fail initially."""
|
||||
model = MockPoseModel("lightweight")
|
||||
await model.load_model()
|
||||
features = np.random.rand(64, 32)
|
||||
|
||||
async def run_inference():
|
||||
return await model.predict(features)
|
||||
|
||||
# This will fail initially
|
||||
# Benchmark the inference
|
||||
result = await run_inference()
|
||||
assert result["inference_time_ms"] < 50 # Should be less than 50ms
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_benchmark_batch_processing_should_fail_initially(self, benchmark):
|
||||
"""Benchmark batch processing performance - should fail initially."""
|
||||
model = MockPoseModel("standard")
|
||||
await model.load_model()
|
||||
batch_features = np.random.rand(8, 64, 32) # Batch of 8
|
||||
|
||||
async def run_batch_inference():
|
||||
return await model.predict(batch_features)
|
||||
|
||||
# This will fail initially
|
||||
result = await run_batch_inference()
|
||||
assert len(result["predictions"]) == 8
|
||||
assert result["inference_time_ms"] < 200 # Batch should be efficient
|
||||
Reference in New Issue
Block a user