This commit is contained in:
rUv
2025-06-07 11:44:19 +00:00
parent 43e92c5494
commit c378b705ca
95 changed files with 43677 additions and 0 deletions

View File

@@ -0,0 +1,736 @@
"""
End-to-end tests for healthcare fall detection scenario.
Tests complete workflow from CSI data collection to fall alert generation.
"""
import pytest
import asyncio
import numpy as np
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional
from unittest.mock import AsyncMock, MagicMock, patch
import json
from dataclasses import dataclass
from enum import Enum
class AlertSeverity(Enum):
"""Alert severity levels."""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
@dataclass
class HealthcareAlert:
"""Healthcare alert data structure."""
alert_id: str
timestamp: datetime
alert_type: str
severity: AlertSeverity
patient_id: str
location: str
confidence: float
description: str
metadata: Dict[str, Any]
class MockPatientMonitor:
"""Mock patient monitoring system."""
def __init__(self, patient_id: str, room_id: str):
self.patient_id = patient_id
self.room_id = room_id
self.is_monitoring = False
self.baseline_activity = None
self.activity_history = []
self.alerts_generated = []
self.fall_detection_enabled = True
self.sensitivity_level = "medium"
async def start_monitoring(self) -> bool:
"""Start patient monitoring."""
if self.is_monitoring:
return False
self.is_monitoring = True
return True
async def stop_monitoring(self) -> bool:
"""Stop patient monitoring."""
if not self.is_monitoring:
return False
self.is_monitoring = False
return True
async def process_pose_data(self, pose_data: Dict[str, Any]) -> Optional[HealthcareAlert]:
"""Process pose data and detect potential issues."""
if not self.is_monitoring:
return None
# Extract activity metrics
activity_metrics = self._extract_activity_metrics(pose_data)
self.activity_history.append(activity_metrics)
# Keep only recent history
if len(self.activity_history) > 100:
self.activity_history = self.activity_history[-100:]
# Detect anomalies
alert = await self._detect_anomalies(activity_metrics, pose_data)
if alert:
self.alerts_generated.append(alert)
return alert
def _extract_activity_metrics(self, pose_data: Dict[str, Any]) -> Dict[str, Any]:
"""Extract activity metrics from pose data."""
persons = pose_data.get("persons", [])
if not persons:
return {
"person_count": 0,
"activity_level": 0.0,
"posture": "unknown",
"movement_speed": 0.0,
"stability_score": 1.0
}
# Analyze first person (primary patient)
person = persons[0]
# Extract posture from activity field or bounding box analysis
posture = person.get("activity", "standing")
# If no activity specified, analyze bounding box for fall detection
if posture == "standing" and "bounding_box" in person:
bbox = person["bounding_box"]
width = bbox.get("width", 80)
height = bbox.get("height", 180)
# Fall detection: if width > height, likely fallen
if width > height * 1.5:
posture = "fallen"
# Calculate activity metrics based on posture
if posture == "fallen":
activity_level = 0.1
movement_speed = 0.0
stability_score = 0.2
elif posture == "walking":
activity_level = 0.8
movement_speed = 1.5
stability_score = 0.7
elif posture == "sitting":
activity_level = 0.3
movement_speed = 0.1
stability_score = 0.9
else: # standing or other
activity_level = 0.5
movement_speed = 0.2
stability_score = 0.8
return {
"person_count": len(persons),
"activity_level": activity_level,
"posture": posture,
"movement_speed": movement_speed,
"stability_score": stability_score,
"confidence": person.get("confidence", 0.0)
}
async def _detect_anomalies(self, current_metrics: Dict[str, Any], pose_data: Dict[str, Any]) -> Optional[HealthcareAlert]:
"""Detect health-related anomalies."""
# Fall detection
if current_metrics["posture"] == "fallen":
return await self._generate_fall_alert(current_metrics, pose_data)
# Prolonged inactivity detection
if len(self.activity_history) >= 10:
recent_activity = [m["activity_level"] for m in self.activity_history[-10:]]
avg_activity = np.mean(recent_activity)
if avg_activity < 0.1: # Very low activity
return await self._generate_inactivity_alert(current_metrics, pose_data)
# Unusual movement patterns
if current_metrics["stability_score"] < 0.4:
return await self._generate_instability_alert(current_metrics, pose_data)
return None
async def _generate_fall_alert(self, metrics: Dict[str, Any], pose_data: Dict[str, Any]) -> HealthcareAlert:
"""Generate fall detection alert."""
return HealthcareAlert(
alert_id=f"fall_{self.patient_id}_{int(datetime.utcnow().timestamp())}",
timestamp=datetime.utcnow(),
alert_type="fall_detected",
severity=AlertSeverity.CRITICAL,
patient_id=self.patient_id,
location=self.room_id,
confidence=metrics["confidence"],
description=f"Fall detected for patient {self.patient_id} in {self.room_id}",
metadata={
"posture": metrics["posture"],
"stability_score": metrics["stability_score"],
"pose_data": pose_data
}
)
async def _generate_inactivity_alert(self, metrics: Dict[str, Any], pose_data: Dict[str, Any]) -> HealthcareAlert:
"""Generate prolonged inactivity alert."""
return HealthcareAlert(
alert_id=f"inactivity_{self.patient_id}_{int(datetime.utcnow().timestamp())}",
timestamp=datetime.utcnow(),
alert_type="prolonged_inactivity",
severity=AlertSeverity.MEDIUM,
patient_id=self.patient_id,
location=self.room_id,
confidence=metrics["confidence"],
description=f"Prolonged inactivity detected for patient {self.patient_id}",
metadata={
"activity_level": metrics["activity_level"],
"duration_minutes": 10,
"pose_data": pose_data
}
)
async def _generate_instability_alert(self, metrics: Dict[str, Any], pose_data: Dict[str, Any]) -> HealthcareAlert:
"""Generate movement instability alert."""
return HealthcareAlert(
alert_id=f"instability_{self.patient_id}_{int(datetime.utcnow().timestamp())}",
timestamp=datetime.utcnow(),
alert_type="movement_instability",
severity=AlertSeverity.HIGH,
patient_id=self.patient_id,
location=self.room_id,
confidence=metrics["confidence"],
description=f"Movement instability detected for patient {self.patient_id}",
metadata={
"stability_score": metrics["stability_score"],
"movement_speed": metrics["movement_speed"],
"pose_data": pose_data
}
)
def get_monitoring_stats(self) -> Dict[str, Any]:
"""Get monitoring statistics."""
return {
"patient_id": self.patient_id,
"room_id": self.room_id,
"is_monitoring": self.is_monitoring,
"total_alerts": len(self.alerts_generated),
"alert_types": {
alert.alert_type: len([a for a in self.alerts_generated if a.alert_type == alert.alert_type])
for alert in self.alerts_generated
},
"activity_samples": len(self.activity_history),
"fall_detection_enabled": self.fall_detection_enabled
}
class MockHealthcareNotificationSystem:
"""Mock healthcare notification system."""
def __init__(self):
self.notifications_sent = []
self.notification_channels = {
"nurse_station": True,
"mobile_app": True,
"email": True,
"sms": False
}
self.escalation_rules = {
AlertSeverity.CRITICAL: ["nurse_station", "mobile_app", "sms"],
AlertSeverity.HIGH: ["nurse_station", "mobile_app"],
AlertSeverity.MEDIUM: ["nurse_station"],
AlertSeverity.LOW: ["mobile_app"]
}
async def send_alert_notification(self, alert: HealthcareAlert) -> Dict[str, bool]:
"""Send alert notification through appropriate channels."""
channels_to_notify = self.escalation_rules.get(alert.severity, ["nurse_station"])
results = {}
for channel in channels_to_notify:
if self.notification_channels.get(channel, False):
success = await self._send_to_channel(channel, alert)
results[channel] = success
if success:
self.notifications_sent.append({
"alert_id": alert.alert_id,
"channel": channel,
"timestamp": datetime.utcnow(),
"severity": alert.severity.value
})
return results
async def _send_to_channel(self, channel: str, alert: HealthcareAlert) -> bool:
"""Send notification to specific channel."""
# Simulate network delay
await asyncio.sleep(0.01)
# Simulate occasional failures
if np.random.random() < 0.05: # 5% failure rate
return False
return True
def get_notification_stats(self) -> Dict[str, Any]:
"""Get notification statistics."""
return {
"total_notifications": len(self.notifications_sent),
"notifications_by_channel": {
channel: len([n for n in self.notifications_sent if n["channel"] == channel])
for channel in self.notification_channels.keys()
},
"notifications_by_severity": {
severity.value: len([n for n in self.notifications_sent if n["severity"] == severity.value])
for severity in AlertSeverity
}
}
class TestHealthcareFallDetection:
"""Test healthcare fall detection workflow."""
@pytest.fixture
def patient_monitor(self):
"""Create patient monitor."""
return MockPatientMonitor("patient_001", "room_101")
@pytest.fixture
def notification_system(self):
"""Create notification system."""
return MockHealthcareNotificationSystem()
@pytest.fixture
def fall_pose_data(self):
"""Create pose data indicating a fall."""
return {
"persons": [
{
"person_id": "patient_001",
"confidence": 0.92,
"bounding_box": {"x": 200, "y": 400, "width": 150, "height": 80}, # Horizontal position
"activity": "fallen",
"keypoints": [[x, y, 0.8] for x, y in zip(range(17), range(17))]
}
],
"zone_summary": {"room_101": 1},
"timestamp": datetime.utcnow().isoformat()
}
@pytest.fixture
def normal_pose_data(self):
"""Create normal pose data."""
return {
"persons": [
{
"person_id": "patient_001",
"confidence": 0.88,
"bounding_box": {"x": 200, "y": 150, "width": 80, "height": 180},
"activity": "standing",
"keypoints": [[x, y, 0.9] for x, y in zip(range(17), range(17))]
}
],
"zone_summary": {"room_101": 1},
"timestamp": datetime.utcnow().isoformat()
}
@pytest.mark.asyncio
async def test_fall_detection_workflow_should_fail_initially(self, patient_monitor, notification_system, fall_pose_data):
"""Test fall detection workflow - should fail initially."""
# Start monitoring
result = await patient_monitor.start_monitoring()
# This will fail initially
assert result is True
assert patient_monitor.is_monitoring is True
# Process fall pose data
alert = await patient_monitor.process_pose_data(fall_pose_data)
# Should generate fall alert
assert alert is not None
assert alert.alert_type == "fall_detected"
assert alert.severity == AlertSeverity.CRITICAL
assert alert.patient_id == "patient_001"
# Send notification
notification_results = await notification_system.send_alert_notification(alert)
# Should notify appropriate channels
assert len(notification_results) > 0
assert any(notification_results.values()) # At least one channel should succeed
# Check statistics
monitor_stats = patient_monitor.get_monitoring_stats()
assert monitor_stats["total_alerts"] == 1
notification_stats = notification_system.get_notification_stats()
assert notification_stats["total_notifications"] > 0
@pytest.mark.asyncio
async def test_normal_activity_monitoring_should_fail_initially(self, patient_monitor, normal_pose_data):
"""Test normal activity monitoring - should fail initially."""
await patient_monitor.start_monitoring()
# Process multiple normal pose data samples
alerts_generated = []
for i in range(10):
alert = await patient_monitor.process_pose_data(normal_pose_data)
if alert:
alerts_generated.append(alert)
# This will fail initially
# Should not generate alerts for normal activity
assert len(alerts_generated) == 0
# Should have activity history
stats = patient_monitor.get_monitoring_stats()
assert stats["activity_samples"] == 10
assert stats["is_monitoring"] is True
@pytest.mark.asyncio
async def test_prolonged_inactivity_detection_should_fail_initially(self, patient_monitor):
"""Test prolonged inactivity detection - should fail initially."""
await patient_monitor.start_monitoring()
# Simulate prolonged inactivity
inactive_pose_data = {
"persons": [], # No person detected
"zone_summary": {"room_101": 0},
"timestamp": datetime.utcnow().isoformat()
}
alerts_generated = []
# Process multiple inactive samples
for i in range(15):
alert = await patient_monitor.process_pose_data(inactive_pose_data)
if alert:
alerts_generated.append(alert)
# This will fail initially
# Should generate inactivity alert after sufficient samples
inactivity_alerts = [a for a in alerts_generated if a.alert_type == "prolonged_inactivity"]
assert len(inactivity_alerts) > 0
# Check alert properties
alert = inactivity_alerts[0]
assert alert.severity == AlertSeverity.MEDIUM
assert alert.patient_id == "patient_001"
@pytest.mark.asyncio
async def test_movement_instability_detection_should_fail_initially(self, patient_monitor):
"""Test movement instability detection - should fail initially."""
await patient_monitor.start_monitoring()
# Simulate unstable movement
unstable_pose_data = {
"persons": [
{
"person_id": "patient_001",
"confidence": 0.65, # Lower confidence indicates instability
"bounding_box": {"x": 200, "y": 150, "width": 80, "height": 180},
"activity": "walking",
"keypoints": [[x, y, 0.5] for x, y in zip(range(17), range(17))] # Low keypoint confidence
}
],
"zone_summary": {"room_101": 1},
"timestamp": datetime.utcnow().isoformat()
}
# Process unstable pose data
alert = await patient_monitor.process_pose_data(unstable_pose_data)
# This will fail initially
# May generate instability alert based on stability score
if alert and alert.alert_type == "movement_instability":
assert alert.severity == AlertSeverity.HIGH
assert alert.patient_id == "patient_001"
assert "stability_score" in alert.metadata
class TestHealthcareMultiPatientMonitoring:
"""Test multi-patient monitoring scenarios."""
@pytest.fixture
def multi_patient_setup(self):
"""Create multi-patient monitoring setup."""
patients = {
"patient_001": MockPatientMonitor("patient_001", "room_101"),
"patient_002": MockPatientMonitor("patient_002", "room_102"),
"patient_003": MockPatientMonitor("patient_003", "room_103")
}
notification_system = MockHealthcareNotificationSystem()
return patients, notification_system
@pytest.mark.asyncio
async def test_concurrent_patient_monitoring_should_fail_initially(self, multi_patient_setup):
"""Test concurrent patient monitoring - should fail initially."""
patients, notification_system = multi_patient_setup
# Start monitoring for all patients
start_results = []
for patient_id, monitor in patients.items():
result = await monitor.start_monitoring()
start_results.append(result)
# This will fail initially
assert all(start_results)
assert all(monitor.is_monitoring for monitor in patients.values())
# Simulate concurrent pose data processing
pose_data_samples = [
{
"persons": [
{
"person_id": patient_id,
"confidence": 0.85,
"bounding_box": {"x": 200, "y": 150, "width": 80, "height": 180},
"activity": "standing"
}
],
"zone_summary": {f"room_{101 + i}": 1},
"timestamp": datetime.utcnow().isoformat()
}
for i, patient_id in enumerate(patients.keys())
]
# Process data for all patients concurrently
tasks = []
for (patient_id, monitor), pose_data in zip(patients.items(), pose_data_samples):
task = asyncio.create_task(monitor.process_pose_data(pose_data))
tasks.append(task)
alerts = await asyncio.gather(*tasks)
# Check results
assert len(alerts) == len(patients)
# Get statistics for all patients
all_stats = {}
for patient_id, monitor in patients.items():
all_stats[patient_id] = monitor.get_monitoring_stats()
assert len(all_stats) == 3
assert all(stats["is_monitoring"] for stats in all_stats.values())
@pytest.mark.asyncio
async def test_alert_prioritization_should_fail_initially(self, multi_patient_setup):
"""Test alert prioritization across patients - should fail initially."""
patients, notification_system = multi_patient_setup
# Start monitoring
for monitor in patients.values():
await monitor.start_monitoring()
# Generate different severity alerts
alert_scenarios = [
("patient_001", "fall_detected", AlertSeverity.CRITICAL),
("patient_002", "prolonged_inactivity", AlertSeverity.MEDIUM),
("patient_003", "movement_instability", AlertSeverity.HIGH)
]
generated_alerts = []
for patient_id, alert_type, expected_severity in alert_scenarios:
# Create appropriate pose data for each scenario
if alert_type == "fall_detected":
pose_data = {
"persons": [{"person_id": patient_id, "confidence": 0.9, "activity": "fallen"}],
"zone_summary": {f"room_{patients[patient_id].room_id}": 1}
}
else:
pose_data = {
"persons": [{"person_id": patient_id, "confidence": 0.7, "activity": "standing"}],
"zone_summary": {f"room_{patients[patient_id].room_id}": 1}
}
alert = await patients[patient_id].process_pose_data(pose_data)
if alert:
generated_alerts.append(alert)
# This will fail initially
# Should have generated alerts
assert len(generated_alerts) > 0
# Send notifications for all alerts
notification_tasks = [
notification_system.send_alert_notification(alert)
for alert in generated_alerts
]
notification_results = await asyncio.gather(*notification_tasks)
# Check notification prioritization
notification_stats = notification_system.get_notification_stats()
assert notification_stats["total_notifications"] > 0
# Critical alerts should use more channels
critical_notifications = [
n for n in notification_system.notifications_sent
if n["severity"] == "critical"
]
if critical_notifications:
# Critical alerts should be sent to multiple channels
critical_channels = set(n["channel"] for n in critical_notifications)
assert len(critical_channels) >= 1
class TestHealthcareSystemIntegration:
"""Test healthcare system integration scenarios."""
@pytest.mark.asyncio
async def test_end_to_end_healthcare_workflow_should_fail_initially(self):
"""Test complete end-to-end healthcare workflow - should fail initially."""
# Setup complete healthcare monitoring system
class HealthcareMonitoringSystem:
def __init__(self):
self.patient_monitors = {}
self.notification_system = MockHealthcareNotificationSystem()
self.alert_history = []
self.system_status = "operational"
async def add_patient(self, patient_id: str, room_id: str) -> bool:
"""Add patient to monitoring system."""
if patient_id in self.patient_monitors:
return False
monitor = MockPatientMonitor(patient_id, room_id)
self.patient_monitors[patient_id] = monitor
return await monitor.start_monitoring()
async def process_pose_update(self, room_id: str, pose_data: Dict[str, Any]) -> List[HealthcareAlert]:
"""Process pose update for room."""
alerts = []
# Find patients in this room
room_patients = [
(patient_id, monitor) for patient_id, monitor in self.patient_monitors.items()
if monitor.room_id == room_id
]
for patient_id, monitor in room_patients:
alert = await monitor.process_pose_data(pose_data)
if alert:
alerts.append(alert)
self.alert_history.append(alert)
# Send notification
await self.notification_system.send_alert_notification(alert)
return alerts
def get_system_status(self) -> Dict[str, Any]:
"""Get overall system status."""
return {
"system_status": self.system_status,
"total_patients": len(self.patient_monitors),
"active_monitors": sum(1 for m in self.patient_monitors.values() if m.is_monitoring),
"total_alerts": len(self.alert_history),
"notification_stats": self.notification_system.get_notification_stats()
}
healthcare_system = HealthcareMonitoringSystem()
# Add patients to system
patients = [
("patient_001", "room_101"),
("patient_002", "room_102"),
("patient_003", "room_103")
]
for patient_id, room_id in patients:
result = await healthcare_system.add_patient(patient_id, room_id)
assert result is True
# Simulate pose data updates for different rooms
pose_updates = [
("room_101", {
"persons": [{"person_id": "patient_001", "confidence": 0.9, "activity": "fallen"}],
"zone_summary": {"room_101": 1}
}),
("room_102", {
"persons": [{"person_id": "patient_002", "confidence": 0.8, "activity": "standing"}],
"zone_summary": {"room_102": 1}
}),
("room_103", {
"persons": [], # No person detected
"zone_summary": {"room_103": 0}
})
]
all_alerts = []
for room_id, pose_data in pose_updates:
alerts = await healthcare_system.process_pose_update(room_id, pose_data)
all_alerts.extend(alerts)
# This will fail initially
# Should have processed all updates
assert len(pose_updates) == 3
# Check system status
system_status = healthcare_system.get_system_status()
assert system_status["total_patients"] == 3
assert system_status["active_monitors"] == 3
assert system_status["system_status"] == "operational"
# Should have generated some alerts
if all_alerts:
assert len(all_alerts) > 0
assert system_status["total_alerts"] > 0
@pytest.mark.asyncio
async def test_healthcare_system_resilience_should_fail_initially(self):
"""Test healthcare system resilience - should fail initially."""
patient_monitor = MockPatientMonitor("patient_001", "room_101")
notification_system = MockHealthcareNotificationSystem()
await patient_monitor.start_monitoring()
# Simulate system stress with rapid pose updates
rapid_updates = 50
alerts_generated = []
for i in range(rapid_updates):
# Alternate between normal and concerning pose data
if i % 10 == 0: # Every 10th update is concerning
pose_data = {
"persons": [{"person_id": "patient_001", "confidence": 0.9, "activity": "fallen"}],
"zone_summary": {"room_101": 1}
}
else:
pose_data = {
"persons": [{"person_id": "patient_001", "confidence": 0.85, "activity": "standing"}],
"zone_summary": {"room_101": 1}
}
alert = await patient_monitor.process_pose_data(pose_data)
if alert:
alerts_generated.append(alert)
await notification_system.send_alert_notification(alert)
# This will fail initially
# System should handle rapid updates gracefully
stats = patient_monitor.get_monitoring_stats()
assert stats["activity_samples"] == rapid_updates
assert stats["is_monitoring"] is True
# Should have generated some alerts but not excessive
assert len(alerts_generated) <= rapid_updates / 5 # At most 20% alert rate
notification_stats = notification_system.get_notification_stats()
assert notification_stats["total_notifications"] >= len(alerts_generated)

661
tests/fixtures/api_client.py vendored Normal file
View File

@@ -0,0 +1,661 @@
"""
Test client utilities for API testing.
Provides mock and real API clients for comprehensive testing.
"""
import asyncio
import aiohttp
import json
import time
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional, Union, AsyncGenerator
from unittest.mock import AsyncMock, MagicMock
import websockets
import jwt
from dataclasses import dataclass, asdict
from enum import Enum
class AuthenticationError(Exception):
"""Authentication related errors."""
pass
class APIError(Exception):
"""General API errors."""
pass
class RateLimitError(Exception):
"""Rate limiting errors."""
pass
@dataclass
class APIResponse:
"""API response wrapper."""
status_code: int
data: Dict[str, Any]
headers: Dict[str, str]
response_time_ms: float
timestamp: datetime
class MockAPIClient:
"""Mock API client for testing."""
def __init__(self, base_url: str = "http://localhost:8000"):
self.base_url = base_url
self.session = None
self.auth_token = None
self.refresh_token = None
self.token_expires_at = None
self.request_history = []
self.response_delays = {}
self.error_simulation = {}
self.rate_limit_config = {
"enabled": False,
"requests_per_minute": 60,
"current_count": 0,
"window_start": time.time()
}
async def __aenter__(self):
"""Async context manager entry."""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
await self.disconnect()
async def connect(self):
"""Initialize connection."""
self.session = aiohttp.ClientSession()
async def disconnect(self):
"""Close connection."""
if self.session:
await self.session.close()
def set_response_delay(self, endpoint: str, delay_ms: float):
"""Set artificial delay for endpoint."""
self.response_delays[endpoint] = delay_ms
def simulate_error(self, endpoint: str, error_type: str, probability: float = 1.0):
"""Simulate errors for endpoint."""
self.error_simulation[endpoint] = {
"type": error_type,
"probability": probability
}
def enable_rate_limiting(self, requests_per_minute: int = 60):
"""Enable rate limiting simulation."""
self.rate_limit_config.update({
"enabled": True,
"requests_per_minute": requests_per_minute,
"current_count": 0,
"window_start": time.time()
})
async def _check_rate_limit(self):
"""Check rate limiting."""
if not self.rate_limit_config["enabled"]:
return
current_time = time.time()
window_duration = 60 # 1 minute
# Reset window if needed
if current_time - self.rate_limit_config["window_start"] > window_duration:
self.rate_limit_config["current_count"] = 0
self.rate_limit_config["window_start"] = current_time
# Check limit
if self.rate_limit_config["current_count"] >= self.rate_limit_config["requests_per_minute"]:
raise RateLimitError("Rate limit exceeded")
self.rate_limit_config["current_count"] += 1
async def _simulate_network_delay(self, endpoint: str):
"""Simulate network delay."""
delay = self.response_delays.get(endpoint, 0)
if delay > 0:
await asyncio.sleep(delay / 1000) # Convert ms to seconds
async def _check_error_simulation(self, endpoint: str):
"""Check if error should be simulated."""
if endpoint in self.error_simulation:
config = self.error_simulation[endpoint]
if random.random() < config["probability"]:
error_type = config["type"]
if error_type == "timeout":
raise asyncio.TimeoutError("Simulated timeout")
elif error_type == "connection":
raise aiohttp.ClientConnectionError("Simulated connection error")
elif error_type == "server_error":
raise APIError("Simulated server error")
async def _make_request(self, method: str, endpoint: str, **kwargs) -> APIResponse:
"""Make HTTP request with simulation."""
start_time = time.time()
# Check rate limiting
await self._check_rate_limit()
# Simulate network delay
await self._simulate_network_delay(endpoint)
# Check error simulation
await self._check_error_simulation(endpoint)
# Record request
request_record = {
"method": method,
"endpoint": endpoint,
"timestamp": datetime.utcnow(),
"kwargs": kwargs
}
self.request_history.append(request_record)
# Generate mock response
response_data = await self._generate_mock_response(method, endpoint, kwargs)
end_time = time.time()
response_time = (end_time - start_time) * 1000
return APIResponse(
status_code=response_data["status_code"],
data=response_data["data"],
headers=response_data.get("headers", {}),
response_time_ms=response_time,
timestamp=datetime.utcnow()
)
async def _generate_mock_response(self, method: str, endpoint: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Generate mock response based on endpoint."""
if endpoint == "/health":
return {
"status_code": 200,
"data": {
"status": "healthy",
"timestamp": datetime.utcnow().isoformat(),
"version": "1.0.0"
}
}
elif endpoint == "/auth/login":
if method == "POST":
# Generate mock JWT tokens
payload = {
"user_id": "test_user",
"exp": datetime.utcnow() + timedelta(hours=1)
}
access_token = jwt.encode(payload, "secret", algorithm="HS256")
refresh_token = jwt.encode({"user_id": "test_user"}, "secret", algorithm="HS256")
self.auth_token = access_token
self.refresh_token = refresh_token
self.token_expires_at = payload["exp"]
return {
"status_code": 200,
"data": {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
"expires_in": 3600
}
}
elif endpoint == "/auth/refresh":
if method == "POST" and self.refresh_token:
# Generate new access token
payload = {
"user_id": "test_user",
"exp": datetime.utcnow() + timedelta(hours=1)
}
access_token = jwt.encode(payload, "secret", algorithm="HS256")
self.auth_token = access_token
self.token_expires_at = payload["exp"]
return {
"status_code": 200,
"data": {
"access_token": access_token,
"token_type": "bearer",
"expires_in": 3600
}
}
elif endpoint == "/pose/detect":
if method == "POST":
return {
"status_code": 200,
"data": {
"persons": [
{
"person_id": "person_1",
"confidence": 0.85,
"bounding_box": {"x": 100, "y": 150, "width": 80, "height": 180},
"keypoints": [[x, y, 0.9] for x, y in zip(range(17), range(17))],
"activity": "standing"
}
],
"processing_time_ms": 45.2,
"model_version": "v1.0",
"timestamp": datetime.utcnow().isoformat()
}
}
elif endpoint == "/config":
if method == "GET":
return {
"status_code": 200,
"data": {
"model_config": {
"confidence_threshold": 0.7,
"nms_threshold": 0.5,
"max_persons": 10
},
"processing_config": {
"batch_size": 1,
"use_gpu": True,
"preprocessing": "standard"
}
}
}
# Default response
return {
"status_code": 404,
"data": {"error": "Endpoint not found"}
}
async def get(self, endpoint: str, **kwargs) -> APIResponse:
"""Make GET request."""
return await self._make_request("GET", endpoint, **kwargs)
async def post(self, endpoint: str, **kwargs) -> APIResponse:
"""Make POST request."""
return await self._make_request("POST", endpoint, **kwargs)
async def put(self, endpoint: str, **kwargs) -> APIResponse:
"""Make PUT request."""
return await self._make_request("PUT", endpoint, **kwargs)
async def delete(self, endpoint: str, **kwargs) -> APIResponse:
"""Make DELETE request."""
return await self._make_request("DELETE", endpoint, **kwargs)
async def login(self, username: str, password: str) -> bool:
"""Authenticate with API."""
response = await self.post("/auth/login", json={
"username": username,
"password": password
})
if response.status_code == 200:
return True
else:
raise AuthenticationError("Login failed")
async def refresh_auth_token(self) -> bool:
"""Refresh authentication token."""
if not self.refresh_token:
raise AuthenticationError("No refresh token available")
response = await self.post("/auth/refresh", json={
"refresh_token": self.refresh_token
})
if response.status_code == 200:
return True
else:
raise AuthenticationError("Token refresh failed")
def is_authenticated(self) -> bool:
"""Check if client is authenticated."""
if not self.auth_token or not self.token_expires_at:
return False
return datetime.utcnow() < self.token_expires_at
def get_request_history(self) -> List[Dict[str, Any]]:
"""Get request history."""
return self.request_history.copy()
def clear_request_history(self):
"""Clear request history."""
self.request_history.clear()
class MockWebSocketClient:
"""Mock WebSocket client for testing."""
def __init__(self, uri: str = "ws://localhost:8000/ws"):
self.uri = uri
self.websocket = None
self.is_connected = False
self.messages_received = []
self.messages_sent = []
self.connection_errors = []
self.auto_respond = True
self.response_delay = 0.01 # 10ms default delay
async def connect(self) -> bool:
"""Connect to WebSocket."""
try:
# Simulate connection
await asyncio.sleep(0.01)
self.is_connected = True
return True
except Exception as e:
self.connection_errors.append(str(e))
return False
async def disconnect(self):
"""Disconnect from WebSocket."""
self.is_connected = False
self.websocket = None
async def send_message(self, message: Dict[str, Any]) -> bool:
"""Send message to WebSocket."""
if not self.is_connected:
raise ConnectionError("WebSocket not connected")
# Record sent message
self.messages_sent.append({
"message": message,
"timestamp": datetime.utcnow()
})
# Auto-respond if enabled
if self.auto_respond:
await asyncio.sleep(self.response_delay)
response = await self._generate_auto_response(message)
if response:
self.messages_received.append({
"message": response,
"timestamp": datetime.utcnow()
})
return True
async def receive_message(self, timeout: float = 1.0) -> Optional[Dict[str, Any]]:
"""Receive message from WebSocket."""
if not self.is_connected:
raise ConnectionError("WebSocket not connected")
# Wait for message or timeout
start_time = time.time()
while time.time() - start_time < timeout:
if self.messages_received:
return self.messages_received.pop(0)["message"]
await asyncio.sleep(0.01)
return None
async def _generate_auto_response(self, message: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Generate automatic response to message."""
message_type = message.get("type")
if message_type == "subscribe":
return {
"type": "subscription_confirmed",
"channel": message.get("channel"),
"timestamp": datetime.utcnow().isoformat()
}
elif message_type == "pose_request":
return {
"type": "pose_data",
"data": {
"persons": [
{
"person_id": "person_1",
"confidence": 0.88,
"bounding_box": {"x": 150, "y": 200, "width": 80, "height": 180},
"keypoints": [[x, y, 0.9] for x, y in zip(range(17), range(17))]
}
],
"timestamp": datetime.utcnow().isoformat()
},
"request_id": message.get("request_id")
}
elif message_type == "ping":
return {
"type": "pong",
"timestamp": datetime.utcnow().isoformat()
}
return None
def set_auto_respond(self, enabled: bool, delay_ms: float = 10):
"""Configure auto-response behavior."""
self.auto_respond = enabled
self.response_delay = delay_ms / 1000
def inject_message(self, message: Dict[str, Any]):
"""Inject message as if received from server."""
self.messages_received.append({
"message": message,
"timestamp": datetime.utcnow()
})
def get_sent_messages(self) -> List[Dict[str, Any]]:
"""Get all sent messages."""
return self.messages_sent.copy()
def get_received_messages(self) -> List[Dict[str, Any]]:
"""Get all received messages."""
return self.messages_received.copy()
def clear_message_history(self):
"""Clear message history."""
self.messages_sent.clear()
self.messages_received.clear()
class APITestClient:
"""High-level test client combining HTTP and WebSocket."""
def __init__(self, base_url: str = "http://localhost:8000"):
self.base_url = base_url
self.ws_url = base_url.replace("http", "ws") + "/ws"
self.http_client = MockAPIClient(base_url)
self.ws_client = MockWebSocketClient(self.ws_url)
self.test_session_id = None
async def __aenter__(self):
"""Async context manager entry."""
await self.setup()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
await self.teardown()
async def setup(self):
"""Setup test client."""
await self.http_client.connect()
await self.ws_client.connect()
self.test_session_id = f"test_session_{int(time.time())}"
async def teardown(self):
"""Teardown test client."""
await self.ws_client.disconnect()
await self.http_client.disconnect()
async def authenticate(self, username: str = "test_user", password: str = "test_pass") -> bool:
"""Authenticate with API."""
return await self.http_client.login(username, password)
async def test_health_endpoint(self) -> APIResponse:
"""Test health endpoint."""
return await self.http_client.get("/health")
async def test_pose_detection(self, csi_data: Dict[str, Any]) -> APIResponse:
"""Test pose detection endpoint."""
return await self.http_client.post("/pose/detect", json=csi_data)
async def test_websocket_streaming(self, duration_seconds: int = 5) -> List[Dict[str, Any]]:
"""Test WebSocket streaming."""
# Subscribe to pose stream
await self.ws_client.send_message({
"type": "subscribe",
"channel": "pose_stream",
"session_id": self.test_session_id
})
# Collect messages for specified duration
messages = []
end_time = time.time() + duration_seconds
while time.time() < end_time:
message = await self.ws_client.receive_message(timeout=0.1)
if message:
messages.append(message)
return messages
async def simulate_concurrent_requests(self, num_requests: int = 10) -> List[APIResponse]:
"""Simulate concurrent HTTP requests."""
tasks = []
for i in range(num_requests):
task = asyncio.create_task(self.http_client.get("/health"))
tasks.append(task)
responses = await asyncio.gather(*tasks, return_exceptions=True)
return responses
async def simulate_websocket_load(self, num_connections: int = 5, duration_seconds: int = 3) -> Dict[str, Any]:
"""Simulate WebSocket load testing."""
# Create multiple WebSocket clients
ws_clients = []
for i in range(num_connections):
client = MockWebSocketClient(self.ws_url)
await client.connect()
ws_clients.append(client)
# Send messages from all clients
message_counts = []
try:
tasks = []
for i, client in enumerate(ws_clients):
task = asyncio.create_task(self._send_messages_for_duration(client, duration_seconds, i))
tasks.append(task)
results = await asyncio.gather(*tasks)
message_counts = results
finally:
# Cleanup
for client in ws_clients:
await client.disconnect()
return {
"num_connections": num_connections,
"duration_seconds": duration_seconds,
"messages_per_connection": message_counts,
"total_messages": sum(message_counts)
}
async def _send_messages_for_duration(self, client: MockWebSocketClient, duration: int, client_id: int) -> int:
"""Send messages for specified duration."""
message_count = 0
end_time = time.time() + duration
while time.time() < end_time:
await client.send_message({
"type": "ping",
"client_id": client_id,
"message_id": message_count
})
message_count += 1
await asyncio.sleep(0.1) # 10 messages per second
return message_count
def configure_error_simulation(self, endpoint: str, error_type: str, probability: float = 0.1):
"""Configure error simulation for testing."""
self.http_client.simulate_error(endpoint, error_type, probability)
def configure_rate_limiting(self, requests_per_minute: int = 60):
"""Configure rate limiting for testing."""
self.http_client.enable_rate_limiting(requests_per_minute)
def get_performance_metrics(self) -> Dict[str, Any]:
"""Get performance metrics from test session."""
http_history = self.http_client.get_request_history()
ws_sent = self.ws_client.get_sent_messages()
ws_received = self.ws_client.get_received_messages()
# Calculate HTTP metrics
if http_history:
response_times = [r.get("response_time_ms", 0) for r in http_history]
http_metrics = {
"total_requests": len(http_history),
"avg_response_time_ms": sum(response_times) / len(response_times),
"min_response_time_ms": min(response_times),
"max_response_time_ms": max(response_times)
}
else:
http_metrics = {"total_requests": 0}
# Calculate WebSocket metrics
ws_metrics = {
"messages_sent": len(ws_sent),
"messages_received": len(ws_received),
"connection_active": self.ws_client.is_connected
}
return {
"session_id": self.test_session_id,
"http_metrics": http_metrics,
"websocket_metrics": ws_metrics,
"timestamp": datetime.utcnow().isoformat()
}
# Utility functions for test data generation
def generate_test_csi_data() -> Dict[str, Any]:
"""Generate test CSI data for API testing."""
import numpy as np
return {
"timestamp": datetime.utcnow().isoformat(),
"router_id": "test_router_001",
"amplitude": np.random.uniform(0, 1, (4, 64)).tolist(),
"phase": np.random.uniform(-np.pi, np.pi, (4, 64)).tolist(),
"frequency": 5.8e9,
"bandwidth": 80e6,
"num_antennas": 4,
"num_subcarriers": 64
}
def create_test_user_credentials() -> Dict[str, str]:
"""Create test user credentials."""
return {
"username": "test_user",
"password": "test_password_123",
"email": "test@example.com"
}
async def wait_for_condition(condition_func, timeout: float = 5.0, interval: float = 0.1) -> bool:
"""Wait for condition to become true."""
end_time = time.time() + timeout
while time.time() < end_time:
if await condition_func() if asyncio.iscoroutinefunction(condition_func) else condition_func():
return True
await asyncio.sleep(interval)
return False

487
tests/fixtures/csi_data.py vendored Normal file
View File

@@ -0,0 +1,487 @@
"""
Test data generation utilities for CSI data.
Provides realistic CSI data samples for testing pose estimation pipeline.
"""
import numpy as np
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional, Tuple
import json
import random
class CSIDataGenerator:
"""Generate realistic CSI data for testing."""
def __init__(self,
frequency: float = 5.8e9,
bandwidth: float = 80e6,
num_antennas: int = 4,
num_subcarriers: int = 64):
self.frequency = frequency
self.bandwidth = bandwidth
self.num_antennas = num_antennas
self.num_subcarriers = num_subcarriers
self.sample_rate = 1000 # Hz
self.noise_level = 0.1
# Pre-computed patterns for different scenarios
self._initialize_patterns()
def _initialize_patterns(self):
"""Initialize CSI patterns for different scenarios."""
# Empty room pattern (baseline)
self.empty_room_pattern = {
"amplitude_mean": 0.3,
"amplitude_std": 0.05,
"phase_variance": 0.1,
"temporal_stability": 0.95
}
# Single person patterns
self.single_person_patterns = {
"standing": {
"amplitude_mean": 0.5,
"amplitude_std": 0.08,
"phase_variance": 0.2,
"temporal_stability": 0.85,
"movement_frequency": 0.1
},
"walking": {
"amplitude_mean": 0.6,
"amplitude_std": 0.15,
"phase_variance": 0.4,
"temporal_stability": 0.6,
"movement_frequency": 2.0
},
"sitting": {
"amplitude_mean": 0.4,
"amplitude_std": 0.06,
"phase_variance": 0.15,
"temporal_stability": 0.9,
"movement_frequency": 0.05
},
"fallen": {
"amplitude_mean": 0.35,
"amplitude_std": 0.04,
"phase_variance": 0.08,
"temporal_stability": 0.95,
"movement_frequency": 0.02
}
}
# Multi-person patterns
self.multi_person_patterns = {
2: {"amplitude_multiplier": 1.4, "phase_complexity": 1.6},
3: {"amplitude_multiplier": 1.7, "phase_complexity": 2.1},
4: {"amplitude_multiplier": 2.0, "phase_complexity": 2.8}
}
def generate_empty_room_sample(self, timestamp: Optional[datetime] = None) -> Dict[str, Any]:
"""Generate CSI sample for empty room."""
if timestamp is None:
timestamp = datetime.utcnow()
pattern = self.empty_room_pattern
# Generate amplitude matrix
amplitude = np.random.normal(
pattern["amplitude_mean"],
pattern["amplitude_std"],
(self.num_antennas, self.num_subcarriers)
)
amplitude = np.clip(amplitude, 0, 1)
# Generate phase matrix
phase = np.random.uniform(
-np.pi, np.pi,
(self.num_antennas, self.num_subcarriers)
)
# Add temporal stability
if hasattr(self, '_last_empty_sample'):
stability = pattern["temporal_stability"]
amplitude = stability * self._last_empty_sample["amplitude"] + (1 - stability) * amplitude
phase = stability * self._last_empty_sample["phase"] + (1 - stability) * phase
sample = {
"timestamp": timestamp.isoformat(),
"router_id": "router_001",
"amplitude": amplitude.tolist(),
"phase": phase.tolist(),
"frequency": self.frequency,
"bandwidth": self.bandwidth,
"num_antennas": self.num_antennas,
"num_subcarriers": self.num_subcarriers,
"sample_rate": self.sample_rate,
"scenario": "empty_room",
"signal_quality": np.random.uniform(0.85, 0.95)
}
self._last_empty_sample = {
"amplitude": amplitude,
"phase": phase
}
return sample
def generate_single_person_sample(self,
activity: str = "standing",
timestamp: Optional[datetime] = None) -> Dict[str, Any]:
"""Generate CSI sample for single person activity."""
if timestamp is None:
timestamp = datetime.utcnow()
if activity not in self.single_person_patterns:
raise ValueError(f"Unknown activity: {activity}")
pattern = self.single_person_patterns[activity]
# Generate base amplitude
amplitude = np.random.normal(
pattern["amplitude_mean"],
pattern["amplitude_std"],
(self.num_antennas, self.num_subcarriers)
)
# Add movement-induced variations
movement_freq = pattern["movement_frequency"]
time_factor = timestamp.timestamp()
movement_modulation = 0.1 * np.sin(2 * np.pi * movement_freq * time_factor)
amplitude += movement_modulation
amplitude = np.clip(amplitude, 0, 1)
# Generate phase with activity-specific variance
phase_base = np.random.uniform(-np.pi, np.pi, (self.num_antennas, self.num_subcarriers))
phase_variance = pattern["phase_variance"]
phase_noise = np.random.normal(0, phase_variance, (self.num_antennas, self.num_subcarriers))
phase = phase_base + phase_noise
phase = np.mod(phase + np.pi, 2 * np.pi) - np.pi # Wrap to [-π, π]
# Add temporal correlation
if hasattr(self, f'_last_{activity}_sample'):
stability = pattern["temporal_stability"]
last_sample = getattr(self, f'_last_{activity}_sample')
amplitude = stability * last_sample["amplitude"] + (1 - stability) * amplitude
phase = stability * last_sample["phase"] + (1 - stability) * phase
sample = {
"timestamp": timestamp.isoformat(),
"router_id": "router_001",
"amplitude": amplitude.tolist(),
"phase": phase.tolist(),
"frequency": self.frequency,
"bandwidth": self.bandwidth,
"num_antennas": self.num_antennas,
"num_subcarriers": self.num_subcarriers,
"sample_rate": self.sample_rate,
"scenario": f"single_person_{activity}",
"signal_quality": np.random.uniform(0.7, 0.9),
"activity": activity
}
setattr(self, f'_last_{activity}_sample', {
"amplitude": amplitude,
"phase": phase
})
return sample
def generate_multi_person_sample(self,
num_persons: int = 2,
activities: Optional[List[str]] = None,
timestamp: Optional[datetime] = None) -> Dict[str, Any]:
"""Generate CSI sample for multiple persons."""
if timestamp is None:
timestamp = datetime.utcnow()
if num_persons < 2 or num_persons > 4:
raise ValueError("Number of persons must be between 2 and 4")
if activities is None:
activities = random.choices(list(self.single_person_patterns.keys()), k=num_persons)
if len(activities) != num_persons:
raise ValueError("Number of activities must match number of persons")
# Start with empty room baseline
amplitude = np.random.normal(
self.empty_room_pattern["amplitude_mean"],
self.empty_room_pattern["amplitude_std"],
(self.num_antennas, self.num_subcarriers)
)
phase = np.random.uniform(
-np.pi, np.pi,
(self.num_antennas, self.num_subcarriers)
)
# Add contribution from each person
for i, activity in enumerate(activities):
person_pattern = self.single_person_patterns[activity]
# Generate person-specific contribution
person_amplitude = np.random.normal(
person_pattern["amplitude_mean"] * 0.7, # Reduced for multi-person
person_pattern["amplitude_std"],
(self.num_antennas, self.num_subcarriers)
)
# Add spatial variation (different persons at different locations)
spatial_offset = i * self.num_subcarriers // num_persons
person_amplitude = np.roll(person_amplitude, spatial_offset, axis=1)
# Add movement modulation
movement_freq = person_pattern["movement_frequency"]
time_factor = timestamp.timestamp() + i * 0.5 # Phase offset between persons
movement_modulation = 0.05 * np.sin(2 * np.pi * movement_freq * time_factor)
person_amplitude += movement_modulation
amplitude += person_amplitude
# Add phase contribution
person_phase = np.random.normal(0, person_pattern["phase_variance"],
(self.num_antennas, self.num_subcarriers))
person_phase = np.roll(person_phase, spatial_offset, axis=1)
phase += person_phase
# Apply multi-person complexity
pattern = self.multi_person_patterns[num_persons]
amplitude *= pattern["amplitude_multiplier"]
phase *= pattern["phase_complexity"]
# Clip and normalize
amplitude = np.clip(amplitude, 0, 1)
phase = np.mod(phase + np.pi, 2 * np.pi) - np.pi
sample = {
"timestamp": timestamp.isoformat(),
"router_id": "router_001",
"amplitude": amplitude.tolist(),
"phase": phase.tolist(),
"frequency": self.frequency,
"bandwidth": self.bandwidth,
"num_antennas": self.num_antennas,
"num_subcarriers": self.num_subcarriers,
"sample_rate": self.sample_rate,
"scenario": f"multi_person_{num_persons}",
"signal_quality": np.random.uniform(0.6, 0.8),
"num_persons": num_persons,
"activities": activities
}
return sample
def generate_time_series(self,
duration_seconds: int = 10,
scenario: str = "single_person_walking",
**kwargs) -> List[Dict[str, Any]]:
"""Generate time series of CSI samples."""
samples = []
start_time = datetime.utcnow()
for i in range(duration_seconds * self.sample_rate):
timestamp = start_time + timedelta(seconds=i / self.sample_rate)
if scenario == "empty_room":
sample = self.generate_empty_room_sample(timestamp)
elif scenario.startswith("single_person_"):
activity = scenario.replace("single_person_", "")
sample = self.generate_single_person_sample(activity, timestamp)
elif scenario.startswith("multi_person_"):
num_persons = int(scenario.split("_")[-1])
sample = self.generate_multi_person_sample(num_persons, timestamp=timestamp, **kwargs)
else:
raise ValueError(f"Unknown scenario: {scenario}")
samples.append(sample)
return samples
def add_noise(self, sample: Dict[str, Any], noise_level: Optional[float] = None) -> Dict[str, Any]:
"""Add noise to CSI sample."""
if noise_level is None:
noise_level = self.noise_level
noisy_sample = sample.copy()
# Add amplitude noise
amplitude = np.array(sample["amplitude"])
amplitude_noise = np.random.normal(0, noise_level, amplitude.shape)
noisy_amplitude = amplitude + amplitude_noise
noisy_amplitude = np.clip(noisy_amplitude, 0, 1)
noisy_sample["amplitude"] = noisy_amplitude.tolist()
# Add phase noise
phase = np.array(sample["phase"])
phase_noise = np.random.normal(0, noise_level * np.pi, phase.shape)
noisy_phase = phase + phase_noise
noisy_phase = np.mod(noisy_phase + np.pi, 2 * np.pi) - np.pi
noisy_sample["phase"] = noisy_phase.tolist()
# Reduce signal quality
noisy_sample["signal_quality"] *= (1 - noise_level)
return noisy_sample
def simulate_hardware_artifacts(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Simulate hardware-specific artifacts."""
artifact_sample = sample.copy()
amplitude = np.array(sample["amplitude"])
phase = np.array(sample["phase"])
# Simulate antenna coupling
coupling_matrix = np.random.uniform(0.95, 1.05, (self.num_antennas, self.num_antennas))
amplitude = coupling_matrix @ amplitude
# Simulate frequency-dependent gain variations
freq_response = 1 + 0.1 * np.sin(np.linspace(0, 2*np.pi, self.num_subcarriers))
amplitude *= freq_response[np.newaxis, :]
# Simulate phase drift
phase_drift = np.random.uniform(-0.1, 0.1) * np.arange(self.num_subcarriers)
phase += phase_drift[np.newaxis, :]
# Clip and wrap
amplitude = np.clip(amplitude, 0, 1)
phase = np.mod(phase + np.pi, 2 * np.pi) - np.pi
artifact_sample["amplitude"] = amplitude.tolist()
artifact_sample["phase"] = phase.tolist()
return artifact_sample
# Convenience functions for common test scenarios
def generate_fall_detection_sequence() -> List[Dict[str, Any]]:
"""Generate CSI sequence showing fall detection scenario."""
generator = CSIDataGenerator()
sequence = []
# Normal standing (5 seconds)
sequence.extend(generator.generate_time_series(5, "single_person_standing"))
# Walking (3 seconds)
sequence.extend(generator.generate_time_series(3, "single_person_walking"))
# Fall event (1 second transition)
sequence.extend(generator.generate_time_series(1, "single_person_fallen"))
# Fallen state (3 seconds)
sequence.extend(generator.generate_time_series(3, "single_person_fallen"))
return sequence
def generate_multi_person_scenario() -> List[Dict[str, Any]]:
"""Generate CSI sequence for multi-person scenario."""
generator = CSIDataGenerator()
sequence = []
# Start with empty room
sequence.extend(generator.generate_time_series(2, "empty_room"))
# One person enters
sequence.extend(generator.generate_time_series(3, "single_person_walking"))
# Second person enters
sequence.extend(generator.generate_time_series(5, "multi_person_2",
activities=["standing", "walking"]))
# Third person enters
sequence.extend(generator.generate_time_series(4, "multi_person_3",
activities=["standing", "walking", "sitting"]))
return sequence
def generate_noisy_environment_data() -> List[Dict[str, Any]]:
"""Generate CSI data with various noise levels."""
generator = CSIDataGenerator()
# Generate clean data
clean_samples = generator.generate_time_series(5, "single_person_walking")
# Add different noise levels
noisy_samples = []
noise_levels = [0.05, 0.1, 0.2, 0.3]
for noise_level in noise_levels:
for sample in clean_samples[:10]: # Take first 10 samples
noisy_sample = generator.add_noise(sample, noise_level)
noisy_samples.append(noisy_sample)
return noisy_samples
def generate_hardware_test_data() -> List[Dict[str, Any]]:
"""Generate CSI data with hardware artifacts."""
generator = CSIDataGenerator()
# Generate base samples
base_samples = generator.generate_time_series(3, "single_person_standing")
# Add hardware artifacts
artifact_samples = []
for sample in base_samples:
artifact_sample = generator.simulate_hardware_artifacts(sample)
artifact_samples.append(artifact_sample)
return artifact_samples
# Test data validation utilities
def validate_csi_sample(sample: Dict[str, Any]) -> bool:
"""Validate CSI sample structure and data ranges."""
required_fields = [
"timestamp", "router_id", "amplitude", "phase",
"frequency", "bandwidth", "num_antennas", "num_subcarriers"
]
# Check required fields
for field in required_fields:
if field not in sample:
return False
# Validate data types and ranges
amplitude = np.array(sample["amplitude"])
phase = np.array(sample["phase"])
# Check shapes
expected_shape = (sample["num_antennas"], sample["num_subcarriers"])
if amplitude.shape != expected_shape or phase.shape != expected_shape:
return False
# Check value ranges
if not (0 <= amplitude.min() and amplitude.max() <= 1):
return False
if not (-np.pi <= phase.min() and phase.max() <= np.pi):
return False
return True
def extract_features_from_csi(sample: Dict[str, Any]) -> Dict[str, Any]:
"""Extract features from CSI sample for testing."""
amplitude = np.array(sample["amplitude"])
phase = np.array(sample["phase"])
features = {
"amplitude_mean": float(np.mean(amplitude)),
"amplitude_std": float(np.std(amplitude)),
"amplitude_max": float(np.max(amplitude)),
"amplitude_min": float(np.min(amplitude)),
"phase_variance": float(np.var(phase)),
"phase_range": float(np.max(phase) - np.min(phase)),
"signal_energy": float(np.sum(amplitude ** 2)),
"phase_coherence": float(np.abs(np.mean(np.exp(1j * phase)))),
"spatial_correlation": float(np.mean(np.corrcoef(amplitude))),
"frequency_diversity": float(np.std(np.mean(amplitude, axis=0)))
}
return features

View File

@@ -0,0 +1,338 @@
"""
Integration tests for WiFi-DensePose API endpoints.
Tests all REST API endpoints with real service dependencies.
"""
import pytest
import asyncio
from datetime import datetime, timedelta
from typing import Dict, Any
from unittest.mock import AsyncMock, MagicMock
from fastapi.testclient import TestClient
from fastapi import FastAPI
import httpx
from src.api.dependencies import (
get_pose_service,
get_stream_service,
get_hardware_service,
get_current_user
)
from src.api.routers.health import router as health_router
from src.api.routers.pose import router as pose_router
from src.api.routers.stream import router as stream_router
class TestAPIEndpoints:
"""Integration tests for API endpoints."""
@pytest.fixture
def app(self):
"""Create FastAPI app with test dependencies."""
app = FastAPI()
app.include_router(health_router, prefix="/health", tags=["health"])
app.include_router(pose_router, prefix="/pose", tags=["pose"])
app.include_router(stream_router, prefix="/stream", tags=["stream"])
return app
@pytest.fixture
def mock_pose_service(self):
"""Mock pose service."""
service = AsyncMock()
service.health_check.return_value = {
"status": "healthy",
"message": "Service operational",
"uptime_seconds": 3600.0,
"metrics": {"processed_frames": 1000}
}
service.is_ready.return_value = True
service.estimate_poses.return_value = {
"timestamp": datetime.utcnow(),
"frame_id": "test-frame-001",
"persons": [],
"zone_summary": {"zone1": 0},
"processing_time_ms": 50.0,
"metadata": {}
}
return service
@pytest.fixture
def mock_stream_service(self):
"""Mock stream service."""
service = AsyncMock()
service.health_check.return_value = {
"status": "healthy",
"message": "Stream service operational",
"uptime_seconds": 1800.0
}
service.is_ready.return_value = True
service.get_status.return_value = {
"is_active": True,
"active_streams": [],
"uptime_seconds": 1800.0
}
service.is_active.return_value = True
return service
@pytest.fixture
def mock_hardware_service(self):
"""Mock hardware service."""
service = AsyncMock()
service.health_check.return_value = {
"status": "healthy",
"message": "Hardware connected",
"uptime_seconds": 7200.0,
"metrics": {"connected_routers": 3}
}
service.is_ready.return_value = True
return service
@pytest.fixture
def mock_user(self):
"""Mock authenticated user."""
return {
"id": "test-user-001",
"username": "testuser",
"email": "test@example.com",
"is_admin": False,
"is_active": True,
"permissions": ["read", "write"]
}
@pytest.fixture
def client(self, app, mock_pose_service, mock_stream_service, mock_hardware_service, mock_user):
"""Create test client with mocked dependencies."""
app.dependency_overrides[get_pose_service] = lambda: mock_pose_service
app.dependency_overrides[get_stream_service] = lambda: mock_stream_service
app.dependency_overrides[get_hardware_service] = lambda: mock_hardware_service
app.dependency_overrides[get_current_user] = lambda: mock_user
with TestClient(app) as client:
yield client
def test_health_check_endpoint_should_fail_initially(self, client):
"""Test health check endpoint - should fail initially."""
# This test should fail because we haven't implemented the endpoint properly
response = client.get("/health/health")
# This assertion will fail initially, driving us to implement the endpoint
assert response.status_code == 200
assert "status" in response.json()
assert "components" in response.json()
assert "system_metrics" in response.json()
def test_readiness_check_endpoint_should_fail_initially(self, client):
"""Test readiness check endpoint - should fail initially."""
response = client.get("/health/ready")
# This will fail initially
assert response.status_code == 200
data = response.json()
assert "ready" in data
assert "checks" in data
assert isinstance(data["checks"], dict)
def test_liveness_check_endpoint_should_fail_initially(self, client):
"""Test liveness check endpoint - should fail initially."""
response = client.get("/health/live")
# This will fail initially
assert response.status_code == 200
data = response.json()
assert "status" in data
assert data["status"] == "alive"
def test_version_info_endpoint_should_fail_initially(self, client):
"""Test version info endpoint - should fail initially."""
response = client.get("/health/version")
# This will fail initially
assert response.status_code == 200
data = response.json()
assert "name" in data
assert "version" in data
assert "environment" in data
def test_pose_current_endpoint_should_fail_initially(self, client):
"""Test current pose estimation endpoint - should fail initially."""
response = client.get("/pose/current")
# This will fail initially
assert response.status_code == 200
data = response.json()
assert "timestamp" in data
assert "frame_id" in data
assert "persons" in data
assert "zone_summary" in data
def test_pose_analyze_endpoint_should_fail_initially(self, client):
"""Test pose analysis endpoint - should fail initially."""
request_data = {
"zone_ids": ["zone1", "zone2"],
"confidence_threshold": 0.7,
"max_persons": 10,
"include_keypoints": True,
"include_segmentation": False
}
response = client.post("/pose/analyze", json=request_data)
# This will fail initially
assert response.status_code == 200
data = response.json()
assert "timestamp" in data
assert "persons" in data
def test_zone_occupancy_endpoint_should_fail_initially(self, client):
"""Test zone occupancy endpoint - should fail initially."""
response = client.get("/pose/zones/zone1/occupancy")
# This will fail initially
assert response.status_code == 200
data = response.json()
assert "zone_id" in data
assert "current_occupancy" in data
def test_zones_summary_endpoint_should_fail_initially(self, client):
"""Test zones summary endpoint - should fail initially."""
response = client.get("/pose/zones/summary")
# This will fail initially
assert response.status_code == 200
data = response.json()
assert "total_persons" in data
assert "zones" in data
def test_stream_status_endpoint_should_fail_initially(self, client):
"""Test stream status endpoint - should fail initially."""
response = client.get("/stream/status")
# This will fail initially
assert response.status_code == 200
data = response.json()
assert "is_active" in data
assert "connected_clients" in data
def test_stream_start_endpoint_should_fail_initially(self, client):
"""Test stream start endpoint - should fail initially."""
response = client.post("/stream/start")
# This will fail initially
assert response.status_code == 200
data = response.json()
assert "message" in data
def test_stream_stop_endpoint_should_fail_initially(self, client):
"""Test stream stop endpoint - should fail initially."""
response = client.post("/stream/stop")
# This will fail initially
assert response.status_code == 200
data = response.json()
assert "message" in data
class TestAPIErrorHandling:
"""Test API error handling scenarios."""
@pytest.fixture
def app_with_failing_services(self):
"""Create app with failing service dependencies."""
app = FastAPI()
app.include_router(health_router, prefix="/health", tags=["health"])
app.include_router(pose_router, prefix="/pose", tags=["pose"])
# Mock failing services
failing_pose_service = AsyncMock()
failing_pose_service.health_check.side_effect = Exception("Service unavailable")
app.dependency_overrides[get_pose_service] = lambda: failing_pose_service
return app
def test_health_check_with_failing_service_should_fail_initially(self, app_with_failing_services):
"""Test health check with failing service - should fail initially."""
with TestClient(app_with_failing_services) as client:
response = client.get("/health/health")
# This will fail initially
assert response.status_code == 200
data = response.json()
assert data["status"] == "unhealthy"
assert "hardware" in data["components"]
assert data["components"]["pose"]["status"] == "unhealthy"
class TestAPIAuthentication:
"""Test API authentication scenarios."""
@pytest.fixture
def app_with_auth(self):
"""Create app with authentication enabled."""
app = FastAPI()
app.include_router(pose_router, prefix="/pose", tags=["pose"])
# Mock authenticated user dependency
def get_authenticated_user():
return {
"id": "auth-user-001",
"username": "authuser",
"is_admin": True,
"permissions": ["read", "write", "admin"]
}
app.dependency_overrides[get_current_user] = get_authenticated_user
return app
def test_authenticated_endpoint_access_should_fail_initially(self, app_with_auth):
"""Test authenticated endpoint access - should fail initially."""
with TestClient(app_with_auth) as client:
response = client.post("/pose/analyze", json={
"confidence_threshold": 0.8,
"include_keypoints": True
})
# This will fail initially
assert response.status_code == 200
class TestAPIValidation:
"""Test API request validation."""
@pytest.fixture
def validation_app(self):
"""Create app for validation testing."""
app = FastAPI()
app.include_router(pose_router, prefix="/pose", tags=["pose"])
# Mock service
mock_service = AsyncMock()
app.dependency_overrides[get_pose_service] = lambda: mock_service
return app
def test_invalid_confidence_threshold_should_fail_initially(self, validation_app):
"""Test invalid confidence threshold validation - should fail initially."""
with TestClient(validation_app) as client:
response = client.post("/pose/analyze", json={
"confidence_threshold": 1.5, # Invalid: > 1.0
"include_keypoints": True
})
# This will fail initially
assert response.status_code == 422
assert "validation error" in response.json()["detail"][0]["msg"].lower()
def test_invalid_max_persons_should_fail_initially(self, validation_app):
"""Test invalid max_persons validation - should fail initially."""
with TestClient(validation_app) as client:
response = client.post("/pose/analyze", json={
"max_persons": 0, # Invalid: < 1
"include_keypoints": True
})
# This will fail initially
assert response.status_code == 422

View File

@@ -0,0 +1,571 @@
"""
Integration tests for authentication and authorization.
Tests JWT authentication flow, user permissions, and access control.
"""
import pytest
import asyncio
from datetime import datetime, timedelta
from typing import Dict, Any, Optional
from unittest.mock import AsyncMock, MagicMock, patch
import jwt
import json
from fastapi import HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials
class MockJWTToken:
"""Mock JWT token for testing."""
def __init__(self, payload: Dict[str, Any], secret: str = "test-secret"):
self.payload = payload
self.secret = secret
self.token = jwt.encode(payload, secret, algorithm="HS256")
def decode(self, token: str, secret: str) -> Dict[str, Any]:
"""Decode JWT token."""
return jwt.decode(token, secret, algorithms=["HS256"])
class TestJWTAuthentication:
"""Test JWT authentication functionality."""
@pytest.fixture
def valid_user_payload(self):
"""Valid user payload for JWT token."""
return {
"sub": "user-001",
"username": "testuser",
"email": "test@example.com",
"is_admin": False,
"is_active": True,
"permissions": ["read", "write"],
"exp": datetime.utcnow() + timedelta(hours=1),
"iat": datetime.utcnow()
}
@pytest.fixture
def admin_user_payload(self):
"""Admin user payload for JWT token."""
return {
"sub": "admin-001",
"username": "admin",
"email": "admin@example.com",
"is_admin": True,
"is_active": True,
"permissions": ["read", "write", "admin"],
"exp": datetime.utcnow() + timedelta(hours=1),
"iat": datetime.utcnow()
}
@pytest.fixture
def expired_user_payload(self):
"""Expired user payload for JWT token."""
return {
"sub": "user-002",
"username": "expireduser",
"email": "expired@example.com",
"is_admin": False,
"is_active": True,
"permissions": ["read"],
"exp": datetime.utcnow() - timedelta(hours=1), # Expired
"iat": datetime.utcnow() - timedelta(hours=2)
}
@pytest.fixture
def mock_jwt_service(self):
"""Mock JWT service."""
class MockJWTService:
def __init__(self):
self.secret = "test-secret-key"
self.algorithm = "HS256"
def create_token(self, user_data: Dict[str, Any]) -> str:
"""Create JWT token."""
payload = {
**user_data,
"exp": datetime.utcnow() + timedelta(hours=1),
"iat": datetime.utcnow()
}
return jwt.encode(payload, self.secret, algorithm=self.algorithm)
def verify_token(self, token: str) -> Dict[str, Any]:
"""Verify JWT token."""
try:
payload = jwt.decode(token, self.secret, algorithms=[self.algorithm])
return payload
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has expired"
)
except jwt.InvalidTokenError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token"
)
def refresh_token(self, token: str) -> str:
"""Refresh JWT token."""
payload = self.verify_token(token)
# Remove exp and iat for new token
payload.pop("exp", None)
payload.pop("iat", None)
return self.create_token(payload)
return MockJWTService()
def test_jwt_token_creation_should_fail_initially(self, mock_jwt_service, valid_user_payload):
"""Test JWT token creation - should fail initially."""
token = mock_jwt_service.create_token(valid_user_payload)
# This will fail initially
assert isinstance(token, str)
assert len(token) > 0
# Verify token can be decoded
decoded = mock_jwt_service.verify_token(token)
assert decoded["sub"] == valid_user_payload["sub"]
assert decoded["username"] == valid_user_payload["username"]
def test_jwt_token_verification_should_fail_initially(self, mock_jwt_service, valid_user_payload):
"""Test JWT token verification - should fail initially."""
token = mock_jwt_service.create_token(valid_user_payload)
decoded = mock_jwt_service.verify_token(token)
# This will fail initially
assert decoded["sub"] == valid_user_payload["sub"]
assert decoded["is_admin"] == valid_user_payload["is_admin"]
assert "exp" in decoded
assert "iat" in decoded
def test_expired_token_rejection_should_fail_initially(self, mock_jwt_service, expired_user_payload):
"""Test expired token rejection - should fail initially."""
# Create token with expired payload
token = jwt.encode(expired_user_payload, mock_jwt_service.secret, algorithm=mock_jwt_service.algorithm)
# This should fail initially
with pytest.raises(HTTPException) as exc_info:
mock_jwt_service.verify_token(token)
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert "expired" in exc_info.value.detail.lower()
def test_invalid_token_rejection_should_fail_initially(self, mock_jwt_service):
"""Test invalid token rejection - should fail initially."""
invalid_token = "invalid.jwt.token"
# This should fail initially
with pytest.raises(HTTPException) as exc_info:
mock_jwt_service.verify_token(invalid_token)
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert "invalid" in exc_info.value.detail.lower()
def test_token_refresh_should_fail_initially(self, mock_jwt_service, valid_user_payload):
"""Test token refresh functionality - should fail initially."""
original_token = mock_jwt_service.create_token(valid_user_payload)
# Wait a moment to ensure different timestamps
import time
time.sleep(0.1)
refreshed_token = mock_jwt_service.refresh_token(original_token)
# This will fail initially
assert refreshed_token != original_token
# Verify both tokens are valid but have different timestamps
original_payload = mock_jwt_service.verify_token(original_token)
refreshed_payload = mock_jwt_service.verify_token(refreshed_token)
assert original_payload["sub"] == refreshed_payload["sub"]
assert original_payload["iat"] != refreshed_payload["iat"]
class TestUserAuthentication:
"""Test user authentication scenarios."""
@pytest.fixture
def mock_user_service(self):
"""Mock user service."""
class MockUserService:
def __init__(self):
self.users = {
"testuser": {
"id": "user-001",
"username": "testuser",
"email": "test@example.com",
"password_hash": "hashed_password",
"is_admin": False,
"is_active": True,
"permissions": ["read", "write"],
"zones": ["zone1", "zone2"],
"created_at": datetime.utcnow()
},
"admin": {
"id": "admin-001",
"username": "admin",
"email": "admin@example.com",
"password_hash": "admin_hashed_password",
"is_admin": True,
"is_active": True,
"permissions": ["read", "write", "admin"],
"zones": [], # Admin has access to all zones
"created_at": datetime.utcnow()
}
}
async def authenticate_user(self, username: str, password: str) -> Optional[Dict[str, Any]]:
"""Authenticate user with username and password."""
user = self.users.get(username)
if not user:
return None
# Mock password verification
if password == "correct_password":
return user
return None
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Get user by ID."""
for user in self.users.values():
if user["id"] == user_id:
return user
return None
async def update_user_activity(self, user_id: str):
"""Update user last activity."""
user = await self.get_user_by_id(user_id)
if user:
user["last_activity"] = datetime.utcnow()
return MockUserService()
@pytest.mark.asyncio
async def test_user_authentication_success_should_fail_initially(self, mock_user_service):
"""Test successful user authentication - should fail initially."""
user = await mock_user_service.authenticate_user("testuser", "correct_password")
# This will fail initially
assert user is not None
assert user["username"] == "testuser"
assert user["is_active"] is True
assert "read" in user["permissions"]
@pytest.mark.asyncio
async def test_user_authentication_failure_should_fail_initially(self, mock_user_service):
"""Test failed user authentication - should fail initially."""
user = await mock_user_service.authenticate_user("testuser", "wrong_password")
# This will fail initially
assert user is None
# Test with non-existent user
user = await mock_user_service.authenticate_user("nonexistent", "any_password")
assert user is None
@pytest.mark.asyncio
async def test_admin_user_authentication_should_fail_initially(self, mock_user_service):
"""Test admin user authentication - should fail initially."""
admin = await mock_user_service.authenticate_user("admin", "correct_password")
# This will fail initially
assert admin is not None
assert admin["is_admin"] is True
assert "admin" in admin["permissions"]
assert admin["zones"] == [] # Admin has access to all zones
class TestAuthorizationDependencies:
"""Test authorization dependency functions."""
@pytest.fixture
def mock_request(self):
"""Mock FastAPI request."""
class MockRequest:
def __init__(self):
self.state = MagicMock()
self.state.user = None
return MockRequest()
@pytest.fixture
def mock_credentials(self):
"""Mock HTTP authorization credentials."""
def create_credentials(token: str):
return HTTPAuthorizationCredentials(
scheme="Bearer",
credentials=token
)
return create_credentials
@pytest.mark.asyncio
async def test_get_current_user_with_valid_token_should_fail_initially(self, mock_request, mock_credentials):
"""Test get_current_user with valid token - should fail initially."""
# Mock the get_current_user dependency
async def mock_get_current_user(request, credentials):
if not credentials:
return None
# Mock token validation
if credentials.credentials == "valid_token":
return {
"id": "user-001",
"username": "testuser",
"is_admin": False,
"is_active": True,
"permissions": ["read", "write"]
}
return None
credentials = mock_credentials("valid_token")
user = await mock_get_current_user(mock_request, credentials)
# This will fail initially
assert user is not None
assert user["username"] == "testuser"
assert user["is_active"] is True
@pytest.mark.asyncio
async def test_get_current_user_without_credentials_should_fail_initially(self, mock_request):
"""Test get_current_user without credentials - should fail initially."""
async def mock_get_current_user(request, credentials):
if not credentials:
return None
return {"id": "user-001"}
user = await mock_get_current_user(mock_request, None)
# This will fail initially
assert user is None
@pytest.mark.asyncio
async def test_require_active_user_should_fail_initially(self):
"""Test require active user dependency - should fail initially."""
async def mock_get_current_active_user(current_user):
if not current_user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required"
)
if not current_user.get("is_active", True):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Inactive user"
)
return current_user
# Test with active user
active_user = {"id": "user-001", "is_active": True}
result = await mock_get_current_active_user(active_user)
# This will fail initially
assert result == active_user
# Test with inactive user
inactive_user = {"id": "user-002", "is_active": False}
with pytest.raises(HTTPException) as exc_info:
await mock_get_current_active_user(inactive_user)
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
# Test with no user
with pytest.raises(HTTPException) as exc_info:
await mock_get_current_active_user(None)
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
@pytest.mark.asyncio
async def test_require_admin_user_should_fail_initially(self):
"""Test require admin user dependency - should fail initially."""
async def mock_get_admin_user(current_user):
if not current_user.get("is_admin", False):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin privileges required"
)
return current_user
# Test with admin user
admin_user = {"id": "admin-001", "is_admin": True}
result = await mock_get_admin_user(admin_user)
# This will fail initially
assert result == admin_user
# Test with regular user
regular_user = {"id": "user-001", "is_admin": False}
with pytest.raises(HTTPException) as exc_info:
await mock_get_admin_user(regular_user)
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
@pytest.mark.asyncio
async def test_permission_checking_should_fail_initially(self):
"""Test permission checking functionality - should fail initially."""
def require_permission(permission: str):
async def check_permission(current_user):
user_permissions = current_user.get("permissions", [])
# Admin users have all permissions
if current_user.get("is_admin", False):
return current_user
# Check specific permission
if permission not in user_permissions:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Permission '{permission}' required"
)
return current_user
return check_permission
# Test with user having required permission
user_with_permission = {
"id": "user-001",
"permissions": ["read", "write"],
"is_admin": False
}
check_read = require_permission("read")
result = await check_read(user_with_permission)
# This will fail initially
assert result == user_with_permission
# Test with user missing permission
check_admin = require_permission("admin")
with pytest.raises(HTTPException) as exc_info:
await check_admin(user_with_permission)
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
assert "admin" in exc_info.value.detail
# Test with admin user (should have all permissions)
admin_user = {"id": "admin-001", "is_admin": True, "permissions": ["read"]}
result = await check_admin(admin_user)
assert result == admin_user
class TestZoneAndRouterAccess:
"""Test zone and router access control."""
@pytest.fixture
def mock_domain_config(self):
"""Mock domain configuration."""
class MockDomainConfig:
def __init__(self):
self.zones = {
"zone1": {"id": "zone1", "name": "Zone 1", "enabled": True},
"zone2": {"id": "zone2", "name": "Zone 2", "enabled": True},
"zone3": {"id": "zone3", "name": "Zone 3", "enabled": False}
}
self.routers = {
"router1": {"id": "router1", "name": "Router 1", "enabled": True},
"router2": {"id": "router2", "name": "Router 2", "enabled": False}
}
def get_zone(self, zone_id: str):
return self.zones.get(zone_id)
def get_router(self, router_id: str):
return self.routers.get(router_id)
return MockDomainConfig()
@pytest.mark.asyncio
async def test_zone_access_validation_should_fail_initially(self, mock_domain_config):
"""Test zone access validation - should fail initially."""
async def validate_zone_access(zone_id: str, current_user=None):
zone = mock_domain_config.get_zone(zone_id)
if not zone:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Zone '{zone_id}' not found"
)
if not zone["enabled"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Zone '{zone_id}' is disabled"
)
if current_user:
if current_user.get("is_admin", False):
return zone_id
user_zones = current_user.get("zones", [])
if user_zones and zone_id not in user_zones:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access denied to zone '{zone_id}'"
)
return zone_id
# Test valid zone access
result = await validate_zone_access("zone1")
# This will fail initially
assert result == "zone1"
# Test invalid zone
with pytest.raises(HTTPException) as exc_info:
await validate_zone_access("nonexistent")
assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND
# Test disabled zone
with pytest.raises(HTTPException) as exc_info:
await validate_zone_access("zone3")
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
# Test user with zone access
user_with_access = {"id": "user-001", "zones": ["zone1", "zone2"]}
result = await validate_zone_access("zone1", user_with_access)
assert result == "zone1"
# Test user without zone access
with pytest.raises(HTTPException) as exc_info:
await validate_zone_access("zone2", {"id": "user-002", "zones": ["zone1"]})
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
@pytest.mark.asyncio
async def test_router_access_validation_should_fail_initially(self, mock_domain_config):
"""Test router access validation - should fail initially."""
async def validate_router_access(router_id: str, current_user=None):
router = mock_domain_config.get_router(router_id)
if not router:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Router '{router_id}' not found"
)
if not router["enabled"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Router '{router_id}' is disabled"
)
return router_id
# Test valid router access
result = await validate_router_access("router1")
# This will fail initially
assert result == "router1"
# Test disabled router
with pytest.raises(HTTPException) as exc_info:
await validate_router_access("router2")
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN

View File

@@ -0,0 +1,447 @@
"""
Full system integration tests for WiFi-DensePose API
Tests the complete integration of all components working together.
"""
import asyncio
import pytest
import httpx
import json
import time
from pathlib import Path
from typing import Dict, Any
from unittest.mock import AsyncMock, MagicMock, patch
from src.config.settings import get_settings
from src.app import app
from src.database.connection import get_database_manager
from src.services.orchestrator import get_service_orchestrator
from src.tasks.cleanup import get_cleanup_manager
from src.tasks.monitoring import get_monitoring_manager
from src.tasks.backup import get_backup_manager
class TestFullSystemIntegration:
"""Test complete system integration."""
@pytest.fixture
async def settings(self):
"""Get test settings."""
settings = get_settings()
settings.environment = "test"
settings.debug = True
settings.database_url = "sqlite+aiosqlite:///test_integration.db"
settings.redis_enabled = False
return settings
@pytest.fixture
async def db_manager(self, settings):
"""Get database manager for testing."""
manager = get_database_manager(settings)
await manager.initialize()
yield manager
await manager.close_all_connections()
@pytest.fixture
async def client(self, settings):
"""Get test HTTP client."""
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
yield client
@pytest.fixture
async def orchestrator(self, settings, db_manager):
"""Get service orchestrator for testing."""
orchestrator = get_service_orchestrator(settings)
await orchestrator.initialize()
yield orchestrator
await orchestrator.shutdown()
async def test_application_startup_and_shutdown(self, settings, db_manager):
"""Test complete application startup and shutdown sequence."""
# Test database initialization
await db_manager.test_connection()
stats = await db_manager.get_connection_stats()
assert stats["database"]["connected"] is True
# Test service orchestrator initialization
orchestrator = get_service_orchestrator(settings)
await orchestrator.initialize()
# Verify services are running
health_status = await orchestrator.get_health_status()
assert health_status["status"] in ["healthy", "warning"]
# Test graceful shutdown
await orchestrator.shutdown()
# Verify cleanup
final_stats = await db_manager.get_connection_stats()
assert final_stats is not None
async def test_api_endpoints_integration(self, client, settings, db_manager):
"""Test API endpoints work with database integration."""
# Test health endpoint
response = await client.get("/health")
assert response.status_code == 200
health_data = response.json()
assert "status" in health_data
assert "timestamp" in health_data
# Test metrics endpoint
response = await client.get("/metrics")
assert response.status_code == 200
# Test devices endpoint
response = await client.get("/api/v1/devices")
assert response.status_code == 200
devices_data = response.json()
assert "devices" in devices_data
assert isinstance(devices_data["devices"], list)
# Test sessions endpoint
response = await client.get("/api/v1/sessions")
assert response.status_code == 200
sessions_data = response.json()
assert "sessions" in sessions_data
assert isinstance(sessions_data["sessions"], list)
@patch('src.core.router_interface.RouterInterface')
@patch('src.core.csi_processor.CSIProcessor')
@patch('src.core.pose_estimator.PoseEstimator')
async def test_data_processing_pipeline(
self,
mock_pose_estimator,
mock_csi_processor,
mock_router_interface,
client,
settings,
db_manager
):
"""Test complete data processing pipeline integration."""
# Setup mocks
mock_router = MagicMock()
mock_router_interface.return_value = mock_router
mock_router.connect.return_value = True
mock_router.start_capture.return_value = True
mock_router.get_csi_data.return_value = {
"timestamp": time.time(),
"csi_matrix": [[1.0, 2.0], [3.0, 4.0]],
"rssi": -45,
"noise_floor": -90
}
mock_processor = MagicMock()
mock_csi_processor.return_value = mock_processor
mock_processor.process_csi_data.return_value = {
"processed_csi": [[1.1, 2.1], [3.1, 4.1]],
"quality_score": 0.85,
"phase_sanitized": True
}
mock_estimator = MagicMock()
mock_pose_estimator.return_value = mock_estimator
mock_estimator.estimate_pose.return_value = {
"pose_data": {
"keypoints": [[100, 200], [150, 250]],
"confidence": 0.9
},
"processing_time": 0.05
}
# Test device registration
device_data = {
"name": "test_router",
"ip_address": "192.168.1.1",
"device_type": "router",
"model": "test_model"
}
response = await client.post("/api/v1/devices", json=device_data)
assert response.status_code == 201
device_response = response.json()
device_id = device_response["device"]["id"]
# Test session creation
session_data = {
"device_id": device_id,
"session_type": "pose_detection",
"configuration": {
"sampling_rate": 1000,
"duration": 60
}
}
response = await client.post("/api/v1/sessions", json=session_data)
assert response.status_code == 201
session_response = response.json()
session_id = session_response["session"]["id"]
# Test CSI data submission
csi_data = {
"session_id": session_id,
"timestamp": time.time(),
"csi_matrix": [[1.0, 2.0], [3.0, 4.0]],
"rssi": -45,
"noise_floor": -90
}
response = await client.post("/api/v1/csi-data", json=csi_data)
assert response.status_code == 201
# Test pose detection retrieval
response = await client.get(f"/api/v1/sessions/{session_id}/pose-detections")
assert response.status_code == 200
# Test session completion
response = await client.patch(
f"/api/v1/sessions/{session_id}",
json={"status": "completed"}
)
assert response.status_code == 200
async def test_background_tasks_integration(self, settings, db_manager):
"""Test background tasks integration."""
# Test cleanup manager
cleanup_manager = get_cleanup_manager(settings)
cleanup_stats = cleanup_manager.get_stats()
assert "manager" in cleanup_stats
# Run cleanup task
cleanup_result = await cleanup_manager.run_all_tasks()
assert cleanup_result["success"] is True
# Test monitoring manager
monitoring_manager = get_monitoring_manager(settings)
monitoring_stats = monitoring_manager.get_stats()
assert "manager" in monitoring_stats
# Run monitoring task
monitoring_result = await monitoring_manager.run_all_tasks()
assert monitoring_result["success"] is True
# Test backup manager
backup_manager = get_backup_manager(settings)
backup_stats = backup_manager.get_stats()
assert "manager" in backup_stats
# Run backup task
backup_result = await backup_manager.run_all_tasks()
assert backup_result["success"] is True
async def test_error_handling_integration(self, client, settings, db_manager):
"""Test error handling across the system."""
# Test invalid device creation
invalid_device_data = {
"name": "", # Invalid empty name
"ip_address": "invalid_ip",
"device_type": "unknown_type"
}
response = await client.post("/api/v1/devices", json=invalid_device_data)
assert response.status_code == 422
error_response = response.json()
assert "detail" in error_response
# Test non-existent resource access
response = await client.get("/api/v1/devices/99999")
assert response.status_code == 404
# Test invalid session creation
invalid_session_data = {
"device_id": "invalid_uuid",
"session_type": "invalid_type"
}
response = await client.post("/api/v1/sessions", json=invalid_session_data)
assert response.status_code == 422
async def test_authentication_and_authorization(self, client, settings):
"""Test authentication and authorization integration."""
# Test protected endpoint without authentication
response = await client.get("/api/v1/admin/system-info")
assert response.status_code in [401, 403]
# Test with invalid token
headers = {"Authorization": "Bearer invalid_token"}
response = await client.get("/api/v1/admin/system-info", headers=headers)
assert response.status_code in [401, 403]
async def test_rate_limiting_integration(self, client, settings):
"""Test rate limiting integration."""
# Make multiple rapid requests to test rate limiting
responses = []
for i in range(10):
response = await client.get("/health")
responses.append(response.status_code)
# Should have at least some successful responses
assert 200 in responses
# Rate limiting might kick in for some requests
# This depends on the rate limiting configuration
async def test_monitoring_and_metrics_integration(self, client, settings, db_manager):
"""Test monitoring and metrics collection integration."""
# Test metrics endpoint
response = await client.get("/metrics")
assert response.status_code == 200
metrics_text = response.text
# Check for Prometheus format metrics
assert "# HELP" in metrics_text
assert "# TYPE" in metrics_text
# Test health check with detailed information
response = await client.get("/health?detailed=true")
assert response.status_code == 200
health_data = response.json()
assert "database" in health_data
assert "services" in health_data
assert "system" in health_data
async def test_configuration_management_integration(self, settings):
"""Test configuration management integration."""
# Test settings validation
assert settings.environment == "test"
assert settings.debug is True
# Test database URL configuration
assert "test_integration.db" in settings.database_url
# Test Redis configuration
assert settings.redis_enabled is False
# Test logging configuration
assert settings.log_level in ["DEBUG", "INFO", "WARNING", "ERROR"]
async def test_database_migration_integration(self, settings, db_manager):
"""Test database migration integration."""
# Test database connection
await db_manager.test_connection()
# Test table creation
async with db_manager.get_async_session() as session:
from sqlalchemy import text
# Check if tables exist
tables_query = text("""
SELECT name FROM sqlite_master
WHERE type='table' AND name NOT LIKE 'sqlite_%'
""")
result = await session.execute(tables_query)
tables = [row[0] for row in result.fetchall()]
# Should have our main tables
expected_tables = ["devices", "sessions", "csi_data", "pose_detections"]
for table in expected_tables:
assert table in tables
async def test_concurrent_operations_integration(self, client, settings, db_manager):
"""Test concurrent operations integration."""
async def create_device(name: str):
device_data = {
"name": f"test_device_{name}",
"ip_address": f"192.168.1.{name}",
"device_type": "router",
"model": "test_model"
}
response = await client.post("/api/v1/devices", json=device_data)
return response.status_code
# Create multiple devices concurrently
tasks = [create_device(str(i)) for i in range(5)]
results = await asyncio.gather(*tasks)
# All should succeed
assert all(status == 201 for status in results)
# Verify all devices were created
response = await client.get("/api/v1/devices")
assert response.status_code == 200
devices_data = response.json()
assert len(devices_data["devices"]) >= 5
async def test_system_resource_management(self, settings, db_manager, orchestrator):
"""Test system resource management integration."""
# Test connection pool management
stats = await db_manager.get_connection_stats()
assert "database" in stats
assert "pool_size" in stats["database"]
# Test service resource usage
health_status = await orchestrator.get_health_status()
assert "memory_usage" in health_status
assert "cpu_usage" in health_status
# Test cleanup of resources
await orchestrator.cleanup_resources()
# Verify resources are cleaned up
final_stats = await db_manager.get_connection_stats()
assert final_stats is not None
@pytest.mark.integration
class TestSystemPerformance:
"""Test system performance under load."""
async def test_api_response_times(self, client):
"""Test API response times under normal load."""
start_time = time.time()
response = await client.get("/health")
end_time = time.time()
assert response.status_code == 200
assert (end_time - start_time) < 1.0 # Should respond within 1 second
async def test_database_query_performance(self, db_manager):
"""Test database query performance."""
async with db_manager.get_async_session() as session:
from sqlalchemy import text
start_time = time.time()
result = await session.execute(text("SELECT 1"))
end_time = time.time()
assert result.scalar() == 1
assert (end_time - start_time) < 0.1 # Should complete within 100ms
async def test_memory_usage_stability(self, orchestrator):
"""Test memory usage remains stable."""
import psutil
import os
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss
# Perform some operations
for _ in range(10):
health_status = await orchestrator.get_health_status()
assert health_status is not None
final_memory = process.memory_info().rss
memory_increase = final_memory - initial_memory
# Memory increase should be reasonable (less than 50MB)
assert memory_increase < 50 * 1024 * 1024
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,663 @@
"""
Integration tests for hardware integration and router communication.
Tests WiFi router communication, CSI data collection, and hardware management.
"""
import pytest
import asyncio
import numpy as np
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional
from unittest.mock import AsyncMock, MagicMock, patch
import json
import socket
class MockRouterInterface:
"""Mock WiFi router interface for testing."""
def __init__(self, router_id: str, ip_address: str = "192.168.1.1"):
self.router_id = router_id
self.ip_address = ip_address
self.is_connected = False
self.is_authenticated = False
self.csi_streaming = False
self.connection_attempts = 0
self.last_heartbeat = None
self.firmware_version = "1.2.3"
self.capabilities = ["csi", "beamforming", "mimo"]
async def connect(self) -> bool:
"""Connect to the router."""
self.connection_attempts += 1
# Simulate connection failure for testing
if self.connection_attempts == 1:
return False
await asyncio.sleep(0.1) # Simulate connection time
self.is_connected = True
return True
async def authenticate(self, username: str, password: str) -> bool:
"""Authenticate with the router."""
if not self.is_connected:
return False
# Simulate authentication
if username == "admin" and password == "correct_password":
self.is_authenticated = True
return True
return False
async def start_csi_streaming(self, config: Dict[str, Any]) -> bool:
"""Start CSI data streaming."""
if not self.is_authenticated:
return False
# This should fail initially to test proper error handling
return False
async def stop_csi_streaming(self) -> bool:
"""Stop CSI data streaming."""
if self.csi_streaming:
self.csi_streaming = False
return True
return False
async def get_status(self) -> Dict[str, Any]:
"""Get router status."""
return {
"router_id": self.router_id,
"ip_address": self.ip_address,
"is_connected": self.is_connected,
"is_authenticated": self.is_authenticated,
"csi_streaming": self.csi_streaming,
"firmware_version": self.firmware_version,
"uptime_seconds": 3600,
"signal_strength": -45.2,
"temperature": 42.5,
"cpu_usage": 15.3
}
async def send_heartbeat(self) -> bool:
"""Send heartbeat to router."""
if not self.is_connected:
return False
self.last_heartbeat = datetime.utcnow()
return True
class TestRouterConnection:
"""Test router connection functionality."""
@pytest.fixture
def router_interface(self):
"""Create router interface for testing."""
return MockRouterInterface("router_001", "192.168.1.100")
@pytest.mark.asyncio
async def test_router_connection_should_fail_initially(self, router_interface):
"""Test router connection - should fail initially."""
# First connection attempt should fail
result = await router_interface.connect()
# This will fail initially because we designed the mock to fail first attempt
assert result is False
assert router_interface.is_connected is False
assert router_interface.connection_attempts == 1
# Second attempt should succeed
result = await router_interface.connect()
assert result is True
assert router_interface.is_connected is True
@pytest.mark.asyncio
async def test_router_authentication_should_fail_initially(self, router_interface):
"""Test router authentication - should fail initially."""
# Connect first
await router_interface.connect()
await router_interface.connect() # Second attempt succeeds
# Test wrong credentials
result = await router_interface.authenticate("admin", "wrong_password")
# This will fail initially
assert result is False
assert router_interface.is_authenticated is False
# Test correct credentials
result = await router_interface.authenticate("admin", "correct_password")
assert result is True
assert router_interface.is_authenticated is True
@pytest.mark.asyncio
async def test_csi_streaming_start_should_fail_initially(self, router_interface):
"""Test CSI streaming start - should fail initially."""
# Setup connection and authentication
await router_interface.connect()
await router_interface.connect() # Second attempt succeeds
await router_interface.authenticate("admin", "correct_password")
# Try to start CSI streaming
config = {
"frequency": 5.8e9,
"bandwidth": 80e6,
"sample_rate": 1000,
"antenna_config": "4x4_mimo"
}
result = await router_interface.start_csi_streaming(config)
# This will fail initially because the mock is designed to return False
assert result is False
assert router_interface.csi_streaming is False
@pytest.mark.asyncio
async def test_router_status_retrieval_should_fail_initially(self, router_interface):
"""Test router status retrieval - should fail initially."""
status = await router_interface.get_status()
# This will fail initially
assert isinstance(status, dict)
assert status["router_id"] == "router_001"
assert status["ip_address"] == "192.168.1.100"
assert "firmware_version" in status
assert "uptime_seconds" in status
assert "signal_strength" in status
assert "temperature" in status
assert "cpu_usage" in status
@pytest.mark.asyncio
async def test_heartbeat_mechanism_should_fail_initially(self, router_interface):
"""Test heartbeat mechanism - should fail initially."""
# Heartbeat without connection should fail
result = await router_interface.send_heartbeat()
# This will fail initially
assert result is False
assert router_interface.last_heartbeat is None
# Connect and try heartbeat
await router_interface.connect()
await router_interface.connect() # Second attempt succeeds
result = await router_interface.send_heartbeat()
assert result is True
assert router_interface.last_heartbeat is not None
class TestMultiRouterManagement:
"""Test management of multiple routers."""
@pytest.fixture
def router_manager(self):
"""Create router manager for testing."""
class RouterManager:
def __init__(self):
self.routers = {}
self.active_connections = 0
async def add_router(self, router_id: str, ip_address: str) -> bool:
"""Add a router to management."""
if router_id in self.routers:
return False
router = MockRouterInterface(router_id, ip_address)
self.routers[router_id] = router
return True
async def connect_router(self, router_id: str) -> bool:
"""Connect to a specific router."""
if router_id not in self.routers:
return False
router = self.routers[router_id]
# Try connecting twice (mock fails first time)
success = await router.connect()
if not success:
success = await router.connect()
if success:
self.active_connections += 1
return success
async def authenticate_router(self, router_id: str, username: str, password: str) -> bool:
"""Authenticate with a router."""
if router_id not in self.routers:
return False
router = self.routers[router_id]
return await router.authenticate(username, password)
async def get_all_status(self) -> Dict[str, Dict[str, Any]]:
"""Get status of all routers."""
status = {}
for router_id, router in self.routers.items():
status[router_id] = await router.get_status()
return status
async def start_all_csi_streaming(self, config: Dict[str, Any]) -> Dict[str, bool]:
"""Start CSI streaming on all authenticated routers."""
results = {}
for router_id, router in self.routers.items():
if router.is_authenticated:
results[router_id] = await router.start_csi_streaming(config)
else:
results[router_id] = False
return results
return RouterManager()
@pytest.mark.asyncio
async def test_multiple_router_addition_should_fail_initially(self, router_manager):
"""Test adding multiple routers - should fail initially."""
# Add first router
result1 = await router_manager.add_router("router_001", "192.168.1.100")
# This will fail initially
assert result1 is True
assert "router_001" in router_manager.routers
# Add second router
result2 = await router_manager.add_router("router_002", "192.168.1.101")
assert result2 is True
assert "router_002" in router_manager.routers
# Try to add duplicate router
result3 = await router_manager.add_router("router_001", "192.168.1.102")
assert result3 is False
assert len(router_manager.routers) == 2
@pytest.mark.asyncio
async def test_concurrent_router_connections_should_fail_initially(self, router_manager):
"""Test concurrent router connections - should fail initially."""
# Add multiple routers
await router_manager.add_router("router_001", "192.168.1.100")
await router_manager.add_router("router_002", "192.168.1.101")
await router_manager.add_router("router_003", "192.168.1.102")
# Connect to all routers concurrently
connection_tasks = [
router_manager.connect_router("router_001"),
router_manager.connect_router("router_002"),
router_manager.connect_router("router_003")
]
results = await asyncio.gather(*connection_tasks)
# This will fail initially
assert len(results) == 3
assert all(results) # All connections should succeed
assert router_manager.active_connections == 3
@pytest.mark.asyncio
async def test_router_status_aggregation_should_fail_initially(self, router_manager):
"""Test router status aggregation - should fail initially."""
# Add and connect routers
await router_manager.add_router("router_001", "192.168.1.100")
await router_manager.add_router("router_002", "192.168.1.101")
await router_manager.connect_router("router_001")
await router_manager.connect_router("router_002")
# Get all status
all_status = await router_manager.get_all_status()
# This will fail initially
assert isinstance(all_status, dict)
assert len(all_status) == 2
assert "router_001" in all_status
assert "router_002" in all_status
# Verify status structure
for router_id, status in all_status.items():
assert "router_id" in status
assert "ip_address" in status
assert "is_connected" in status
assert status["is_connected"] is True
class TestCSIDataCollection:
"""Test CSI data collection from routers."""
@pytest.fixture
def csi_collector(self):
"""Create CSI data collector."""
class CSICollector:
def __init__(self):
self.collected_data = []
self.is_collecting = False
self.collection_rate = 0
async def start_collection(self, router_interfaces: List[MockRouterInterface]) -> bool:
"""Start CSI data collection."""
# This should fail initially
return False
async def stop_collection(self) -> bool:
"""Stop CSI data collection."""
if self.is_collecting:
self.is_collecting = False
return True
return False
async def collect_frame(self, router_interface: MockRouterInterface) -> Optional[Dict[str, Any]]:
"""Collect a single CSI frame."""
if not router_interface.csi_streaming:
return None
# Simulate CSI data
return {
"timestamp": datetime.utcnow().isoformat(),
"router_id": router_interface.router_id,
"amplitude": np.random.rand(64, 32).tolist(),
"phase": np.random.rand(64, 32).tolist(),
"frequency": 5.8e9,
"bandwidth": 80e6,
"antenna_count": 4,
"subcarrier_count": 64,
"signal_quality": np.random.uniform(0.7, 0.95)
}
def get_collection_stats(self) -> Dict[str, Any]:
"""Get collection statistics."""
return {
"total_frames": len(self.collected_data),
"collection_rate": self.collection_rate,
"is_collecting": self.is_collecting,
"last_collection": self.collected_data[-1]["timestamp"] if self.collected_data else None
}
return CSICollector()
@pytest.mark.asyncio
async def test_csi_collection_start_should_fail_initially(self, csi_collector):
"""Test CSI collection start - should fail initially."""
router_interfaces = [
MockRouterInterface("router_001", "192.168.1.100"),
MockRouterInterface("router_002", "192.168.1.101")
]
result = await csi_collector.start_collection(router_interfaces)
# This will fail initially because the collector is designed to return False
assert result is False
assert csi_collector.is_collecting is False
@pytest.mark.asyncio
async def test_single_frame_collection_should_fail_initially(self, csi_collector):
"""Test single frame collection - should fail initially."""
router = MockRouterInterface("router_001", "192.168.1.100")
# Without CSI streaming enabled
frame = await csi_collector.collect_frame(router)
# This will fail initially
assert frame is None
# Enable CSI streaming (manually for testing)
router.csi_streaming = True
frame = await csi_collector.collect_frame(router)
assert frame is not None
assert "timestamp" in frame
assert "router_id" in frame
assert "amplitude" in frame
assert "phase" in frame
assert frame["router_id"] == "router_001"
@pytest.mark.asyncio
async def test_collection_statistics_should_fail_initially(self, csi_collector):
"""Test collection statistics - should fail initially."""
stats = csi_collector.get_collection_stats()
# This will fail initially
assert isinstance(stats, dict)
assert "total_frames" in stats
assert "collection_rate" in stats
assert "is_collecting" in stats
assert "last_collection" in stats
assert stats["total_frames"] == 0
assert stats["is_collecting"] is False
assert stats["last_collection"] is None
class TestHardwareErrorHandling:
"""Test hardware error handling scenarios."""
@pytest.fixture
def unreliable_router(self):
"""Create unreliable router for error testing."""
class UnreliableRouter(MockRouterInterface):
def __init__(self, router_id: str, ip_address: str = "192.168.1.1"):
super().__init__(router_id, ip_address)
self.failure_rate = 0.3 # 30% failure rate
self.connection_drops = 0
async def connect(self) -> bool:
"""Unreliable connection."""
if np.random.random() < self.failure_rate:
return False
return await super().connect()
async def send_heartbeat(self) -> bool:
"""Unreliable heartbeat."""
if np.random.random() < self.failure_rate:
self.is_connected = False
self.connection_drops += 1
return False
return await super().send_heartbeat()
async def start_csi_streaming(self, config: Dict[str, Any]) -> bool:
"""Unreliable CSI streaming."""
if np.random.random() < self.failure_rate:
return False
# Still return False for initial test failure
return False
return UnreliableRouter("unreliable_router", "192.168.1.200")
@pytest.mark.asyncio
async def test_connection_retry_mechanism_should_fail_initially(self, unreliable_router):
"""Test connection retry mechanism - should fail initially."""
max_retries = 5
success = False
for attempt in range(max_retries):
result = await unreliable_router.connect()
if result:
success = True
break
# Wait before retry
await asyncio.sleep(0.1)
# This will fail initially due to randomness, but should eventually pass
# The test demonstrates the need for retry logic
assert success or unreliable_router.connection_attempts >= max_retries
@pytest.mark.asyncio
async def test_connection_drop_detection_should_fail_initially(self, unreliable_router):
"""Test connection drop detection - should fail initially."""
# Establish connection
await unreliable_router.connect()
await unreliable_router.connect() # Ensure connection
initial_drops = unreliable_router.connection_drops
# Send multiple heartbeats to trigger potential drops
for _ in range(10):
await unreliable_router.send_heartbeat()
await asyncio.sleep(0.01)
# This will fail initially
# Should detect connection drops
final_drops = unreliable_router.connection_drops
assert final_drops >= initial_drops # May have detected drops
@pytest.mark.asyncio
async def test_hardware_timeout_handling_should_fail_initially(self):
"""Test hardware timeout handling - should fail initially."""
async def slow_operation():
"""Simulate slow hardware operation."""
await asyncio.sleep(2.0) # 2 second delay
return "success"
# Test with timeout
try:
result = await asyncio.wait_for(slow_operation(), timeout=1.0)
# This should not be reached
assert False, "Operation should have timed out"
except asyncio.TimeoutError:
# This will fail initially because we expect timeout handling
assert True # Timeout was properly handled
@pytest.mark.asyncio
async def test_network_error_simulation_should_fail_initially(self):
"""Test network error simulation - should fail initially."""
class NetworkErrorRouter(MockRouterInterface):
async def connect(self) -> bool:
"""Simulate network error."""
raise ConnectionError("Network unreachable")
router = NetworkErrorRouter("error_router", "192.168.1.999")
# This will fail initially
with pytest.raises(ConnectionError, match="Network unreachable"):
await router.connect()
class TestHardwareConfiguration:
"""Test hardware configuration management."""
@pytest.fixture
def config_manager(self):
"""Create configuration manager."""
class ConfigManager:
def __init__(self):
self.default_config = {
"frequency": 5.8e9,
"bandwidth": 80e6,
"sample_rate": 1000,
"antenna_config": "4x4_mimo",
"power_level": 20,
"channel": 36
}
self.router_configs = {}
def get_router_config(self, router_id: str) -> Dict[str, Any]:
"""Get configuration for a specific router."""
return self.router_configs.get(router_id, self.default_config.copy())
def set_router_config(self, router_id: str, config: Dict[str, Any]) -> bool:
"""Set configuration for a specific router."""
# Validate configuration
required_fields = ["frequency", "bandwidth", "sample_rate"]
if not all(field in config for field in required_fields):
return False
self.router_configs[router_id] = config
return True
def validate_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
"""Validate router configuration."""
errors = []
# Frequency validation
if "frequency" in config:
freq = config["frequency"]
if not (2.4e9 <= freq <= 6e9):
errors.append("Frequency must be between 2.4GHz and 6GHz")
# Bandwidth validation
if "bandwidth" in config:
bw = config["bandwidth"]
if bw not in [20e6, 40e6, 80e6, 160e6]:
errors.append("Bandwidth must be 20, 40, 80, or 160 MHz")
# Sample rate validation
if "sample_rate" in config:
sr = config["sample_rate"]
if not (100 <= sr <= 10000):
errors.append("Sample rate must be between 100 and 10000 Hz")
return {
"valid": len(errors) == 0,
"errors": errors
}
return ConfigManager()
def test_default_configuration_should_fail_initially(self, config_manager):
"""Test default configuration retrieval - should fail initially."""
config = config_manager.get_router_config("new_router")
# This will fail initially
assert isinstance(config, dict)
assert "frequency" in config
assert "bandwidth" in config
assert "sample_rate" in config
assert "antenna_config" in config
assert config["frequency"] == 5.8e9
assert config["bandwidth"] == 80e6
def test_configuration_validation_should_fail_initially(self, config_manager):
"""Test configuration validation - should fail initially."""
# Valid configuration
valid_config = {
"frequency": 5.8e9,
"bandwidth": 80e6,
"sample_rate": 1000
}
result = config_manager.validate_config(valid_config)
# This will fail initially
assert result["valid"] is True
assert len(result["errors"]) == 0
# Invalid configuration
invalid_config = {
"frequency": 10e9, # Too high
"bandwidth": 100e6, # Invalid
"sample_rate": 50 # Too low
}
result = config_manager.validate_config(invalid_config)
assert result["valid"] is False
assert len(result["errors"]) == 3
def test_router_specific_configuration_should_fail_initially(self, config_manager):
"""Test router-specific configuration - should fail initially."""
router_id = "router_001"
custom_config = {
"frequency": 2.4e9,
"bandwidth": 40e6,
"sample_rate": 500,
"antenna_config": "2x2_mimo"
}
# Set custom configuration
result = config_manager.set_router_config(router_id, custom_config)
# This will fail initially
assert result is True
# Retrieve custom configuration
retrieved_config = config_manager.get_router_config(router_id)
assert retrieved_config["frequency"] == 2.4e9
assert retrieved_config["bandwidth"] == 40e6
assert retrieved_config["antenna_config"] == "2x2_mimo"
# Test invalid configuration
invalid_config = {"frequency": 5.8e9} # Missing required fields
result = config_manager.set_router_config(router_id, invalid_config)
assert result is False

View File

@@ -0,0 +1,577 @@
"""
Integration tests for end-to-end pose estimation pipeline.
Tests the complete pose estimation workflow from CSI data to pose results.
"""
import pytest
import asyncio
import numpy as np
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional
from unittest.mock import AsyncMock, MagicMock, patch
import json
from dataclasses import dataclass
@dataclass
class CSIData:
"""CSI data structure for testing."""
timestamp: datetime
router_id: str
amplitude: np.ndarray
phase: np.ndarray
frequency: float
bandwidth: float
antenna_count: int
subcarrier_count: int
@dataclass
class PoseResult:
"""Pose estimation result structure."""
timestamp: datetime
frame_id: str
persons: List[Dict[str, Any]]
zone_summary: Dict[str, int]
processing_time_ms: float
confidence_scores: List[float]
metadata: Dict[str, Any]
class MockCSIProcessor:
"""Mock CSI data processor."""
def __init__(self):
self.is_initialized = False
self.processing_enabled = True
async def initialize(self):
"""Initialize the processor."""
self.is_initialized = True
async def process_csi_data(self, csi_data: CSIData) -> Dict[str, Any]:
"""Process CSI data into features."""
if not self.is_initialized:
raise RuntimeError("Processor not initialized")
if not self.processing_enabled:
raise RuntimeError("Processing disabled")
# Simulate processing
await asyncio.sleep(0.01) # Simulate processing time
return {
"features": np.random.rand(64, 32).tolist(), # Mock feature matrix
"quality_score": 0.85,
"signal_strength": -45.2,
"noise_level": -78.1,
"processed_at": datetime.utcnow().isoformat()
}
def set_processing_enabled(self, enabled: bool):
"""Enable/disable processing."""
self.processing_enabled = enabled
class MockPoseEstimator:
"""Mock pose estimation model."""
def __init__(self):
self.is_loaded = False
self.model_version = "1.0.0"
self.confidence_threshold = 0.5
async def load_model(self):
"""Load the pose estimation model."""
await asyncio.sleep(0.1) # Simulate model loading
self.is_loaded = True
async def estimate_poses(self, features: np.ndarray) -> Dict[str, Any]:
"""Estimate poses from features."""
if not self.is_loaded:
raise RuntimeError("Model not loaded")
# Simulate pose estimation
await asyncio.sleep(0.05) # Simulate inference time
# Generate mock pose data
num_persons = np.random.randint(0, 4) # 0-3 persons
persons = []
for i in range(num_persons):
confidence = np.random.uniform(0.3, 0.95)
if confidence >= self.confidence_threshold:
persons.append({
"person_id": f"person_{i}",
"confidence": confidence,
"bounding_box": {
"x": np.random.uniform(0, 800),
"y": np.random.uniform(0, 600),
"width": np.random.uniform(50, 200),
"height": np.random.uniform(100, 400)
},
"keypoints": [
{
"name": "head",
"x": np.random.uniform(0, 800),
"y": np.random.uniform(0, 200),
"confidence": np.random.uniform(0.5, 0.95)
},
{
"name": "torso",
"x": np.random.uniform(0, 800),
"y": np.random.uniform(200, 400),
"confidence": np.random.uniform(0.5, 0.95)
}
],
"activity": "standing" if np.random.random() > 0.2 else "sitting"
})
return {
"persons": persons,
"processing_time_ms": np.random.uniform(20, 80),
"model_version": self.model_version,
"confidence_threshold": self.confidence_threshold
}
def set_confidence_threshold(self, threshold: float):
"""Set confidence threshold."""
self.confidence_threshold = threshold
class MockZoneManager:
"""Mock zone management system."""
def __init__(self):
self.zones = {
"zone1": {"id": "zone1", "name": "Zone 1", "bounds": [0, 0, 400, 600]},
"zone2": {"id": "zone2", "name": "Zone 2", "bounds": [400, 0, 800, 600]},
"zone3": {"id": "zone3", "name": "Zone 3", "bounds": [0, 300, 800, 600]}
}
def assign_persons_to_zones(self, persons: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Assign detected persons to zones."""
zone_summary = {zone_id: 0 for zone_id in self.zones.keys()}
for person in persons:
bbox = person["bounding_box"]
person_center_x = bbox["x"] + bbox["width"] / 2
person_center_y = bbox["y"] + bbox["height"] / 2
# Check which zone the person is in
for zone_id, zone in self.zones.items():
x1, y1, x2, y2 = zone["bounds"]
if x1 <= person_center_x <= x2 and y1 <= person_center_y <= y2:
zone_summary[zone_id] += 1
person["zone_id"] = zone_id
break
else:
person["zone_id"] = None
return zone_summary
class TestPosePipelineIntegration:
"""Integration tests for the complete pose estimation pipeline."""
@pytest.fixture
def csi_processor(self):
"""Create CSI processor."""
return MockCSIProcessor()
@pytest.fixture
def pose_estimator(self):
"""Create pose estimator."""
return MockPoseEstimator()
@pytest.fixture
def zone_manager(self):
"""Create zone manager."""
return MockZoneManager()
@pytest.fixture
def sample_csi_data(self):
"""Create sample CSI data."""
return CSIData(
timestamp=datetime.utcnow(),
router_id="router_001",
amplitude=np.random.rand(64, 32),
phase=np.random.rand(64, 32),
frequency=5.8e9, # 5.8 GHz
bandwidth=80e6, # 80 MHz
antenna_count=4,
subcarrier_count=64
)
@pytest.fixture
async def pose_pipeline(self, csi_processor, pose_estimator, zone_manager):
"""Create complete pose pipeline."""
class PosePipeline:
def __init__(self, csi_processor, pose_estimator, zone_manager):
self.csi_processor = csi_processor
self.pose_estimator = pose_estimator
self.zone_manager = zone_manager
self.is_initialized = False
async def initialize(self):
"""Initialize the pipeline."""
await self.csi_processor.initialize()
await self.pose_estimator.load_model()
self.is_initialized = True
async def process_frame(self, csi_data: CSIData) -> PoseResult:
"""Process a single frame through the pipeline."""
if not self.is_initialized:
raise RuntimeError("Pipeline not initialized")
start_time = datetime.utcnow()
# Step 1: Process CSI data
processed_data = await self.csi_processor.process_csi_data(csi_data)
# Step 2: Extract features
features = np.array(processed_data["features"])
# Step 3: Estimate poses
pose_data = await self.pose_estimator.estimate_poses(features)
# Step 4: Assign to zones
zone_summary = self.zone_manager.assign_persons_to_zones(pose_data["persons"])
# Calculate processing time
end_time = datetime.utcnow()
processing_time = (end_time - start_time).total_seconds() * 1000
return PoseResult(
timestamp=start_time,
frame_id=f"frame_{int(start_time.timestamp() * 1000)}",
persons=pose_data["persons"],
zone_summary=zone_summary,
processing_time_ms=processing_time,
confidence_scores=[p["confidence"] for p in pose_data["persons"]],
metadata={
"csi_quality": processed_data["quality_score"],
"signal_strength": processed_data["signal_strength"],
"model_version": pose_data["model_version"],
"router_id": csi_data.router_id
}
)
pipeline = PosePipeline(csi_processor, pose_estimator, zone_manager)
await pipeline.initialize()
return pipeline
@pytest.mark.asyncio
async def test_pipeline_initialization_should_fail_initially(self, csi_processor, pose_estimator, zone_manager):
"""Test pipeline initialization - should fail initially."""
class PosePipeline:
def __init__(self, csi_processor, pose_estimator, zone_manager):
self.csi_processor = csi_processor
self.pose_estimator = pose_estimator
self.zone_manager = zone_manager
self.is_initialized = False
async def initialize(self):
await self.csi_processor.initialize()
await self.pose_estimator.load_model()
self.is_initialized = True
pipeline = PosePipeline(csi_processor, pose_estimator, zone_manager)
# Initially not initialized
assert not pipeline.is_initialized
assert not csi_processor.is_initialized
assert not pose_estimator.is_loaded
# Initialize pipeline
await pipeline.initialize()
# This will fail initially
assert pipeline.is_initialized
assert csi_processor.is_initialized
assert pose_estimator.is_loaded
@pytest.mark.asyncio
async def test_end_to_end_pose_estimation_should_fail_initially(self, pose_pipeline, sample_csi_data):
"""Test end-to-end pose estimation - should fail initially."""
result = await pose_pipeline.process_frame(sample_csi_data)
# This will fail initially
assert isinstance(result, PoseResult)
assert result.timestamp is not None
assert result.frame_id.startswith("frame_")
assert isinstance(result.persons, list)
assert isinstance(result.zone_summary, dict)
assert result.processing_time_ms > 0
assert isinstance(result.confidence_scores, list)
assert isinstance(result.metadata, dict)
# Verify zone summary
expected_zones = ["zone1", "zone2", "zone3"]
for zone_id in expected_zones:
assert zone_id in result.zone_summary
assert isinstance(result.zone_summary[zone_id], int)
assert result.zone_summary[zone_id] >= 0
# Verify metadata
assert "csi_quality" in result.metadata
assert "signal_strength" in result.metadata
assert "model_version" in result.metadata
assert "router_id" in result.metadata
assert result.metadata["router_id"] == sample_csi_data.router_id
@pytest.mark.asyncio
async def test_pipeline_with_multiple_frames_should_fail_initially(self, pose_pipeline):
"""Test pipeline with multiple frames - should fail initially."""
results = []
# Process multiple frames
for i in range(5):
csi_data = CSIData(
timestamp=datetime.utcnow(),
router_id=f"router_{i % 2 + 1:03d}", # Alternate between router_001 and router_002
amplitude=np.random.rand(64, 32),
phase=np.random.rand(64, 32),
frequency=5.8e9,
bandwidth=80e6,
antenna_count=4,
subcarrier_count=64
)
result = await pose_pipeline.process_frame(csi_data)
results.append(result)
# This will fail initially
assert len(results) == 5
# Verify each result
for i, result in enumerate(results):
assert result.frame_id != results[0].frame_id if i > 0 else True
assert result.metadata["router_id"] in ["router_001", "router_002"]
assert result.processing_time_ms > 0
@pytest.mark.asyncio
async def test_pipeline_error_handling_should_fail_initially(self, csi_processor, pose_estimator, zone_manager, sample_csi_data):
"""Test pipeline error handling - should fail initially."""
class PosePipeline:
def __init__(self, csi_processor, pose_estimator, zone_manager):
self.csi_processor = csi_processor
self.pose_estimator = pose_estimator
self.zone_manager = zone_manager
self.is_initialized = False
async def initialize(self):
await self.csi_processor.initialize()
await self.pose_estimator.load_model()
self.is_initialized = True
async def process_frame(self, csi_data):
if not self.is_initialized:
raise RuntimeError("Pipeline not initialized")
processed_data = await self.csi_processor.process_csi_data(csi_data)
features = np.array(processed_data["features"])
pose_data = await self.pose_estimator.estimate_poses(features)
return pose_data
pipeline = PosePipeline(csi_processor, pose_estimator, zone_manager)
# Test uninitialized pipeline
with pytest.raises(RuntimeError, match="Pipeline not initialized"):
await pipeline.process_frame(sample_csi_data)
# Initialize pipeline
await pipeline.initialize()
# Test with disabled CSI processor
csi_processor.set_processing_enabled(False)
with pytest.raises(RuntimeError, match="Processing disabled"):
await pipeline.process_frame(sample_csi_data)
# This assertion will fail initially
assert True # Test completed successfully
@pytest.mark.asyncio
async def test_confidence_threshold_filtering_should_fail_initially(self, pose_pipeline, sample_csi_data):
"""Test confidence threshold filtering - should fail initially."""
# Set high confidence threshold
pose_pipeline.pose_estimator.set_confidence_threshold(0.9)
result = await pose_pipeline.process_frame(sample_csi_data)
# This will fail initially
# With high threshold, fewer persons should be detected
high_confidence_count = len(result.persons)
# Set low confidence threshold
pose_pipeline.pose_estimator.set_confidence_threshold(0.1)
result = await pose_pipeline.process_frame(sample_csi_data)
low_confidence_count = len(result.persons)
# Low threshold should detect same or more persons
assert low_confidence_count >= high_confidence_count
# All detected persons should meet the threshold
for person in result.persons:
assert person["confidence"] >= 0.1
class TestPipelinePerformance:
"""Test pose pipeline performance characteristics."""
@pytest.mark.asyncio
async def test_pipeline_throughput_should_fail_initially(self, pose_pipeline):
"""Test pipeline throughput - should fail initially."""
frame_count = 10
start_time = datetime.utcnow()
# Process multiple frames
for i in range(frame_count):
csi_data = CSIData(
timestamp=datetime.utcnow(),
router_id="router_001",
amplitude=np.random.rand(64, 32),
phase=np.random.rand(64, 32),
frequency=5.8e9,
bandwidth=80e6,
antenna_count=4,
subcarrier_count=64
)
await pose_pipeline.process_frame(csi_data)
end_time = datetime.utcnow()
total_time = (end_time - start_time).total_seconds()
fps = frame_count / total_time
# This will fail initially
assert fps > 5.0 # Should process at least 5 FPS
assert total_time < 5.0 # Should complete 10 frames in under 5 seconds
@pytest.mark.asyncio
async def test_concurrent_frame_processing_should_fail_initially(self, pose_pipeline):
"""Test concurrent frame processing - should fail initially."""
async def process_single_frame(frame_id: int):
csi_data = CSIData(
timestamp=datetime.utcnow(),
router_id=f"router_{frame_id % 3 + 1:03d}",
amplitude=np.random.rand(64, 32),
phase=np.random.rand(64, 32),
frequency=5.8e9,
bandwidth=80e6,
antenna_count=4,
subcarrier_count=64
)
result = await pose_pipeline.process_frame(csi_data)
return result.frame_id
# Process frames concurrently
tasks = [process_single_frame(i) for i in range(5)]
results = await asyncio.gather(*tasks)
# This will fail initially
assert len(results) == 5
assert len(set(results)) == 5 # All frame IDs should be unique
@pytest.mark.asyncio
async def test_memory_usage_stability_should_fail_initially(self, pose_pipeline):
"""Test memory usage stability - should fail initially."""
import psutil
import os
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss
# Process many frames
for i in range(50):
csi_data = CSIData(
timestamp=datetime.utcnow(),
router_id="router_001",
amplitude=np.random.rand(64, 32),
phase=np.random.rand(64, 32),
frequency=5.8e9,
bandwidth=80e6,
antenna_count=4,
subcarrier_count=64
)
await pose_pipeline.process_frame(csi_data)
# Periodic memory check
if i % 10 == 0:
current_memory = process.memory_info().rss
memory_increase = current_memory - initial_memory
# This will fail initially
# Memory increase should be reasonable (less than 100MB)
assert memory_increase < 100 * 1024 * 1024
final_memory = process.memory_info().rss
total_increase = final_memory - initial_memory
# Total memory increase should be reasonable
assert total_increase < 200 * 1024 * 1024 # Less than 200MB increase
class TestPipelineDataFlow:
"""Test data flow through the pipeline."""
@pytest.mark.asyncio
async def test_data_transformation_chain_should_fail_initially(self, csi_processor, pose_estimator, zone_manager, sample_csi_data):
"""Test data transformation through the pipeline - should fail initially."""
# Step 1: CSI processing
await csi_processor.initialize()
processed_data = await csi_processor.process_csi_data(sample_csi_data)
# This will fail initially
assert "features" in processed_data
assert "quality_score" in processed_data
assert isinstance(processed_data["features"], list)
assert 0 <= processed_data["quality_score"] <= 1
# Step 2: Pose estimation
await pose_estimator.load_model()
features = np.array(processed_data["features"])
pose_data = await pose_estimator.estimate_poses(features)
assert "persons" in pose_data
assert "processing_time_ms" in pose_data
assert isinstance(pose_data["persons"], list)
# Step 3: Zone assignment
zone_summary = zone_manager.assign_persons_to_zones(pose_data["persons"])
assert isinstance(zone_summary, dict)
assert all(isinstance(count, int) for count in zone_summary.values())
# Verify person zone assignments
for person in pose_data["persons"]:
if "zone_id" in person and person["zone_id"]:
assert person["zone_id"] in zone_summary
@pytest.mark.asyncio
async def test_pipeline_state_consistency_should_fail_initially(self, pose_pipeline, sample_csi_data):
"""Test pipeline state consistency - should fail initially."""
# Process the same frame multiple times
results = []
for _ in range(3):
result = await pose_pipeline.process_frame(sample_csi_data)
results.append(result)
# This will fail initially
# Results should be consistent (same input should produce similar output)
assert len(results) == 3
# All results should have the same router_id
router_ids = [r.metadata["router_id"] for r in results]
assert all(rid == router_ids[0] for rid in router_ids)
# Processing times should be reasonable and similar
processing_times = [r.processing_time_ms for r in results]
assert all(10 <= pt <= 200 for pt in processing_times) # Between 10ms and 200ms

View File

@@ -0,0 +1,565 @@
"""
Integration tests for rate limiting functionality.
Tests rate limit behavior, throttling, and quota management.
"""
import pytest
import asyncio
from datetime import datetime, timedelta
from typing import Dict, Any, List
from unittest.mock import AsyncMock, MagicMock, patch
import time
from fastapi import HTTPException, status, Request, Response
class MockRateLimiter:
"""Mock rate limiter for testing."""
def __init__(self, requests_per_minute: int = 60, requests_per_hour: int = 1000):
self.requests_per_minute = requests_per_minute
self.requests_per_hour = requests_per_hour
self.request_history = {}
self.blocked_clients = set()
def _get_client_key(self, client_id: str, endpoint: str = None) -> str:
"""Get client key for rate limiting."""
return f"{client_id}:{endpoint}" if endpoint else client_id
def _cleanup_old_requests(self, client_key: str):
"""Clean up old request records."""
if client_key not in self.request_history:
return
now = datetime.utcnow()
minute_ago = now - timedelta(minutes=1)
hour_ago = now - timedelta(hours=1)
# Keep only requests from the last hour
self.request_history[client_key] = [
req_time for req_time in self.request_history[client_key]
if req_time > hour_ago
]
def check_rate_limit(self, client_id: str, endpoint: str = None) -> Dict[str, Any]:
"""Check if client is within rate limits."""
client_key = self._get_client_key(client_id, endpoint)
if client_id in self.blocked_clients:
return {
"allowed": False,
"reason": "Client blocked",
"retry_after": 3600 # 1 hour
}
self._cleanup_old_requests(client_key)
if client_key not in self.request_history:
self.request_history[client_key] = []
now = datetime.utcnow()
minute_ago = now - timedelta(minutes=1)
# Count requests in the last minute
recent_requests = [
req_time for req_time in self.request_history[client_key]
if req_time > minute_ago
]
# Count requests in the last hour
hour_requests = len(self.request_history[client_key])
if len(recent_requests) >= self.requests_per_minute:
return {
"allowed": False,
"reason": "Rate limit exceeded (per minute)",
"retry_after": 60,
"current_requests": len(recent_requests),
"limit": self.requests_per_minute
}
if hour_requests >= self.requests_per_hour:
return {
"allowed": False,
"reason": "Rate limit exceeded (per hour)",
"retry_after": 3600,
"current_requests": hour_requests,
"limit": self.requests_per_hour
}
# Record this request
self.request_history[client_key].append(now)
return {
"allowed": True,
"remaining_minute": self.requests_per_minute - len(recent_requests) - 1,
"remaining_hour": self.requests_per_hour - hour_requests - 1,
"reset_time": minute_ago + timedelta(minutes=1)
}
def block_client(self, client_id: str):
"""Block a client."""
self.blocked_clients.add(client_id)
def unblock_client(self, client_id: str):
"""Unblock a client."""
self.blocked_clients.discard(client_id)
class TestRateLimitingBasic:
"""Test basic rate limiting functionality."""
@pytest.fixture
def rate_limiter(self):
"""Create rate limiter for testing."""
return MockRateLimiter(requests_per_minute=5, requests_per_hour=100)
def test_rate_limit_within_bounds_should_fail_initially(self, rate_limiter):
"""Test rate limiting within bounds - should fail initially."""
client_id = "test-client-001"
# Make requests within limit
for i in range(3):
result = rate_limiter.check_rate_limit(client_id)
# This will fail initially
assert result["allowed"] is True
assert "remaining_minute" in result
assert "remaining_hour" in result
def test_rate_limit_per_minute_exceeded_should_fail_initially(self, rate_limiter):
"""Test per-minute rate limit exceeded - should fail initially."""
client_id = "test-client-002"
# Make requests up to the limit
for i in range(5):
result = rate_limiter.check_rate_limit(client_id)
assert result["allowed"] is True
# Next request should be blocked
result = rate_limiter.check_rate_limit(client_id)
# This will fail initially
assert result["allowed"] is False
assert "per minute" in result["reason"]
assert result["retry_after"] == 60
assert result["current_requests"] == 5
assert result["limit"] == 5
def test_rate_limit_per_hour_exceeded_should_fail_initially(self, rate_limiter):
"""Test per-hour rate limit exceeded - should fail initially."""
# Create rate limiter with very low hour limit for testing
limiter = MockRateLimiter(requests_per_minute=10, requests_per_hour=3)
client_id = "test-client-003"
# Make requests up to hour limit
for i in range(3):
result = limiter.check_rate_limit(client_id)
assert result["allowed"] is True
# Next request should be blocked
result = limiter.check_rate_limit(client_id)
# This will fail initially
assert result["allowed"] is False
assert "per hour" in result["reason"]
assert result["retry_after"] == 3600
def test_blocked_client_should_fail_initially(self, rate_limiter):
"""Test blocked client handling - should fail initially."""
client_id = "blocked-client"
# Block the client
rate_limiter.block_client(client_id)
# Request should be blocked
result = rate_limiter.check_rate_limit(client_id)
# This will fail initially
assert result["allowed"] is False
assert result["reason"] == "Client blocked"
assert result["retry_after"] == 3600
# Unblock and test
rate_limiter.unblock_client(client_id)
result = rate_limiter.check_rate_limit(client_id)
assert result["allowed"] is True
def test_endpoint_specific_rate_limiting_should_fail_initially(self, rate_limiter):
"""Test endpoint-specific rate limiting - should fail initially."""
client_id = "test-client-004"
# Make requests to different endpoints
result1 = rate_limiter.check_rate_limit(client_id, "/api/pose/current")
result2 = rate_limiter.check_rate_limit(client_id, "/api/stream/status")
# This will fail initially
assert result1["allowed"] is True
assert result2["allowed"] is True
# Each endpoint should have separate rate limiting
for i in range(4):
rate_limiter.check_rate_limit(client_id, "/api/pose/current")
# Pose endpoint should be at limit, but stream should still work
pose_result = rate_limiter.check_rate_limit(client_id, "/api/pose/current")
stream_result = rate_limiter.check_rate_limit(client_id, "/api/stream/status")
assert pose_result["allowed"] is False
assert stream_result["allowed"] is True
class TestRateLimitMiddleware:
"""Test rate limiting middleware functionality."""
@pytest.fixture
def mock_request(self):
"""Mock FastAPI request."""
class MockRequest:
def __init__(self, client_ip="127.0.0.1", path="/api/test", method="GET"):
self.client = MagicMock()
self.client.host = client_ip
self.url = MagicMock()
self.url.path = path
self.method = method
self.headers = {}
self.state = MagicMock()
return MockRequest
@pytest.fixture
def mock_response(self):
"""Mock FastAPI response."""
class MockResponse:
def __init__(self):
self.status_code = 200
self.headers = {}
return MockResponse()
@pytest.fixture
def rate_limit_middleware(self, rate_limiter):
"""Create rate limiting middleware."""
class RateLimitMiddleware:
def __init__(self, rate_limiter):
self.rate_limiter = rate_limiter
async def __call__(self, request, call_next):
# Get client identifier
client_id = self._get_client_id(request)
endpoint = request.url.path
# Check rate limit
limit_result = self.rate_limiter.check_rate_limit(client_id, endpoint)
if not limit_result["allowed"]:
# Return rate limit exceeded response
response = Response(
content=f"Rate limit exceeded: {limit_result['reason']}",
status_code=status.HTTP_429_TOO_MANY_REQUESTS
)
response.headers["Retry-After"] = str(limit_result["retry_after"])
response.headers["X-RateLimit-Limit"] = str(limit_result.get("limit", "unknown"))
response.headers["X-RateLimit-Remaining"] = "0"
return response
# Process request
response = await call_next(request)
# Add rate limit headers
response.headers["X-RateLimit-Limit"] = str(self.rate_limiter.requests_per_minute)
response.headers["X-RateLimit-Remaining"] = str(limit_result.get("remaining_minute", 0))
response.headers["X-RateLimit-Reset"] = str(int(limit_result.get("reset_time", datetime.utcnow()).timestamp()))
return response
def _get_client_id(self, request):
"""Get client identifier from request."""
# Check for API key in headers
api_key = request.headers.get("X-API-Key")
if api_key:
return f"api:{api_key}"
# Check for user ID in request state (from auth)
if hasattr(request.state, "user") and request.state.user:
return f"user:{request.state.user.get('id', 'unknown')}"
# Fall back to IP address
return f"ip:{request.client.host}"
return RateLimitMiddleware(rate_limiter)
@pytest.mark.asyncio
async def test_middleware_allows_normal_requests_should_fail_initially(
self, rate_limit_middleware, mock_request, mock_response
):
"""Test middleware allows normal requests - should fail initially."""
request = mock_request()
async def mock_call_next(req):
return mock_response
response = await rate_limit_middleware(request, mock_call_next)
# This will fail initially
assert response.status_code == 200
assert "X-RateLimit-Limit" in response.headers
assert "X-RateLimit-Remaining" in response.headers
assert "X-RateLimit-Reset" in response.headers
@pytest.mark.asyncio
async def test_middleware_blocks_excessive_requests_should_fail_initially(
self, rate_limit_middleware, mock_request
):
"""Test middleware blocks excessive requests - should fail initially."""
request = mock_request()
async def mock_call_next(req):
response = Response(content="OK", status_code=200)
return response
# Make requests up to the limit
for i in range(5):
response = await rate_limit_middleware(request, mock_call_next)
assert response.status_code == 200
# Next request should be blocked
response = await rate_limit_middleware(request, mock_call_next)
# This will fail initially
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
assert "Retry-After" in response.headers
assert "X-RateLimit-Remaining" in response.headers
assert response.headers["X-RateLimit-Remaining"] == "0"
@pytest.mark.asyncio
async def test_middleware_client_identification_should_fail_initially(
self, rate_limit_middleware, mock_request
):
"""Test middleware client identification - should fail initially."""
# Test API key identification
request_with_api_key = mock_request()
request_with_api_key.headers["X-API-Key"] = "test-api-key-123"
# Test user identification
request_with_user = mock_request()
request_with_user.state.user = {"id": "user-123"}
# Test IP identification
request_with_ip = mock_request(client_ip="192.168.1.100")
async def mock_call_next(req):
return Response(content="OK", status_code=200)
# Each should be treated as different clients
response1 = await rate_limit_middleware(request_with_api_key, mock_call_next)
response2 = await rate_limit_middleware(request_with_user, mock_call_next)
response3 = await rate_limit_middleware(request_with_ip, mock_call_next)
# This will fail initially
assert response1.status_code == 200
assert response2.status_code == 200
assert response3.status_code == 200
class TestRateLimitingStrategies:
"""Test different rate limiting strategies."""
@pytest.fixture
def sliding_window_limiter(self):
"""Create sliding window rate limiter."""
class SlidingWindowLimiter:
def __init__(self, window_size_seconds: int = 60, max_requests: int = 10):
self.window_size = window_size_seconds
self.max_requests = max_requests
self.request_times = {}
def check_limit(self, client_id: str) -> Dict[str, Any]:
now = time.time()
if client_id not in self.request_times:
self.request_times[client_id] = []
# Remove old requests outside the window
cutoff_time = now - self.window_size
self.request_times[client_id] = [
req_time for req_time in self.request_times[client_id]
if req_time > cutoff_time
]
# Check if we're at the limit
if len(self.request_times[client_id]) >= self.max_requests:
oldest_request = min(self.request_times[client_id])
retry_after = int(oldest_request + self.window_size - now)
return {
"allowed": False,
"retry_after": max(retry_after, 1),
"current_requests": len(self.request_times[client_id]),
"limit": self.max_requests
}
# Record this request
self.request_times[client_id].append(now)
return {
"allowed": True,
"remaining": self.max_requests - len(self.request_times[client_id]),
"window_reset": int(now + self.window_size)
}
return SlidingWindowLimiter(window_size_seconds=10, max_requests=3)
@pytest.fixture
def token_bucket_limiter(self):
"""Create token bucket rate limiter."""
class TokenBucketLimiter:
def __init__(self, capacity: int = 10, refill_rate: float = 1.0):
self.capacity = capacity
self.refill_rate = refill_rate # tokens per second
self.buckets = {}
def check_limit(self, client_id: str) -> Dict[str, Any]:
now = time.time()
if client_id not in self.buckets:
self.buckets[client_id] = {
"tokens": self.capacity,
"last_refill": now
}
bucket = self.buckets[client_id]
# Refill tokens based on time elapsed
time_elapsed = now - bucket["last_refill"]
tokens_to_add = time_elapsed * self.refill_rate
bucket["tokens"] = min(self.capacity, bucket["tokens"] + tokens_to_add)
bucket["last_refill"] = now
# Check if we have tokens available
if bucket["tokens"] < 1:
return {
"allowed": False,
"retry_after": int((1 - bucket["tokens"]) / self.refill_rate),
"tokens_remaining": bucket["tokens"]
}
# Consume a token
bucket["tokens"] -= 1
return {
"allowed": True,
"tokens_remaining": bucket["tokens"]
}
return TokenBucketLimiter(capacity=5, refill_rate=0.5) # 0.5 tokens per second
def test_sliding_window_limiter_should_fail_initially(self, sliding_window_limiter):
"""Test sliding window rate limiter - should fail initially."""
client_id = "sliding-test-client"
# Make requests up to limit
for i in range(3):
result = sliding_window_limiter.check_limit(client_id)
# This will fail initially
assert result["allowed"] is True
assert "remaining" in result
# Next request should be blocked
result = sliding_window_limiter.check_limit(client_id)
assert result["allowed"] is False
assert result["current_requests"] == 3
assert result["limit"] == 3
def test_token_bucket_limiter_should_fail_initially(self, token_bucket_limiter):
"""Test token bucket rate limiter - should fail initially."""
client_id = "bucket-test-client"
# Make requests up to capacity
for i in range(5):
result = token_bucket_limiter.check_limit(client_id)
# This will fail initially
assert result["allowed"] is True
assert "tokens_remaining" in result
# Next request should be blocked (no tokens left)
result = token_bucket_limiter.check_limit(client_id)
assert result["allowed"] is False
assert result["tokens_remaining"] < 1
@pytest.mark.asyncio
async def test_token_bucket_refill_should_fail_initially(self, token_bucket_limiter):
"""Test token bucket refill mechanism - should fail initially."""
client_id = "refill-test-client"
# Exhaust all tokens
for i in range(5):
token_bucket_limiter.check_limit(client_id)
# Should be blocked
result = token_bucket_limiter.check_limit(client_id)
assert result["allowed"] is False
# Wait for refill (simulate time passing)
await asyncio.sleep(2.1) # Wait for 1 token to be refilled (0.5 tokens/sec * 2.1 sec > 1)
# Should now be allowed
result = token_bucket_limiter.check_limit(client_id)
# This will fail initially
assert result["allowed"] is True
class TestRateLimitingPerformance:
"""Test rate limiting performance characteristics."""
@pytest.mark.asyncio
async def test_concurrent_rate_limit_checks_should_fail_initially(self):
"""Test concurrent rate limit checks - should fail initially."""
rate_limiter = MockRateLimiter(requests_per_minute=100, requests_per_hour=1000)
async def make_request(client_id: str, request_id: int):
result = rate_limiter.check_rate_limit(f"{client_id}-{request_id}")
return result["allowed"]
# Create many concurrent requests
tasks = [
make_request("concurrent-client", i)
for i in range(50)
]
results = await asyncio.gather(*tasks)
# This will fail initially
assert len(results) == 50
assert all(results) # All should be allowed since they're different clients
@pytest.mark.asyncio
async def test_rate_limiter_memory_cleanup_should_fail_initially(self):
"""Test rate limiter memory cleanup - should fail initially."""
rate_limiter = MockRateLimiter(requests_per_minute=10, requests_per_hour=100)
# Make requests for many different clients
for i in range(100):
rate_limiter.check_rate_limit(f"client-{i}")
initial_memory_size = len(rate_limiter.request_history)
# Simulate time passing and cleanup
for client_key in list(rate_limiter.request_history.keys()):
rate_limiter._cleanup_old_requests(client_key)
# This will fail initially
assert initial_memory_size == 100
# After cleanup, old entries should be removed
# (In a real implementation, this would clean up old timestamps)
final_memory_size = len([
key for key, history in rate_limiter.request_history.items()
if history # Only count non-empty histories
])
assert final_memory_size <= initial_memory_size

View File

@@ -0,0 +1,729 @@
"""
Integration tests for real-time streaming pipeline.
Tests the complete real-time data flow from CSI collection to client delivery.
"""
import pytest
import asyncio
import numpy as np
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional, AsyncGenerator
from unittest.mock import AsyncMock, MagicMock, patch
import json
import queue
import threading
from dataclasses import dataclass
@dataclass
class StreamFrame:
"""Streaming frame data structure."""
frame_id: str
timestamp: datetime
router_id: str
pose_data: Dict[str, Any]
processing_time_ms: float
quality_score: float
class MockStreamBuffer:
"""Mock streaming buffer for testing."""
def __init__(self, max_size: int = 100):
self.max_size = max_size
self.buffer = asyncio.Queue(maxsize=max_size)
self.dropped_frames = 0
self.total_frames = 0
async def put_frame(self, frame: StreamFrame) -> bool:
"""Add frame to buffer."""
self.total_frames += 1
try:
self.buffer.put_nowait(frame)
return True
except asyncio.QueueFull:
self.dropped_frames += 1
return False
async def get_frame(self, timeout: float = 1.0) -> Optional[StreamFrame]:
"""Get frame from buffer."""
try:
return await asyncio.wait_for(self.buffer.get(), timeout=timeout)
except asyncio.TimeoutError:
return None
def get_stats(self) -> Dict[str, Any]:
"""Get buffer statistics."""
return {
"buffer_size": self.buffer.qsize(),
"max_size": self.max_size,
"total_frames": self.total_frames,
"dropped_frames": self.dropped_frames,
"drop_rate": self.dropped_frames / max(self.total_frames, 1)
}
class MockStreamProcessor:
"""Mock stream processor for testing."""
def __init__(self):
self.is_running = False
self.processing_rate = 30 # FPS
self.frame_counter = 0
self.error_rate = 0.0
async def start_processing(self, input_buffer: MockStreamBuffer, output_buffer: MockStreamBuffer):
"""Start stream processing."""
self.is_running = True
while self.is_running:
try:
# Get frame from input
frame = await input_buffer.get_frame(timeout=0.1)
if frame is None:
continue
# Simulate processing error
if np.random.random() < self.error_rate:
continue # Skip frame due to error
# Process frame
processed_frame = await self._process_frame(frame)
# Put to output buffer
await output_buffer.put_frame(processed_frame)
# Control processing rate
await asyncio.sleep(1.0 / self.processing_rate)
except Exception as e:
# Handle processing errors
continue
async def _process_frame(self, frame: StreamFrame) -> StreamFrame:
"""Process a single frame."""
# Simulate processing time
await asyncio.sleep(0.01)
# Add processing metadata
processed_pose_data = frame.pose_data.copy()
processed_pose_data["processed_at"] = datetime.utcnow().isoformat()
processed_pose_data["processor_id"] = "stream_processor_001"
return StreamFrame(
frame_id=f"processed_{frame.frame_id}",
timestamp=frame.timestamp,
router_id=frame.router_id,
pose_data=processed_pose_data,
processing_time_ms=frame.processing_time_ms + 10, # Add processing overhead
quality_score=frame.quality_score * 0.95 # Slight quality degradation
)
def stop_processing(self):
"""Stop stream processing."""
self.is_running = False
def set_error_rate(self, error_rate: float):
"""Set processing error rate."""
self.error_rate = error_rate
class MockWebSocketManager:
"""Mock WebSocket manager for testing."""
def __init__(self):
self.connected_clients = {}
self.message_queue = asyncio.Queue()
self.total_messages_sent = 0
self.failed_sends = 0
async def add_client(self, client_id: str, websocket_mock) -> bool:
"""Add WebSocket client."""
if client_id in self.connected_clients:
return False
self.connected_clients[client_id] = {
"websocket": websocket_mock,
"connected_at": datetime.utcnow(),
"messages_sent": 0,
"last_ping": datetime.utcnow()
}
return True
async def remove_client(self, client_id: str) -> bool:
"""Remove WebSocket client."""
if client_id in self.connected_clients:
del self.connected_clients[client_id]
return True
return False
async def broadcast_frame(self, frame: StreamFrame) -> Dict[str, bool]:
"""Broadcast frame to all connected clients."""
results = {}
message = {
"type": "pose_update",
"frame_id": frame.frame_id,
"timestamp": frame.timestamp.isoformat(),
"router_id": frame.router_id,
"pose_data": frame.pose_data,
"processing_time_ms": frame.processing_time_ms,
"quality_score": frame.quality_score
}
for client_id, client_info in self.connected_clients.items():
try:
# Simulate WebSocket send
success = await self._send_to_client(client_id, message)
results[client_id] = success
if success:
client_info["messages_sent"] += 1
self.total_messages_sent += 1
else:
self.failed_sends += 1
except Exception:
results[client_id] = False
self.failed_sends += 1
return results
async def _send_to_client(self, client_id: str, message: Dict[str, Any]) -> bool:
"""Send message to specific client."""
# Simulate network issues
if np.random.random() < 0.05: # 5% failure rate
return False
# Simulate send delay
await asyncio.sleep(0.001)
return True
def get_client_stats(self) -> Dict[str, Any]:
"""Get client statistics."""
return {
"connected_clients": len(self.connected_clients),
"total_messages_sent": self.total_messages_sent,
"failed_sends": self.failed_sends,
"clients": {
client_id: {
"messages_sent": info["messages_sent"],
"connected_duration": (datetime.utcnow() - info["connected_at"]).total_seconds()
}
for client_id, info in self.connected_clients.items()
}
}
class TestStreamingPipelineBasic:
"""Test basic streaming pipeline functionality."""
@pytest.fixture
def stream_buffer(self):
"""Create stream buffer."""
return MockStreamBuffer(max_size=50)
@pytest.fixture
def stream_processor(self):
"""Create stream processor."""
return MockStreamProcessor()
@pytest.fixture
def websocket_manager(self):
"""Create WebSocket manager."""
return MockWebSocketManager()
@pytest.fixture
def sample_frame(self):
"""Create sample stream frame."""
return StreamFrame(
frame_id="frame_001",
timestamp=datetime.utcnow(),
router_id="router_001",
pose_data={
"persons": [
{
"person_id": "person_1",
"confidence": 0.85,
"bounding_box": {"x": 100, "y": 150, "width": 80, "height": 180},
"activity": "standing"
}
],
"zone_summary": {"zone1": 1, "zone2": 0}
},
processing_time_ms=45.2,
quality_score=0.92
)
@pytest.mark.asyncio
async def test_buffer_frame_operations_should_fail_initially(self, stream_buffer, sample_frame):
"""Test buffer frame operations - should fail initially."""
# Put frame in buffer
result = await stream_buffer.put_frame(sample_frame)
# This will fail initially
assert result is True
# Get frame from buffer
retrieved_frame = await stream_buffer.get_frame()
assert retrieved_frame is not None
assert retrieved_frame.frame_id == sample_frame.frame_id
assert retrieved_frame.router_id == sample_frame.router_id
# Buffer should be empty now
empty_frame = await stream_buffer.get_frame(timeout=0.1)
assert empty_frame is None
@pytest.mark.asyncio
async def test_buffer_overflow_handling_should_fail_initially(self, sample_frame):
"""Test buffer overflow handling - should fail initially."""
small_buffer = MockStreamBuffer(max_size=2)
# Fill buffer to capacity
result1 = await small_buffer.put_frame(sample_frame)
result2 = await small_buffer.put_frame(sample_frame)
# This will fail initially
assert result1 is True
assert result2 is True
# Next frame should be dropped
result3 = await small_buffer.put_frame(sample_frame)
assert result3 is False
# Check statistics
stats = small_buffer.get_stats()
assert stats["total_frames"] == 3
assert stats["dropped_frames"] == 1
assert stats["drop_rate"] > 0
@pytest.mark.asyncio
async def test_stream_processing_should_fail_initially(self, stream_processor, sample_frame):
"""Test stream processing - should fail initially."""
input_buffer = MockStreamBuffer()
output_buffer = MockStreamBuffer()
# Add frame to input buffer
await input_buffer.put_frame(sample_frame)
# Start processing task
processing_task = asyncio.create_task(
stream_processor.start_processing(input_buffer, output_buffer)
)
# Wait for processing
await asyncio.sleep(0.2)
# Stop processing
stream_processor.stop_processing()
await processing_task
# Check output
processed_frame = await output_buffer.get_frame(timeout=0.1)
# This will fail initially
assert processed_frame is not None
assert processed_frame.frame_id.startswith("processed_")
assert "processed_at" in processed_frame.pose_data
assert processed_frame.processing_time_ms > sample_frame.processing_time_ms
@pytest.mark.asyncio
async def test_websocket_client_management_should_fail_initially(self, websocket_manager):
"""Test WebSocket client management - should fail initially."""
mock_websocket = MagicMock()
# Add client
result = await websocket_manager.add_client("client_001", mock_websocket)
# This will fail initially
assert result is True
assert "client_001" in websocket_manager.connected_clients
# Try to add duplicate client
result = await websocket_manager.add_client("client_001", mock_websocket)
assert result is False
# Remove client
result = await websocket_manager.remove_client("client_001")
assert result is True
assert "client_001" not in websocket_manager.connected_clients
@pytest.mark.asyncio
async def test_frame_broadcasting_should_fail_initially(self, websocket_manager, sample_frame):
"""Test frame broadcasting - should fail initially."""
# Add multiple clients
for i in range(3):
await websocket_manager.add_client(f"client_{i:03d}", MagicMock())
# Broadcast frame
results = await websocket_manager.broadcast_frame(sample_frame)
# This will fail initially
assert len(results) == 3
assert all(isinstance(success, bool) for success in results.values())
# Check statistics
stats = websocket_manager.get_client_stats()
assert stats["connected_clients"] == 3
assert stats["total_messages_sent"] >= 0
class TestStreamingPipelineIntegration:
"""Test complete streaming pipeline integration."""
@pytest.fixture
async def streaming_pipeline(self):
"""Create complete streaming pipeline."""
class StreamingPipeline:
def __init__(self):
self.input_buffer = MockStreamBuffer(max_size=100)
self.output_buffer = MockStreamBuffer(max_size=100)
self.processor = MockStreamProcessor()
self.websocket_manager = MockWebSocketManager()
self.is_running = False
self.processing_task = None
self.broadcasting_task = None
async def start(self):
"""Start the streaming pipeline."""
if self.is_running:
return False
self.is_running = True
# Start processing task
self.processing_task = asyncio.create_task(
self.processor.start_processing(self.input_buffer, self.output_buffer)
)
# Start broadcasting task
self.broadcasting_task = asyncio.create_task(
self._broadcast_loop()
)
return True
async def stop(self):
"""Stop the streaming pipeline."""
if not self.is_running:
return False
self.is_running = False
self.processor.stop_processing()
# Cancel tasks
if self.processing_task:
self.processing_task.cancel()
if self.broadcasting_task:
self.broadcasting_task.cancel()
return True
async def add_frame(self, frame: StreamFrame) -> bool:
"""Add frame to pipeline."""
return await self.input_buffer.put_frame(frame)
async def add_client(self, client_id: str, websocket_mock) -> bool:
"""Add WebSocket client."""
return await self.websocket_manager.add_client(client_id, websocket_mock)
async def _broadcast_loop(self):
"""Broadcasting loop."""
while self.is_running:
try:
frame = await self.output_buffer.get_frame(timeout=0.1)
if frame:
await self.websocket_manager.broadcast_frame(frame)
except asyncio.TimeoutError:
continue
except Exception:
continue
def get_pipeline_stats(self) -> Dict[str, Any]:
"""Get pipeline statistics."""
return {
"is_running": self.is_running,
"input_buffer": self.input_buffer.get_stats(),
"output_buffer": self.output_buffer.get_stats(),
"websocket_clients": self.websocket_manager.get_client_stats()
}
return StreamingPipeline()
@pytest.mark.asyncio
async def test_end_to_end_streaming_should_fail_initially(self, streaming_pipeline):
"""Test end-to-end streaming - should fail initially."""
# Start pipeline
result = await streaming_pipeline.start()
# This will fail initially
assert result is True
assert streaming_pipeline.is_running is True
# Add clients
for i in range(2):
await streaming_pipeline.add_client(f"client_{i}", MagicMock())
# Add frames
for i in range(5):
frame = StreamFrame(
frame_id=f"frame_{i:03d}",
timestamp=datetime.utcnow(),
router_id="router_001",
pose_data={"persons": [], "zone_summary": {}},
processing_time_ms=30.0,
quality_score=0.9
)
await streaming_pipeline.add_frame(frame)
# Wait for processing
await asyncio.sleep(0.5)
# Stop pipeline
await streaming_pipeline.stop()
# Check statistics
stats = streaming_pipeline.get_pipeline_stats()
assert stats["input_buffer"]["total_frames"] == 5
assert stats["websocket_clients"]["connected_clients"] == 2
@pytest.mark.asyncio
async def test_pipeline_performance_should_fail_initially(self, streaming_pipeline):
"""Test pipeline performance - should fail initially."""
await streaming_pipeline.start()
# Add multiple clients
for i in range(10):
await streaming_pipeline.add_client(f"client_{i:03d}", MagicMock())
# Measure throughput
start_time = datetime.utcnow()
frame_count = 50
for i in range(frame_count):
frame = StreamFrame(
frame_id=f"perf_frame_{i:03d}",
timestamp=datetime.utcnow(),
router_id="router_001",
pose_data={"persons": [], "zone_summary": {}},
processing_time_ms=25.0,
quality_score=0.88
)
await streaming_pipeline.add_frame(frame)
# Wait for processing
await asyncio.sleep(2.0)
end_time = datetime.utcnow()
duration = (end_time - start_time).total_seconds()
await streaming_pipeline.stop()
# This will fail initially
# Check performance metrics
stats = streaming_pipeline.get_pipeline_stats()
throughput = frame_count / duration
assert throughput > 10 # Should process at least 10 FPS
assert stats["input_buffer"]["drop_rate"] < 0.1 # Less than 10% drop rate
@pytest.mark.asyncio
async def test_pipeline_error_recovery_should_fail_initially(self, streaming_pipeline):
"""Test pipeline error recovery - should fail initially."""
await streaming_pipeline.start()
# Set high error rate
streaming_pipeline.processor.set_error_rate(0.5) # 50% error rate
# Add frames
for i in range(20):
frame = StreamFrame(
frame_id=f"error_frame_{i:03d}",
timestamp=datetime.utcnow(),
router_id="router_001",
pose_data={"persons": [], "zone_summary": {}},
processing_time_ms=30.0,
quality_score=0.9
)
await streaming_pipeline.add_frame(frame)
# Wait for processing
await asyncio.sleep(1.0)
await streaming_pipeline.stop()
# This will fail initially
# Pipeline should continue running despite errors
stats = streaming_pipeline.get_pipeline_stats()
assert stats["input_buffer"]["total_frames"] == 20
# Some frames should be processed despite errors
assert stats["output_buffer"]["total_frames"] > 0
class TestStreamingLatency:
"""Test streaming latency characteristics."""
@pytest.mark.asyncio
async def test_end_to_end_latency_should_fail_initially(self):
"""Test end-to-end latency - should fail initially."""
class LatencyTracker:
def __init__(self):
self.latencies = []
async def measure_latency(self, frame: StreamFrame) -> float:
"""Measure processing latency."""
start_time = datetime.utcnow()
# Simulate processing pipeline
await asyncio.sleep(0.05) # 50ms processing time
end_time = datetime.utcnow()
latency = (end_time - start_time).total_seconds() * 1000 # Convert to ms
self.latencies.append(latency)
return latency
tracker = LatencyTracker()
# Measure latency for multiple frames
for i in range(10):
frame = StreamFrame(
frame_id=f"latency_frame_{i}",
timestamp=datetime.utcnow(),
router_id="router_001",
pose_data={},
processing_time_ms=0,
quality_score=1.0
)
latency = await tracker.measure_latency(frame)
# This will fail initially
assert latency > 0
assert latency < 200 # Should be less than 200ms
# Check average latency
avg_latency = sum(tracker.latencies) / len(tracker.latencies)
assert avg_latency < 100 # Average should be less than 100ms
@pytest.mark.asyncio
async def test_concurrent_stream_handling_should_fail_initially(self):
"""Test concurrent stream handling - should fail initially."""
async def process_stream(stream_id: str, frame_count: int) -> Dict[str, Any]:
"""Process a single stream."""
buffer = MockStreamBuffer()
processed_frames = 0
for i in range(frame_count):
frame = StreamFrame(
frame_id=f"{stream_id}_frame_{i}",
timestamp=datetime.utcnow(),
router_id=stream_id,
pose_data={},
processing_time_ms=20.0,
quality_score=0.9
)
success = await buffer.put_frame(frame)
if success:
processed_frames += 1
await asyncio.sleep(0.01) # Simulate frame rate
return {
"stream_id": stream_id,
"processed_frames": processed_frames,
"total_frames": frame_count
}
# Process multiple streams concurrently
streams = ["router_001", "router_002", "router_003"]
tasks = [process_stream(stream_id, 20) for stream_id in streams]
results = await asyncio.gather(*tasks)
# This will fail initially
assert len(results) == 3
for result in results:
assert result["processed_frames"] == result["total_frames"]
assert result["stream_id"] in streams
class TestStreamingResilience:
"""Test streaming pipeline resilience."""
@pytest.mark.asyncio
async def test_client_disconnection_handling_should_fail_initially(self):
"""Test client disconnection handling - should fail initially."""
websocket_manager = MockWebSocketManager()
# Add clients
client_ids = [f"client_{i:03d}" for i in range(5)]
for client_id in client_ids:
await websocket_manager.add_client(client_id, MagicMock())
# Simulate frame broadcasting
frame = StreamFrame(
frame_id="disconnect_test_frame",
timestamp=datetime.utcnow(),
router_id="router_001",
pose_data={},
processing_time_ms=30.0,
quality_score=0.9
)
# Broadcast to all clients
results = await websocket_manager.broadcast_frame(frame)
# This will fail initially
assert len(results) == 5
# Simulate client disconnections
await websocket_manager.remove_client("client_001")
await websocket_manager.remove_client("client_003")
# Broadcast again
results = await websocket_manager.broadcast_frame(frame)
assert len(results) == 3 # Only remaining clients
# Check statistics
stats = websocket_manager.get_client_stats()
assert stats["connected_clients"] == 3
@pytest.mark.asyncio
async def test_memory_pressure_handling_should_fail_initially(self):
"""Test memory pressure handling - should fail initially."""
# Create small buffers to simulate memory pressure
small_buffer = MockStreamBuffer(max_size=5)
# Generate many frames quickly
frames_generated = 0
frames_accepted = 0
for i in range(20):
frame = StreamFrame(
frame_id=f"memory_pressure_frame_{i}",
timestamp=datetime.utcnow(),
router_id="router_001",
pose_data={},
processing_time_ms=25.0,
quality_score=0.85
)
frames_generated += 1
success = await small_buffer.put_frame(frame)
if success:
frames_accepted += 1
# This will fail initially
# Buffer should handle memory pressure gracefully
stats = small_buffer.get_stats()
assert stats["total_frames"] == frames_generated
assert stats["dropped_frames"] > 0 # Some frames should be dropped
assert frames_accepted <= small_buffer.max_size
# Drop rate should be reasonable
assert stats["drop_rate"] > 0.5 # More than 50% dropped due to small buffer

View File

@@ -0,0 +1,419 @@
"""
Integration tests for WebSocket streaming functionality.
Tests WebSocket connections, message handling, and real-time data streaming.
"""
import pytest
import asyncio
import json
from datetime import datetime
from typing import Dict, Any, List
from unittest.mock import AsyncMock, MagicMock, patch
import websockets
from fastapi import FastAPI, WebSocket
from fastapi.testclient import TestClient
class MockWebSocket:
"""Mock WebSocket for testing."""
def __init__(self):
self.messages_sent = []
self.messages_received = []
self.closed = False
self.accept_called = False
async def accept(self):
"""Mock accept method."""
self.accept_called = True
async def send_json(self, data: Dict[str, Any]):
"""Mock send_json method."""
self.messages_sent.append(data)
async def send_text(self, text: str):
"""Mock send_text method."""
self.messages_sent.append(text)
async def receive_text(self) -> str:
"""Mock receive_text method."""
if self.messages_received:
return self.messages_received.pop(0)
# Simulate WebSocket disconnect
from fastapi import WebSocketDisconnect
raise WebSocketDisconnect()
async def close(self):
"""Mock close method."""
self.closed = True
def add_received_message(self, message: str):
"""Add a message to be received."""
self.messages_received.append(message)
class TestWebSocketStreaming:
"""Integration tests for WebSocket streaming."""
@pytest.fixture
def mock_websocket(self):
"""Create mock WebSocket."""
return MockWebSocket()
@pytest.fixture
def mock_connection_manager(self):
"""Mock connection manager."""
manager = AsyncMock()
manager.connect.return_value = "client-001"
manager.disconnect.return_value = True
manager.get_connection_stats.return_value = {
"total_clients": 1,
"active_streams": ["pose"]
}
manager.broadcast.return_value = 1
return manager
@pytest.fixture
def mock_stream_service(self):
"""Mock stream service."""
service = AsyncMock()
service.get_status.return_value = {
"is_active": True,
"active_streams": [],
"uptime_seconds": 3600.0
}
service.is_active.return_value = True
service.start.return_value = None
service.stop.return_value = None
return service
@pytest.mark.asyncio
async def test_websocket_pose_connection_should_fail_initially(self, mock_websocket, mock_connection_manager):
"""Test WebSocket pose connection establishment - should fail initially."""
# This test should fail because we haven't implemented the WebSocket handler properly
# Simulate WebSocket connection
zone_ids = "zone1,zone2"
min_confidence = 0.7
max_fps = 30
# Mock the websocket_pose_stream function
async def mock_websocket_handler(websocket, zone_ids, min_confidence, max_fps):
await websocket.accept()
# Parse zone IDs
zone_list = [zone.strip() for zone in zone_ids.split(",") if zone.strip()]
# Register client
client_id = await mock_connection_manager.connect(
websocket=websocket,
stream_type="pose",
zone_ids=zone_list,
min_confidence=min_confidence,
max_fps=max_fps
)
# Send confirmation
await websocket.send_json({
"type": "connection_established",
"client_id": client_id,
"timestamp": datetime.utcnow().isoformat(),
"config": {
"zone_ids": zone_list,
"min_confidence": min_confidence,
"max_fps": max_fps
}
})
return client_id
# Execute the handler
client_id = await mock_websocket_handler(mock_websocket, zone_ids, min_confidence, max_fps)
# This assertion will fail initially, driving us to implement the WebSocket handler
assert mock_websocket.accept_called
assert len(mock_websocket.messages_sent) == 1
assert mock_websocket.messages_sent[0]["type"] == "connection_established"
assert mock_websocket.messages_sent[0]["client_id"] == "client-001"
assert "config" in mock_websocket.messages_sent[0]
@pytest.mark.asyncio
async def test_websocket_message_handling_should_fail_initially(self, mock_websocket):
"""Test WebSocket message handling - should fail initially."""
# Mock message handler
async def handle_websocket_message(client_id: str, data: Dict[str, Any], websocket):
message_type = data.get("type")
if message_type == "ping":
await websocket.send_json({
"type": "pong",
"timestamp": datetime.utcnow().isoformat()
})
elif message_type == "update_config":
config = data.get("config", {})
await websocket.send_json({
"type": "config_updated",
"timestamp": datetime.utcnow().isoformat(),
"config": config
})
else:
await websocket.send_json({
"type": "error",
"message": f"Unknown message type: {message_type}"
})
# Test ping message
ping_data = {"type": "ping"}
await handle_websocket_message("client-001", ping_data, mock_websocket)
# This will fail initially
assert len(mock_websocket.messages_sent) == 1
assert mock_websocket.messages_sent[0]["type"] == "pong"
# Test config update
mock_websocket.messages_sent.clear()
config_data = {
"type": "update_config",
"config": {"min_confidence": 0.8, "max_fps": 15}
}
await handle_websocket_message("client-001", config_data, mock_websocket)
# This will fail initially
assert len(mock_websocket.messages_sent) == 1
assert mock_websocket.messages_sent[0]["type"] == "config_updated"
assert mock_websocket.messages_sent[0]["config"]["min_confidence"] == 0.8
@pytest.mark.asyncio
async def test_websocket_events_stream_should_fail_initially(self, mock_websocket, mock_connection_manager):
"""Test WebSocket events stream - should fail initially."""
# Mock events stream handler
async def mock_events_handler(websocket, event_types, zone_ids):
await websocket.accept()
# Parse parameters
event_list = [event.strip() for event in event_types.split(",") if event.strip()] if event_types else None
zone_list = [zone.strip() for zone in zone_ids.split(",") if zone.strip()] if zone_ids else None
# Register client
client_id = await mock_connection_manager.connect(
websocket=websocket,
stream_type="events",
zone_ids=zone_list,
event_types=event_list
)
# Send confirmation
await websocket.send_json({
"type": "connection_established",
"client_id": client_id,
"timestamp": datetime.utcnow().isoformat(),
"config": {
"event_types": event_list,
"zone_ids": zone_list
}
})
return client_id
# Execute handler
client_id = await mock_events_handler(mock_websocket, "fall_detection,intrusion", "zone1")
# This will fail initially
assert mock_websocket.accept_called
assert len(mock_websocket.messages_sent) == 1
assert mock_websocket.messages_sent[0]["type"] == "connection_established"
assert mock_websocket.messages_sent[0]["config"]["event_types"] == ["fall_detection", "intrusion"]
@pytest.mark.asyncio
async def test_websocket_disconnect_handling_should_fail_initially(self, mock_websocket, mock_connection_manager):
"""Test WebSocket disconnect handling - should fail initially."""
# Mock disconnect scenario
client_id = "client-001"
# Simulate disconnect
disconnect_result = await mock_connection_manager.disconnect(client_id)
# This will fail initially
assert disconnect_result is True
mock_connection_manager.disconnect.assert_called_once_with(client_id)
class TestWebSocketConnectionManager:
"""Test WebSocket connection management."""
@pytest.fixture
def connection_manager(self):
"""Create connection manager for testing."""
# Mock connection manager implementation
class MockConnectionManager:
def __init__(self):
self.connections = {}
self.client_counter = 0
async def connect(self, websocket, stream_type, zone_ids=None, **kwargs):
self.client_counter += 1
client_id = f"client-{self.client_counter:03d}"
self.connections[client_id] = {
"websocket": websocket,
"stream_type": stream_type,
"zone_ids": zone_ids or [],
"connected_at": datetime.utcnow(),
**kwargs
}
return client_id
async def disconnect(self, client_id):
if client_id in self.connections:
del self.connections[client_id]
return True
return False
async def get_connected_clients(self):
return list(self.connections.keys())
async def get_connection_stats(self):
return {
"total_clients": len(self.connections),
"active_streams": list(set(conn["stream_type"] for conn in self.connections.values()))
}
async def broadcast(self, data, stream_type=None, zone_ids=None):
sent_count = 0
for client_id, conn in self.connections.items():
if stream_type and conn["stream_type"] != stream_type:
continue
if zone_ids and not any(zone in conn["zone_ids"] for zone in zone_ids):
continue
# Mock sending data
sent_count += 1
return sent_count
return MockConnectionManager()
@pytest.mark.asyncio
async def test_connection_manager_connect_should_fail_initially(self, connection_manager, mock_websocket):
"""Test connection manager connect functionality - should fail initially."""
client_id = await connection_manager.connect(
websocket=mock_websocket,
stream_type="pose",
zone_ids=["zone1", "zone2"],
min_confidence=0.7
)
# This will fail initially
assert client_id == "client-001"
assert client_id in connection_manager.connections
assert connection_manager.connections[client_id]["stream_type"] == "pose"
assert connection_manager.connections[client_id]["zone_ids"] == ["zone1", "zone2"]
@pytest.mark.asyncio
async def test_connection_manager_disconnect_should_fail_initially(self, connection_manager, mock_websocket):
"""Test connection manager disconnect functionality - should fail initially."""
# Connect first
client_id = await connection_manager.connect(
websocket=mock_websocket,
stream_type="pose"
)
# Disconnect
result = await connection_manager.disconnect(client_id)
# This will fail initially
assert result is True
assert client_id not in connection_manager.connections
@pytest.mark.asyncio
async def test_connection_manager_broadcast_should_fail_initially(self, connection_manager):
"""Test connection manager broadcast functionality - should fail initially."""
# Connect multiple clients
ws1 = MockWebSocket()
ws2 = MockWebSocket()
client1 = await connection_manager.connect(ws1, "pose", zone_ids=["zone1"])
client2 = await connection_manager.connect(ws2, "events", zone_ids=["zone2"])
# Broadcast to pose stream
sent_count = await connection_manager.broadcast(
data={"type": "pose_data", "data": {}},
stream_type="pose"
)
# This will fail initially
assert sent_count == 1
# Broadcast to specific zone
sent_count = await connection_manager.broadcast(
data={"type": "zone_event", "data": {}},
zone_ids=["zone1"]
)
# This will fail initially
assert sent_count == 1
class TestWebSocketPerformance:
"""Test WebSocket performance characteristics."""
@pytest.mark.asyncio
async def test_multiple_concurrent_connections_should_fail_initially(self):
"""Test handling multiple concurrent WebSocket connections - should fail initially."""
# Mock multiple connections
connection_count = 10
connections = []
for i in range(connection_count):
mock_ws = MockWebSocket()
connections.append(mock_ws)
# Simulate concurrent connections
async def simulate_connection(websocket, client_id):
await websocket.accept()
await websocket.send_json({
"type": "connection_established",
"client_id": client_id
})
return True
# Execute concurrent connections
tasks = [
simulate_connection(ws, f"client-{i:03d}")
for i, ws in enumerate(connections)
]
results = await asyncio.gather(*tasks)
# This will fail initially
assert len(results) == connection_count
assert all(results)
assert all(ws.accept_called for ws in connections)
@pytest.mark.asyncio
async def test_websocket_message_throughput_should_fail_initially(self):
"""Test WebSocket message throughput - should fail initially."""
mock_ws = MockWebSocket()
message_count = 100
# Simulate high-frequency message sending
start_time = datetime.utcnow()
for i in range(message_count):
await mock_ws.send_json({
"type": "pose_data",
"frame_id": f"frame-{i:04d}",
"timestamp": datetime.utcnow().isoformat()
})
end_time = datetime.utcnow()
duration = (end_time - start_time).total_seconds()
# This will fail initially
assert len(mock_ws.messages_sent) == message_count
assert duration < 1.0 # Should handle 100 messages in under 1 second
# Calculate throughput
throughput = message_count / duration if duration > 0 else float('inf')
assert throughput > 100 # Should handle at least 100 messages per second

View File

@@ -0,0 +1,712 @@
"""
Hardware simulation mocks for testing.
Provides realistic hardware behavior simulation for routers and sensors.
"""
import asyncio
import numpy as np
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional, Callable, AsyncGenerator
from unittest.mock import AsyncMock, MagicMock
import json
import random
from dataclasses import dataclass, field
from enum import Enum
class RouterStatus(Enum):
"""Router status enumeration."""
OFFLINE = "offline"
CONNECTING = "connecting"
ONLINE = "online"
ERROR = "error"
MAINTENANCE = "maintenance"
class SignalQuality(Enum):
"""Signal quality levels."""
POOR = "poor"
FAIR = "fair"
GOOD = "good"
EXCELLENT = "excellent"
@dataclass
class RouterConfig:
"""Router configuration."""
router_id: str
frequency: float = 5.8e9 # 5.8 GHz
bandwidth: float = 80e6 # 80 MHz
num_antennas: int = 4
num_subcarriers: int = 64
tx_power: float = 20.0 # dBm
location: Dict[str, float] = field(default_factory=lambda: {"x": 0, "y": 0, "z": 0})
firmware_version: str = "1.2.3"
class MockWiFiRouter:
"""Mock WiFi router with CSI capabilities."""
def __init__(self, config: RouterConfig):
self.config = config
self.status = RouterStatus.OFFLINE
self.signal_quality = SignalQuality.GOOD
self.is_streaming = False
self.connected_devices = []
self.csi_data_buffer = []
self.error_rate = 0.01 # 1% error rate
self.latency_ms = 5.0
self.throughput_mbps = 100.0
self.temperature_celsius = 45.0
self.uptime_seconds = 0
self.last_heartbeat = None
self.callbacks = {
"on_status_change": [],
"on_csi_data": [],
"on_error": []
}
self._streaming_task = None
self._heartbeat_task = None
async def connect(self) -> bool:
"""Connect to router."""
if self.status != RouterStatus.OFFLINE:
return False
self.status = RouterStatus.CONNECTING
await self._notify_status_change()
# Simulate connection delay
await asyncio.sleep(0.1)
# Simulate occasional connection failures
if random.random() < 0.05: # 5% failure rate
self.status = RouterStatus.ERROR
await self._notify_error("Connection failed")
return False
self.status = RouterStatus.ONLINE
self.last_heartbeat = datetime.utcnow()
await self._notify_status_change()
# Start heartbeat
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
return True
async def disconnect(self):
"""Disconnect from router."""
if self.status == RouterStatus.OFFLINE:
return
# Stop streaming if active
if self.is_streaming:
await self.stop_csi_streaming()
# Stop heartbeat
if self._heartbeat_task:
self._heartbeat_task.cancel()
try:
await self._heartbeat_task
except asyncio.CancelledError:
pass
self.status = RouterStatus.OFFLINE
await self._notify_status_change()
async def start_csi_streaming(self, sample_rate: int = 1000) -> bool:
"""Start CSI data streaming."""
if self.status != RouterStatus.ONLINE:
return False
if self.is_streaming:
return False
self.is_streaming = True
self._streaming_task = asyncio.create_task(self._csi_streaming_loop(sample_rate))
return True
async def stop_csi_streaming(self):
"""Stop CSI data streaming."""
if not self.is_streaming:
return
self.is_streaming = False
if self._streaming_task:
self._streaming_task.cancel()
try:
await self._streaming_task
except asyncio.CancelledError:
pass
async def _csi_streaming_loop(self, sample_rate: int):
"""CSI data streaming loop."""
interval = 1.0 / sample_rate
try:
while self.is_streaming:
# Generate CSI data
csi_data = self._generate_csi_sample()
# Add to buffer
self.csi_data_buffer.append(csi_data)
# Keep buffer size manageable
if len(self.csi_data_buffer) > 1000:
self.csi_data_buffer = self.csi_data_buffer[-1000:]
# Notify callbacks
await self._notify_csi_data(csi_data)
# Simulate processing delay and jitter
actual_interval = interval * random.uniform(0.9, 1.1)
await asyncio.sleep(actual_interval)
except asyncio.CancelledError:
pass
async def _heartbeat_loop(self):
"""Heartbeat loop to maintain connection."""
try:
while self.status == RouterStatus.ONLINE:
self.last_heartbeat = datetime.utcnow()
self.uptime_seconds += 1
# Simulate temperature variations
self.temperature_celsius += random.uniform(-1, 1)
self.temperature_celsius = max(30, min(80, self.temperature_celsius))
# Check for overheating
if self.temperature_celsius > 75:
self.signal_quality = SignalQuality.POOR
await self._notify_error("High temperature warning")
await asyncio.sleep(1.0)
except asyncio.CancelledError:
pass
def _generate_csi_sample(self) -> Dict[str, Any]:
"""Generate realistic CSI sample."""
# Base amplitude and phase matrices
amplitude = np.random.uniform(0.2, 0.8, (self.config.num_antennas, self.config.num_subcarriers))
phase = np.random.uniform(-np.pi, np.pi, (self.config.num_antennas, self.config.num_subcarriers))
# Add signal quality effects
if self.signal_quality == SignalQuality.POOR:
noise_level = 0.3
elif self.signal_quality == SignalQuality.FAIR:
noise_level = 0.2
elif self.signal_quality == SignalQuality.GOOD:
noise_level = 0.1
else: # EXCELLENT
noise_level = 0.05
# Add noise
amplitude += np.random.normal(0, noise_level, amplitude.shape)
phase += np.random.normal(0, noise_level * np.pi, phase.shape)
# Clip values
amplitude = np.clip(amplitude, 0, 1)
phase = np.mod(phase + np.pi, 2 * np.pi) - np.pi
# Simulate packet errors
if random.random() < self.error_rate:
# Corrupt some data
corruption_mask = np.random.random(amplitude.shape) < 0.1
amplitude[corruption_mask] = 0
phase[corruption_mask] = 0
return {
"timestamp": datetime.utcnow().isoformat(),
"router_id": self.config.router_id,
"amplitude": amplitude.tolist(),
"phase": phase.tolist(),
"frequency": self.config.frequency,
"bandwidth": self.config.bandwidth,
"num_antennas": self.config.num_antennas,
"num_subcarriers": self.config.num_subcarriers,
"signal_quality": self.signal_quality.value,
"temperature": self.temperature_celsius,
"tx_power": self.config.tx_power,
"sequence_number": len(self.csi_data_buffer)
}
def register_callback(self, event: str, callback: Callable):
"""Register event callback."""
if event in self.callbacks:
self.callbacks[event].append(callback)
def unregister_callback(self, event: str, callback: Callable):
"""Unregister event callback."""
if event in self.callbacks and callback in self.callbacks[event]:
self.callbacks[event].remove(callback)
async def _notify_status_change(self):
"""Notify status change callbacks."""
for callback in self.callbacks["on_status_change"]:
try:
if asyncio.iscoroutinefunction(callback):
await callback(self.status)
else:
callback(self.status)
except Exception:
pass # Ignore callback errors
async def _notify_csi_data(self, data: Dict[str, Any]):
"""Notify CSI data callbacks."""
for callback in self.callbacks["on_csi_data"]:
try:
if asyncio.iscoroutinefunction(callback):
await callback(data)
else:
callback(data)
except Exception:
pass
async def _notify_error(self, error_message: str):
"""Notify error callbacks."""
for callback in self.callbacks["on_error"]:
try:
if asyncio.iscoroutinefunction(callback):
await callback(error_message)
else:
callback(error_message)
except Exception:
pass
def get_status(self) -> Dict[str, Any]:
"""Get router status information."""
return {
"router_id": self.config.router_id,
"status": self.status.value,
"signal_quality": self.signal_quality.value,
"is_streaming": self.is_streaming,
"connected_devices": len(self.connected_devices),
"temperature": self.temperature_celsius,
"uptime_seconds": self.uptime_seconds,
"last_heartbeat": self.last_heartbeat.isoformat() if self.last_heartbeat else None,
"error_rate": self.error_rate,
"latency_ms": self.latency_ms,
"throughput_mbps": self.throughput_mbps,
"firmware_version": self.config.firmware_version,
"location": self.config.location
}
def set_signal_quality(self, quality: SignalQuality):
"""Set signal quality for testing."""
self.signal_quality = quality
def set_error_rate(self, error_rate: float):
"""Set error rate for testing."""
self.error_rate = max(0, min(1, error_rate))
def simulate_interference(self, duration_seconds: float = 5.0):
"""Simulate interference for testing."""
async def interference_task():
original_quality = self.signal_quality
self.signal_quality = SignalQuality.POOR
await asyncio.sleep(duration_seconds)
self.signal_quality = original_quality
asyncio.create_task(interference_task())
def get_csi_buffer(self) -> List[Dict[str, Any]]:
"""Get CSI data buffer."""
return self.csi_data_buffer.copy()
def clear_csi_buffer(self):
"""Clear CSI data buffer."""
self.csi_data_buffer.clear()
class MockRouterNetwork:
"""Mock network of WiFi routers."""
def __init__(self):
self.routers = {}
self.network_topology = {}
self.interference_sources = []
self.global_callbacks = {
"on_router_added": [],
"on_router_removed": [],
"on_network_event": []
}
def add_router(self, config: RouterConfig) -> MockWiFiRouter:
"""Add router to network."""
if config.router_id in self.routers:
raise ValueError(f"Router {config.router_id} already exists")
router = MockWiFiRouter(config)
self.routers[config.router_id] = router
# Register for router events
router.register_callback("on_status_change", self._on_router_status_change)
router.register_callback("on_error", self._on_router_error)
# Notify callbacks
for callback in self.global_callbacks["on_router_added"]:
callback(router)
return router
def remove_router(self, router_id: str) -> bool:
"""Remove router from network."""
if router_id not in self.routers:
return False
router = self.routers[router_id]
# Disconnect if connected
if router.status != RouterStatus.OFFLINE:
asyncio.create_task(router.disconnect())
del self.routers[router_id]
# Notify callbacks
for callback in self.global_callbacks["on_router_removed"]:
callback(router_id)
return True
def get_router(self, router_id: str) -> Optional[MockWiFiRouter]:
"""Get router by ID."""
return self.routers.get(router_id)
def get_all_routers(self) -> Dict[str, MockWiFiRouter]:
"""Get all routers."""
return self.routers.copy()
async def connect_all_routers(self) -> Dict[str, bool]:
"""Connect all routers."""
results = {}
tasks = []
for router_id, router in self.routers.items():
task = asyncio.create_task(router.connect())
tasks.append((router_id, task))
for router_id, task in tasks:
try:
result = await task
results[router_id] = result
except Exception:
results[router_id] = False
return results
async def disconnect_all_routers(self):
"""Disconnect all routers."""
tasks = []
for router in self.routers.values():
if router.status != RouterStatus.OFFLINE:
task = asyncio.create_task(router.disconnect())
tasks.append(task)
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
async def start_all_streaming(self, sample_rate: int = 1000) -> Dict[str, bool]:
"""Start CSI streaming on all routers."""
results = {}
for router_id, router in self.routers.items():
if router.status == RouterStatus.ONLINE:
result = await router.start_csi_streaming(sample_rate)
results[router_id] = result
else:
results[router_id] = False
return results
async def stop_all_streaming(self):
"""Stop CSI streaming on all routers."""
tasks = []
for router in self.routers.values():
if router.is_streaming:
task = asyncio.create_task(router.stop_csi_streaming())
tasks.append(task)
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
def get_network_status(self) -> Dict[str, Any]:
"""Get overall network status."""
total_routers = len(self.routers)
online_routers = sum(1 for r in self.routers.values() if r.status == RouterStatus.ONLINE)
streaming_routers = sum(1 for r in self.routers.values() if r.is_streaming)
return {
"total_routers": total_routers,
"online_routers": online_routers,
"streaming_routers": streaming_routers,
"network_health": online_routers / max(total_routers, 1),
"interference_sources": len(self.interference_sources),
"timestamp": datetime.utcnow().isoformat()
}
def simulate_network_partition(self, router_ids: List[str], duration_seconds: float = 10.0):
"""Simulate network partition for testing."""
async def partition_task():
# Disconnect specified routers
affected_routers = [self.routers[rid] for rid in router_ids if rid in self.routers]
for router in affected_routers:
if router.status == RouterStatus.ONLINE:
router.status = RouterStatus.ERROR
await router._notify_status_change()
await asyncio.sleep(duration_seconds)
# Reconnect routers
for router in affected_routers:
if router.status == RouterStatus.ERROR:
await router.connect()
asyncio.create_task(partition_task())
def add_interference_source(self, location: Dict[str, float], strength: float, frequency: float):
"""Add interference source."""
interference = {
"id": f"interference_{len(self.interference_sources)}",
"location": location,
"strength": strength,
"frequency": frequency,
"active": True
}
self.interference_sources.append(interference)
# Affect nearby routers
for router in self.routers.values():
distance = self._calculate_distance(router.config.location, location)
if distance < 50: # Within 50 meters
if strength > 0.5:
router.set_signal_quality(SignalQuality.POOR)
elif strength > 0.3:
router.set_signal_quality(SignalQuality.FAIR)
def _calculate_distance(self, loc1: Dict[str, float], loc2: Dict[str, float]) -> float:
"""Calculate distance between two locations."""
dx = loc1.get("x", 0) - loc2.get("x", 0)
dy = loc1.get("y", 0) - loc2.get("y", 0)
dz = loc1.get("z", 0) - loc2.get("z", 0)
return np.sqrt(dx**2 + dy**2 + dz**2)
async def _on_router_status_change(self, status: RouterStatus):
"""Handle router status change."""
for callback in self.global_callbacks["on_network_event"]:
await callback("router_status_change", {"status": status})
async def _on_router_error(self, error_message: str):
"""Handle router error."""
for callback in self.global_callbacks["on_network_event"]:
await callback("router_error", {"error": error_message})
def register_global_callback(self, event: str, callback: Callable):
"""Register global network callback."""
if event in self.global_callbacks:
self.global_callbacks[event].append(callback)
class MockSensorArray:
"""Mock sensor array for environmental monitoring."""
def __init__(self, sensor_id: str, location: Dict[str, float]):
self.sensor_id = sensor_id
self.location = location
self.is_active = False
self.sensors = {
"temperature": {"value": 22.0, "unit": "celsius", "range": (15, 35)},
"humidity": {"value": 45.0, "unit": "percent", "range": (30, 70)},
"pressure": {"value": 1013.25, "unit": "hPa", "range": (980, 1050)},
"light": {"value": 300.0, "unit": "lux", "range": (0, 1000)},
"motion": {"value": False, "unit": "boolean", "range": (False, True)},
"sound": {"value": 35.0, "unit": "dB", "range": (20, 80)}
}
self.reading_history = []
self.callbacks = []
async def start_monitoring(self, interval_seconds: float = 1.0):
"""Start sensor monitoring."""
if self.is_active:
return False
self.is_active = True
asyncio.create_task(self._monitoring_loop(interval_seconds))
return True
def stop_monitoring(self):
"""Stop sensor monitoring."""
self.is_active = False
async def _monitoring_loop(self, interval: float):
"""Sensor monitoring loop."""
try:
while self.is_active:
reading = self._generate_sensor_reading()
self.reading_history.append(reading)
# Keep history manageable
if len(self.reading_history) > 1000:
self.reading_history = self.reading_history[-1000:]
# Notify callbacks
for callback in self.callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback(reading)
else:
callback(reading)
except Exception:
pass
await asyncio.sleep(interval)
except asyncio.CancelledError:
pass
def _generate_sensor_reading(self) -> Dict[str, Any]:
"""Generate realistic sensor reading."""
reading = {
"sensor_id": self.sensor_id,
"timestamp": datetime.utcnow().isoformat(),
"location": self.location,
"readings": {}
}
for sensor_name, config in self.sensors.items():
if sensor_name == "motion":
# Motion detection with some randomness
reading["readings"][sensor_name] = random.random() < 0.1 # 10% chance of motion
else:
# Continuous sensors with drift
current_value = config["value"]
min_val, max_val = config["range"]
# Add small random drift
drift = random.uniform(-0.1, 0.1) * (max_val - min_val)
new_value = current_value + drift
# Keep within range
new_value = max(min_val, min(max_val, new_value))
config["value"] = new_value
reading["readings"][sensor_name] = {
"value": round(new_value, 2),
"unit": config["unit"]
}
return reading
def register_callback(self, callback: Callable):
"""Register sensor callback."""
self.callbacks.append(callback)
def unregister_callback(self, callback: Callable):
"""Unregister sensor callback."""
if callback in self.callbacks:
self.callbacks.remove(callback)
def get_latest_reading(self) -> Optional[Dict[str, Any]]:
"""Get latest sensor reading."""
return self.reading_history[-1] if self.reading_history else None
def get_reading_history(self, limit: int = 100) -> List[Dict[str, Any]]:
"""Get sensor reading history."""
return self.reading_history[-limit:]
def simulate_event(self, event_type: str, duration_seconds: float = 5.0):
"""Simulate environmental event."""
async def event_task():
if event_type == "motion_detected":
self.sensors["motion"]["value"] = True
await asyncio.sleep(duration_seconds)
self.sensors["motion"]["value"] = False
elif event_type == "temperature_spike":
original_temp = self.sensors["temperature"]["value"]
self.sensors["temperature"]["value"] = min(35, original_temp + 10)
await asyncio.sleep(duration_seconds)
self.sensors["temperature"]["value"] = original_temp
elif event_type == "loud_noise":
original_sound = self.sensors["sound"]["value"]
self.sensors["sound"]["value"] = min(80, original_sound + 20)
await asyncio.sleep(duration_seconds)
self.sensors["sound"]["value"] = original_sound
asyncio.create_task(event_task())
# Utility functions for creating test hardware setups
def create_test_router_network(num_routers: int = 3) -> MockRouterNetwork:
"""Create test router network."""
network = MockRouterNetwork()
for i in range(num_routers):
config = RouterConfig(
router_id=f"router_{i:03d}",
location={"x": i * 10, "y": 0, "z": 2.5}
)
network.add_router(config)
return network
def create_test_sensor_array(num_sensors: int = 2) -> List[MockSensorArray]:
"""Create test sensor array."""
sensors = []
for i in range(num_sensors):
sensor = MockSensorArray(
sensor_id=f"sensor_{i:03d}",
location={"x": i * 5, "y": 5, "z": 1.0}
)
sensors.append(sensor)
return sensors
async def setup_test_hardware_environment() -> Dict[str, Any]:
"""Setup complete test hardware environment."""
# Create router network
router_network = create_test_router_network(3)
# Create sensor arrays
sensor_arrays = create_test_sensor_array(2)
# Connect all routers
router_results = await router_network.connect_all_routers()
# Start sensor monitoring
sensor_tasks = []
for sensor in sensor_arrays:
task = asyncio.create_task(sensor.start_monitoring(1.0))
sensor_tasks.append(task)
sensor_results = await asyncio.gather(*sensor_tasks)
return {
"router_network": router_network,
"sensor_arrays": sensor_arrays,
"router_connection_results": router_results,
"sensor_start_results": sensor_results,
"setup_timestamp": datetime.utcnow().isoformat()
}
async def teardown_test_hardware_environment(environment: Dict[str, Any]):
"""Teardown test hardware environment."""
# Stop sensor monitoring
for sensor in environment["sensor_arrays"]:
sensor.stop_monitoring()
# Disconnect all routers
await environment["router_network"].disconnect_all_routers()

View File

@@ -0,0 +1,649 @@
"""
Performance tests for API throughput and load testing.
Tests API endpoint performance under various load conditions.
"""
import pytest
import asyncio
import aiohttp
import time
import numpy as np
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional
from unittest.mock import AsyncMock, MagicMock, patch
import json
import statistics
class MockAPIServer:
"""Mock API server for load testing."""
def __init__(self):
self.request_count = 0
self.response_times = []
self.error_count = 0
self.concurrent_requests = 0
self.max_concurrent = 0
self.is_running = False
self.rate_limit_enabled = False
self.rate_limit_per_second = 100
self.request_timestamps = []
async def handle_request(self, endpoint: str, method: str = "GET", data: Dict[str, Any] = None) -> Dict[str, Any]:
"""Handle API request."""
start_time = time.time()
self.concurrent_requests += 1
self.max_concurrent = max(self.max_concurrent, self.concurrent_requests)
self.request_count += 1
self.request_timestamps.append(start_time)
try:
# Check rate limiting
if self.rate_limit_enabled:
recent_requests = [
ts for ts in self.request_timestamps
if start_time - ts <= 1.0
]
if len(recent_requests) > self.rate_limit_per_second:
self.error_count += 1
return {
"status": 429,
"error": "Rate limit exceeded",
"response_time_ms": 1.0
}
# Simulate processing time based on endpoint
processing_time = self._get_processing_time(endpoint, method)
await asyncio.sleep(processing_time)
# Generate response
response = self._generate_response(endpoint, method, data)
end_time = time.time()
response_time = (end_time - start_time) * 1000
self.response_times.append(response_time)
return {
"status": 200,
"data": response,
"response_time_ms": response_time
}
except Exception as e:
self.error_count += 1
return {
"status": 500,
"error": str(e),
"response_time_ms": (time.time() - start_time) * 1000
}
finally:
self.concurrent_requests -= 1
def _get_processing_time(self, endpoint: str, method: str) -> float:
"""Get processing time for endpoint."""
processing_times = {
"/health": 0.001,
"/pose/detect": 0.05,
"/pose/stream": 0.02,
"/auth/login": 0.01,
"/auth/refresh": 0.005,
"/config": 0.003
}
base_time = processing_times.get(endpoint, 0.01)
# Add some variance
return base_time * np.random.uniform(0.8, 1.2)
def _generate_response(self, endpoint: str, method: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""Generate response for endpoint."""
if endpoint == "/health":
return {"status": "healthy", "timestamp": datetime.utcnow().isoformat()}
elif endpoint == "/pose/detect":
return {
"persons": [
{
"person_id": "person_1",
"confidence": 0.85,
"bounding_box": {"x": 100, "y": 150, "width": 80, "height": 180},
"keypoints": [[x, y, 0.9] for x, y in zip(range(17), range(17))]
}
],
"processing_time_ms": 45.2,
"model_version": "v1.0"
}
elif endpoint == "/auth/login":
return {
"access_token": "mock_access_token",
"refresh_token": "mock_refresh_token",
"expires_in": 3600
}
else:
return {"message": "Success", "endpoint": endpoint, "method": method}
def get_performance_stats(self) -> Dict[str, Any]:
"""Get performance statistics."""
if not self.response_times:
return {
"total_requests": self.request_count,
"error_count": self.error_count,
"error_rate": 0,
"avg_response_time_ms": 0,
"median_response_time_ms": 0,
"p95_response_time_ms": 0,
"p99_response_time_ms": 0,
"max_concurrent_requests": self.max_concurrent,
"requests_per_second": 0
}
return {
"total_requests": self.request_count,
"error_count": self.error_count,
"error_rate": self.error_count / self.request_count,
"avg_response_time_ms": statistics.mean(self.response_times),
"median_response_time_ms": statistics.median(self.response_times),
"p95_response_time_ms": np.percentile(self.response_times, 95),
"p99_response_time_ms": np.percentile(self.response_times, 99),
"max_concurrent_requests": self.max_concurrent,
"requests_per_second": self._calculate_rps()
}
def _calculate_rps(self) -> float:
"""Calculate requests per second."""
if len(self.request_timestamps) < 2:
return 0
duration = self.request_timestamps[-1] - self.request_timestamps[0]
return len(self.request_timestamps) / max(duration, 0.001)
def enable_rate_limiting(self, requests_per_second: int):
"""Enable rate limiting."""
self.rate_limit_enabled = True
self.rate_limit_per_second = requests_per_second
def reset_stats(self):
"""Reset performance statistics."""
self.request_count = 0
self.response_times = []
self.error_count = 0
self.concurrent_requests = 0
self.max_concurrent = 0
self.request_timestamps = []
class TestAPIThroughput:
"""Test API throughput under various conditions."""
@pytest.fixture
def api_server(self):
"""Create mock API server."""
return MockAPIServer()
@pytest.mark.asyncio
async def test_single_request_performance_should_fail_initially(self, api_server):
"""Test single request performance - should fail initially."""
start_time = time.time()
response = await api_server.handle_request("/health")
end_time = time.time()
response_time = (end_time - start_time) * 1000
# This will fail initially
assert response["status"] == 200
assert response_time < 50 # Should respond within 50ms
assert response["response_time_ms"] > 0
stats = api_server.get_performance_stats()
assert stats["total_requests"] == 1
assert stats["error_count"] == 0
@pytest.mark.asyncio
async def test_concurrent_request_handling_should_fail_initially(self, api_server):
"""Test concurrent request handling - should fail initially."""
# Send multiple concurrent requests
concurrent_requests = 10
tasks = []
for i in range(concurrent_requests):
task = asyncio.create_task(api_server.handle_request("/health"))
tasks.append(task)
start_time = time.time()
responses = await asyncio.gather(*tasks)
end_time = time.time()
total_time = (end_time - start_time) * 1000
# This will fail initially
assert len(responses) == concurrent_requests
assert all(r["status"] == 200 for r in responses)
# All requests should complete within reasonable time
assert total_time < 200 # Should complete within 200ms
stats = api_server.get_performance_stats()
assert stats["total_requests"] == concurrent_requests
assert stats["max_concurrent_requests"] <= concurrent_requests
@pytest.mark.asyncio
async def test_sustained_load_performance_should_fail_initially(self, api_server):
"""Test sustained load performance - should fail initially."""
duration_seconds = 3
target_rps = 50 # 50 requests per second
async def send_requests():
"""Send requests at target rate."""
interval = 1.0 / target_rps
end_time = time.time() + duration_seconds
while time.time() < end_time:
await api_server.handle_request("/health")
await asyncio.sleep(interval)
start_time = time.time()
await send_requests()
actual_duration = time.time() - start_time
stats = api_server.get_performance_stats()
actual_rps = stats["requests_per_second"]
# This will fail initially
assert actual_rps >= target_rps * 0.8 # Within 80% of target
assert stats["error_rate"] < 0.05 # Less than 5% error rate
assert stats["avg_response_time_ms"] < 100 # Average response time under 100ms
@pytest.mark.asyncio
async def test_different_endpoint_performance_should_fail_initially(self, api_server):
"""Test different endpoint performance - should fail initially."""
endpoints = [
"/health",
"/pose/detect",
"/auth/login",
"/config"
]
results = {}
for endpoint in endpoints:
# Test each endpoint multiple times
response_times = []
for _ in range(10):
response = await api_server.handle_request(endpoint)
response_times.append(response["response_time_ms"])
results[endpoint] = {
"avg_response_time": statistics.mean(response_times),
"min_response_time": min(response_times),
"max_response_time": max(response_times)
}
# This will fail initially
# Health endpoint should be fastest
assert results["/health"]["avg_response_time"] < results["/pose/detect"]["avg_response_time"]
# All endpoints should respond within reasonable time
for endpoint, metrics in results.items():
assert metrics["avg_response_time"] < 200 # Less than 200ms average
assert metrics["max_response_time"] < 500 # Less than 500ms max
@pytest.mark.asyncio
async def test_rate_limiting_behavior_should_fail_initially(self, api_server):
"""Test rate limiting behavior - should fail initially."""
# Enable rate limiting
api_server.enable_rate_limiting(requests_per_second=10)
# Send requests faster than rate limit
rapid_requests = 20
tasks = []
for i in range(rapid_requests):
task = asyncio.create_task(api_server.handle_request("/health"))
tasks.append(task)
responses = await asyncio.gather(*tasks)
# This will fail initially
# Some requests should be rate limited
success_responses = [r for r in responses if r["status"] == 200]
rate_limited_responses = [r for r in responses if r["status"] == 429]
assert len(success_responses) > 0
assert len(rate_limited_responses) > 0
assert len(success_responses) + len(rate_limited_responses) == rapid_requests
stats = api_server.get_performance_stats()
assert stats["error_count"] > 0 # Should have rate limit errors
class TestAPILoadTesting:
"""Test API under heavy load conditions."""
@pytest.fixture
def load_test_server(self):
"""Create server for load testing."""
server = MockAPIServer()
return server
@pytest.mark.asyncio
async def test_high_concurrency_load_should_fail_initially(self, load_test_server):
"""Test high concurrency load - should fail initially."""
concurrent_users = 50
requests_per_user = 5
async def user_session(user_id: int):
"""Simulate user session."""
session_responses = []
for i in range(requests_per_user):
response = await load_test_server.handle_request("/health")
session_responses.append(response)
# Small delay between requests
await asyncio.sleep(0.01)
return session_responses
# Create user sessions
user_tasks = [user_session(i) for i in range(concurrent_users)]
start_time = time.time()
all_sessions = await asyncio.gather(*user_tasks)
end_time = time.time()
total_duration = end_time - start_time
total_requests = concurrent_users * requests_per_user
# This will fail initially
# All sessions should complete
assert len(all_sessions) == concurrent_users
# Check performance metrics
stats = load_test_server.get_performance_stats()
assert stats["total_requests"] == total_requests
assert stats["error_rate"] < 0.1 # Less than 10% error rate
assert stats["requests_per_second"] > 100 # Should handle at least 100 RPS
@pytest.mark.asyncio
async def test_mixed_endpoint_load_should_fail_initially(self, load_test_server):
"""Test mixed endpoint load - should fail initially."""
# Define endpoint mix (realistic usage pattern)
endpoint_mix = [
("/health", 0.4), # 40% health checks
("/pose/detect", 0.3), # 30% pose detection
("/auth/login", 0.1), # 10% authentication
("/config", 0.2) # 20% configuration
]
total_requests = 100
async def send_mixed_requests():
"""Send requests with mixed endpoints."""
tasks = []
for i in range(total_requests):
# Select endpoint based on distribution
rand = np.random.random()
cumulative = 0
for endpoint, probability in endpoint_mix:
cumulative += probability
if rand <= cumulative:
task = asyncio.create_task(
load_test_server.handle_request(endpoint)
)
tasks.append(task)
break
return await asyncio.gather(*tasks)
start_time = time.time()
responses = await send_mixed_requests()
end_time = time.time()
duration = end_time - start_time
# This will fail initially
assert len(responses) == total_requests
# Check response distribution
success_responses = [r for r in responses if r["status"] == 200]
assert len(success_responses) >= total_requests * 0.9 # At least 90% success
stats = load_test_server.get_performance_stats()
assert stats["requests_per_second"] > 50 # Should handle at least 50 RPS
assert stats["avg_response_time_ms"] < 150 # Average response time under 150ms
@pytest.mark.asyncio
async def test_stress_testing_should_fail_initially(self, load_test_server):
"""Test stress testing - should fail initially."""
# Gradually increase load to find breaking point
load_levels = [10, 25, 50, 100, 200]
results = {}
for concurrent_requests in load_levels:
load_test_server.reset_stats()
# Send concurrent requests
tasks = [
load_test_server.handle_request("/health")
for _ in range(concurrent_requests)
]
start_time = time.time()
responses = await asyncio.gather(*tasks)
end_time = time.time()
duration = end_time - start_time
stats = load_test_server.get_performance_stats()
results[concurrent_requests] = {
"duration": duration,
"rps": stats["requests_per_second"],
"error_rate": stats["error_rate"],
"avg_response_time": stats["avg_response_time_ms"],
"p95_response_time": stats["p95_response_time_ms"]
}
# This will fail initially
# Performance should degrade gracefully with increased load
for load_level, metrics in results.items():
assert metrics["error_rate"] < 0.2 # Less than 20% error rate
assert metrics["avg_response_time"] < 1000 # Less than 1 second average
# Higher loads should have higher response times
assert results[10]["avg_response_time"] <= results[200]["avg_response_time"]
@pytest.mark.asyncio
async def test_memory_usage_under_load_should_fail_initially(self, load_test_server):
"""Test memory usage under load - should fail initially."""
import psutil
import os
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss
# Generate sustained load
duration_seconds = 5
target_rps = 100
async def sustained_load():
"""Generate sustained load."""
interval = 1.0 / target_rps
end_time = time.time() + duration_seconds
while time.time() < end_time:
await load_test_server.handle_request("/pose/detect")
await asyncio.sleep(interval)
await sustained_load()
final_memory = process.memory_info().rss
memory_increase = final_memory - initial_memory
# This will fail initially
# Memory increase should be reasonable (less than 100MB)
assert memory_increase < 100 * 1024 * 1024
stats = load_test_server.get_performance_stats()
assert stats["total_requests"] > duration_seconds * target_rps * 0.8
class TestAPIPerformanceOptimization:
"""Test API performance optimization techniques."""
@pytest.mark.asyncio
async def test_response_caching_effect_should_fail_initially(self):
"""Test response caching effect - should fail initially."""
class CachedAPIServer(MockAPIServer):
def __init__(self):
super().__init__()
self.cache = {}
self.cache_hits = 0
self.cache_misses = 0
async def handle_request(self, endpoint: str, method: str = "GET", data: Dict[str, Any] = None) -> Dict[str, Any]:
cache_key = f"{method}:{endpoint}"
if cache_key in self.cache:
self.cache_hits += 1
cached_response = self.cache[cache_key].copy()
cached_response["response_time_ms"] = 1.0 # Cached responses are fast
return cached_response
self.cache_misses += 1
response = await super().handle_request(endpoint, method, data)
# Cache successful responses
if response["status"] == 200:
self.cache[cache_key] = response.copy()
return response
cached_server = CachedAPIServer()
# First request (cache miss)
response1 = await cached_server.handle_request("/health")
# Second request (cache hit)
response2 = await cached_server.handle_request("/health")
# This will fail initially
assert response1["status"] == 200
assert response2["status"] == 200
assert response2["response_time_ms"] < response1["response_time_ms"]
assert cached_server.cache_hits == 1
assert cached_server.cache_misses == 1
@pytest.mark.asyncio
async def test_connection_pooling_effect_should_fail_initially(self):
"""Test connection pooling effect - should fail initially."""
# Simulate connection overhead
class ConnectionPoolServer(MockAPIServer):
def __init__(self, pool_size: int = 10):
super().__init__()
self.pool_size = pool_size
self.active_connections = 0
self.connection_overhead = 0.01 # 10ms connection overhead
async def handle_request(self, endpoint: str, method: str = "GET", data: Dict[str, Any] = None) -> Dict[str, Any]:
# Simulate connection acquisition
if self.active_connections < self.pool_size:
# New connection needed
await asyncio.sleep(self.connection_overhead)
self.active_connections += 1
try:
return await super().handle_request(endpoint, method, data)
finally:
# Connection returned to pool (not closed)
pass
pooled_server = ConnectionPoolServer(pool_size=5)
# Send requests that exceed pool size
concurrent_requests = 10
tasks = [
pooled_server.handle_request("/health")
for _ in range(concurrent_requests)
]
start_time = time.time()
responses = await asyncio.gather(*tasks)
end_time = time.time()
total_time = (end_time - start_time) * 1000
# This will fail initially
assert len(responses) == concurrent_requests
assert all(r["status"] == 200 for r in responses)
# With connection pooling, should complete reasonably fast
assert total_time < 500 # Should complete within 500ms
@pytest.mark.asyncio
async def test_request_batching_performance_should_fail_initially(self):
"""Test request batching performance - should fail initially."""
class BatchingServer(MockAPIServer):
def __init__(self):
super().__init__()
self.batch_size = 5
self.pending_requests = []
self.batch_processing = False
async def handle_batch_request(self, requests: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Handle batch of requests."""
# Batch processing is more efficient
batch_overhead = 0.01 # 10ms overhead for entire batch
await asyncio.sleep(batch_overhead)
responses = []
for req in requests:
# Individual processing is faster in batch
processing_time = self._get_processing_time(req["endpoint"], req["method"]) * 0.5
await asyncio.sleep(processing_time)
response = self._generate_response(req["endpoint"], req["method"], req.get("data"))
responses.append({
"status": 200,
"data": response,
"response_time_ms": processing_time * 1000
})
return responses
batching_server = BatchingServer()
# Test individual requests vs batch
individual_requests = 5
# Individual requests
start_time = time.time()
individual_tasks = [
batching_server.handle_request("/health")
for _ in range(individual_requests)
]
individual_responses = await asyncio.gather(*individual_tasks)
individual_time = (time.time() - start_time) * 1000
# Batch request
batch_requests = [
{"endpoint": "/health", "method": "GET"}
for _ in range(individual_requests)
]
start_time = time.time()
batch_responses = await batching_server.handle_batch_request(batch_requests)
batch_time = (time.time() - start_time) * 1000
# This will fail initially
assert len(individual_responses) == individual_requests
assert len(batch_responses) == individual_requests
# Batch should be more efficient
assert batch_time < individual_time
assert all(r["status"] == 200 for r in batch_responses)

View File

@@ -0,0 +1,507 @@
"""
Performance tests for ML model inference speed.
Tests pose estimation model performance, throughput, and optimization.
"""
import pytest
import asyncio
import numpy as np
import time
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional
from unittest.mock import AsyncMock, MagicMock, patch
import psutil
import os
class MockPoseModel:
"""Mock pose estimation model for performance testing."""
def __init__(self, model_complexity: str = "standard"):
self.model_complexity = model_complexity
self.is_loaded = False
self.inference_count = 0
self.total_inference_time = 0.0
self.batch_size = 1
# Model complexity affects inference time
self.base_inference_time = {
"lightweight": 0.02, # 20ms
"standard": 0.05, # 50ms
"high_accuracy": 0.15 # 150ms
}.get(model_complexity, 0.05)
async def load_model(self):
"""Load the model."""
# Simulate model loading time
load_time = {
"lightweight": 0.5,
"standard": 2.0,
"high_accuracy": 5.0
}.get(self.model_complexity, 2.0)
await asyncio.sleep(load_time)
self.is_loaded = True
async def predict(self, features: np.ndarray) -> Dict[str, Any]:
"""Run inference on features."""
if not self.is_loaded:
raise RuntimeError("Model not loaded")
start_time = time.time()
# Simulate inference computation
batch_size = features.shape[0] if len(features.shape) > 2 else 1
inference_time = self.base_inference_time * batch_size
# Add some variance
inference_time *= np.random.uniform(0.8, 1.2)
await asyncio.sleep(inference_time)
end_time = time.time()
actual_inference_time = end_time - start_time
self.inference_count += batch_size
self.total_inference_time += actual_inference_time
# Generate mock predictions
predictions = []
for i in range(batch_size):
predictions.append({
"person_id": f"person_{i}",
"confidence": np.random.uniform(0.5, 0.95),
"keypoints": np.random.rand(17, 3).tolist(), # 17 keypoints with x,y,confidence
"bounding_box": {
"x": np.random.uniform(0, 640),
"y": np.random.uniform(0, 480),
"width": np.random.uniform(50, 200),
"height": np.random.uniform(100, 300)
}
})
return {
"predictions": predictions,
"inference_time_ms": actual_inference_time * 1000,
"model_complexity": self.model_complexity,
"batch_size": batch_size
}
def get_performance_stats(self) -> Dict[str, Any]:
"""Get performance statistics."""
avg_inference_time = (
self.total_inference_time / self.inference_count
if self.inference_count > 0 else 0
)
return {
"total_inferences": self.inference_count,
"total_time_seconds": self.total_inference_time,
"average_inference_time_ms": avg_inference_time * 1000,
"throughput_fps": 1.0 / avg_inference_time if avg_inference_time > 0 else 0,
"model_complexity": self.model_complexity
}
class TestInferenceSpeed:
"""Test inference speed for different model configurations."""
@pytest.fixture
def lightweight_model(self):
"""Create lightweight model."""
return MockPoseModel("lightweight")
@pytest.fixture
def standard_model(self):
"""Create standard model."""
return MockPoseModel("standard")
@pytest.fixture
def high_accuracy_model(self):
"""Create high accuracy model."""
return MockPoseModel("high_accuracy")
@pytest.fixture
def sample_features(self):
"""Create sample feature data."""
return np.random.rand(64, 32) # 64x32 feature matrix
@pytest.mark.asyncio
async def test_single_inference_speed_should_fail_initially(self, standard_model, sample_features):
"""Test single inference speed - should fail initially."""
await standard_model.load_model()
start_time = time.time()
result = await standard_model.predict(sample_features)
end_time = time.time()
inference_time = (end_time - start_time) * 1000 # Convert to ms
# This will fail initially
assert inference_time < 100 # Should be less than 100ms
assert result["inference_time_ms"] > 0
assert len(result["predictions"]) > 0
assert result["model_complexity"] == "standard"
@pytest.mark.asyncio
async def test_model_complexity_comparison_should_fail_initially(self, sample_features):
"""Test model complexity comparison - should fail initially."""
models = {
"lightweight": MockPoseModel("lightweight"),
"standard": MockPoseModel("standard"),
"high_accuracy": MockPoseModel("high_accuracy")
}
# Load all models
for model in models.values():
await model.load_model()
# Run inference on each model
results = {}
for name, model in models.items():
start_time = time.time()
result = await model.predict(sample_features)
end_time = time.time()
results[name] = {
"inference_time_ms": (end_time - start_time) * 1000,
"result": result
}
# This will fail initially
# Lightweight should be fastest
assert results["lightweight"]["inference_time_ms"] < results["standard"]["inference_time_ms"]
assert results["standard"]["inference_time_ms"] < results["high_accuracy"]["inference_time_ms"]
# All should complete within reasonable time
for name, result in results.items():
assert result["inference_time_ms"] < 500 # Less than 500ms
@pytest.mark.asyncio
async def test_batch_inference_performance_should_fail_initially(self, standard_model):
"""Test batch inference performance - should fail initially."""
await standard_model.load_model()
# Test different batch sizes
batch_sizes = [1, 4, 8, 16]
results = {}
for batch_size in batch_sizes:
# Create batch of features
batch_features = np.random.rand(batch_size, 64, 32)
start_time = time.time()
result = await standard_model.predict(batch_features)
end_time = time.time()
total_time = (end_time - start_time) * 1000
per_sample_time = total_time / batch_size
results[batch_size] = {
"total_time_ms": total_time,
"per_sample_time_ms": per_sample_time,
"throughput_fps": 1000 / per_sample_time,
"predictions": len(result["predictions"])
}
# This will fail initially
# Batch processing should be more efficient per sample
assert results[1]["per_sample_time_ms"] > results[4]["per_sample_time_ms"]
assert results[4]["per_sample_time_ms"] > results[8]["per_sample_time_ms"]
# Verify correct number of predictions
for batch_size, result in results.items():
assert result["predictions"] == batch_size
@pytest.mark.asyncio
async def test_sustained_inference_performance_should_fail_initially(self, standard_model, sample_features):
"""Test sustained inference performance - should fail initially."""
await standard_model.load_model()
# Run many inferences to test sustained performance
num_inferences = 50
inference_times = []
for i in range(num_inferences):
start_time = time.time()
await standard_model.predict(sample_features)
end_time = time.time()
inference_times.append((end_time - start_time) * 1000)
# This will fail initially
# Calculate performance metrics
avg_time = np.mean(inference_times)
std_time = np.std(inference_times)
min_time = np.min(inference_times)
max_time = np.max(inference_times)
assert avg_time < 100 # Average should be less than 100ms
assert std_time < 20 # Standard deviation should be low (consistent performance)
assert max_time < avg_time * 2 # No inference should take more than 2x average
# Check model statistics
stats = standard_model.get_performance_stats()
assert stats["total_inferences"] == num_inferences
assert stats["throughput_fps"] > 10 # Should achieve at least 10 FPS
class TestInferenceOptimization:
"""Test inference optimization techniques."""
@pytest.mark.asyncio
async def test_model_warmup_effect_should_fail_initially(self, standard_model, sample_features):
"""Test model warmup effect - should fail initially."""
await standard_model.load_model()
# First inference (cold start)
start_time = time.time()
await standard_model.predict(sample_features)
cold_start_time = (time.time() - start_time) * 1000
# Subsequent inferences (warmed up)
warm_times = []
for _ in range(5):
start_time = time.time()
await standard_model.predict(sample_features)
warm_times.append((time.time() - start_time) * 1000)
avg_warm_time = np.mean(warm_times)
# This will fail initially
# Warm inferences should be faster than cold start
assert avg_warm_time <= cold_start_time
assert cold_start_time > 0
assert avg_warm_time > 0
@pytest.mark.asyncio
async def test_concurrent_inference_performance_should_fail_initially(self, sample_features):
"""Test concurrent inference performance - should fail initially."""
# Create multiple model instances
models = [MockPoseModel("standard") for _ in range(3)]
# Load all models
for model in models:
await model.load_model()
async def run_inference(model, features):
start_time = time.time()
result = await model.predict(features)
end_time = time.time()
return (end_time - start_time) * 1000
# Run concurrent inferences
tasks = [run_inference(model, sample_features) for model in models]
inference_times = await asyncio.gather(*tasks)
# This will fail initially
# All inferences should complete
assert len(inference_times) == 3
assert all(time > 0 for time in inference_times)
# Concurrent execution shouldn't be much slower than sequential
avg_concurrent_time = np.mean(inference_times)
assert avg_concurrent_time < 200 # Should complete within 200ms each
@pytest.mark.asyncio
async def test_memory_usage_during_inference_should_fail_initially(self, standard_model, sample_features):
"""Test memory usage during inference - should fail initially."""
process = psutil.Process(os.getpid())
await standard_model.load_model()
initial_memory = process.memory_info().rss
# Run multiple inferences
for i in range(20):
await standard_model.predict(sample_features)
# Check memory every 5 inferences
if i % 5 == 0:
current_memory = process.memory_info().rss
memory_increase = current_memory - initial_memory
# This will fail initially
# Memory increase should be reasonable (less than 50MB)
assert memory_increase < 50 * 1024 * 1024
final_memory = process.memory_info().rss
total_increase = final_memory - initial_memory
# Total memory increase should be reasonable
assert total_increase < 100 * 1024 * 1024 # Less than 100MB
class TestInferenceAccuracy:
"""Test inference accuracy and quality metrics."""
@pytest.mark.asyncio
async def test_prediction_consistency_should_fail_initially(self, standard_model, sample_features):
"""Test prediction consistency - should fail initially."""
await standard_model.load_model()
# Run same inference multiple times
results = []
for _ in range(5):
result = await standard_model.predict(sample_features)
results.append(result)
# This will fail initially
# All results should have similar structure
for result in results:
assert "predictions" in result
assert "inference_time_ms" in result
assert len(result["predictions"]) > 0
# Inference times should be consistent
inference_times = [r["inference_time_ms"] for r in results]
avg_time = np.mean(inference_times)
std_time = np.std(inference_times)
assert std_time < avg_time * 0.5 # Standard deviation should be less than 50% of mean
@pytest.mark.asyncio
async def test_confidence_score_distribution_should_fail_initially(self, standard_model, sample_features):
"""Test confidence score distribution - should fail initially."""
await standard_model.load_model()
# Collect confidence scores from multiple inferences
all_confidences = []
for _ in range(20):
result = await standard_model.predict(sample_features)
for prediction in result["predictions"]:
all_confidences.append(prediction["confidence"])
# This will fail initially
if all_confidences: # Only test if we have predictions
# Confidence scores should be in valid range
assert all(0.0 <= conf <= 1.0 for conf in all_confidences)
# Should have reasonable distribution
avg_confidence = np.mean(all_confidences)
assert 0.3 <= avg_confidence <= 0.95 # Reasonable average confidence
@pytest.mark.asyncio
async def test_keypoint_detection_quality_should_fail_initially(self, standard_model, sample_features):
"""Test keypoint detection quality - should fail initially."""
await standard_model.load_model()
result = await standard_model.predict(sample_features)
# This will fail initially
for prediction in result["predictions"]:
keypoints = prediction["keypoints"]
# Should have correct number of keypoints
assert len(keypoints) == 17 # Standard pose has 17 keypoints
# Each keypoint should have x, y, confidence
for keypoint in keypoints:
assert len(keypoint) == 3
x, y, conf = keypoint
assert isinstance(x, (int, float))
assert isinstance(y, (int, float))
assert 0.0 <= conf <= 1.0
class TestInferenceScaling:
"""Test inference scaling characteristics."""
@pytest.mark.asyncio
async def test_input_size_scaling_should_fail_initially(self, standard_model):
"""Test inference scaling with input size - should fail initially."""
await standard_model.load_model()
# Test different input sizes
input_sizes = [(32, 16), (64, 32), (128, 64), (256, 128)]
results = {}
for height, width in input_sizes:
features = np.random.rand(height, width)
start_time = time.time()
result = await standard_model.predict(features)
end_time = time.time()
inference_time = (end_time - start_time) * 1000
input_size = height * width
results[input_size] = {
"inference_time_ms": inference_time,
"dimensions": (height, width),
"predictions": len(result["predictions"])
}
# This will fail initially
# Larger inputs should generally take longer
sizes = sorted(results.keys())
for i in range(len(sizes) - 1):
current_size = sizes[i]
next_size = sizes[i + 1]
# Allow some variance, but larger inputs should generally be slower
time_ratio = results[next_size]["inference_time_ms"] / results[current_size]["inference_time_ms"]
assert time_ratio >= 0.8 # Next size shouldn't be much faster
@pytest.mark.asyncio
async def test_throughput_under_load_should_fail_initially(self, standard_model, sample_features):
"""Test throughput under sustained load - should fail initially."""
await standard_model.load_model()
# Simulate sustained load
duration_seconds = 5
start_time = time.time()
inference_count = 0
while time.time() - start_time < duration_seconds:
await standard_model.predict(sample_features)
inference_count += 1
actual_duration = time.time() - start_time
throughput = inference_count / actual_duration
# This will fail initially
# Should maintain reasonable throughput under load
assert throughput > 5 # At least 5 FPS
assert inference_count > 20 # Should complete at least 20 inferences in 5 seconds
# Check model statistics
stats = standard_model.get_performance_stats()
assert stats["total_inferences"] >= inference_count
assert stats["throughput_fps"] > 0
@pytest.mark.benchmark
class TestInferenceBenchmarks:
"""Benchmark tests for inference performance."""
@pytest.mark.asyncio
async def test_benchmark_lightweight_model_should_fail_initially(self, benchmark):
"""Benchmark lightweight model performance - should fail initially."""
model = MockPoseModel("lightweight")
await model.load_model()
features = np.random.rand(64, 32)
async def run_inference():
return await model.predict(features)
# This will fail initially
# Benchmark the inference
result = await run_inference()
assert result["inference_time_ms"] < 50 # Should be less than 50ms
@pytest.mark.asyncio
async def test_benchmark_batch_processing_should_fail_initially(self, benchmark):
"""Benchmark batch processing performance - should fail initially."""
model = MockPoseModel("standard")
await model.load_model()
batch_features = np.random.rand(8, 64, 32) # Batch of 8
async def run_batch_inference():
return await model.predict(batch_features)
# This will fail initially
result = await run_batch_inference()
assert len(result["predictions"]) == 8
assert result["inference_time_ms"] < 200 # Batch should be efficient