updates
This commit is contained in:
661
tests/fixtures/api_client.py
vendored
Normal file
661
tests/fixtures/api_client.py
vendored
Normal file
@@ -0,0 +1,661 @@
|
||||
"""
|
||||
Test client utilities for API testing.
|
||||
|
||||
Provides mock and real API clients for comprehensive testing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional, Union, AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
import websockets
|
||||
import jwt
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
"""Authentication related errors."""
|
||||
pass
|
||||
|
||||
|
||||
class APIError(Exception):
|
||||
"""General API errors."""
|
||||
pass
|
||||
|
||||
|
||||
class RateLimitError(Exception):
|
||||
"""Rate limiting errors."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIResponse:
|
||||
"""API response wrapper."""
|
||||
status_code: int
|
||||
data: Dict[str, Any]
|
||||
headers: Dict[str, str]
|
||||
response_time_ms: float
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
class MockAPIClient:
|
||||
"""Mock API client for testing."""
|
||||
|
||||
def __init__(self, base_url: str = "http://localhost:8000"):
|
||||
self.base_url = base_url
|
||||
self.session = None
|
||||
self.auth_token = None
|
||||
self.refresh_token = None
|
||||
self.token_expires_at = None
|
||||
self.request_history = []
|
||||
self.response_delays = {}
|
||||
self.error_simulation = {}
|
||||
self.rate_limit_config = {
|
||||
"enabled": False,
|
||||
"requests_per_minute": 60,
|
||||
"current_count": 0,
|
||||
"window_start": time.time()
|
||||
}
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
await self.disconnect()
|
||||
|
||||
async def connect(self):
|
||||
"""Initialize connection."""
|
||||
self.session = aiohttp.ClientSession()
|
||||
|
||||
async def disconnect(self):
|
||||
"""Close connection."""
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
|
||||
def set_response_delay(self, endpoint: str, delay_ms: float):
|
||||
"""Set artificial delay for endpoint."""
|
||||
self.response_delays[endpoint] = delay_ms
|
||||
|
||||
def simulate_error(self, endpoint: str, error_type: str, probability: float = 1.0):
|
||||
"""Simulate errors for endpoint."""
|
||||
self.error_simulation[endpoint] = {
|
||||
"type": error_type,
|
||||
"probability": probability
|
||||
}
|
||||
|
||||
def enable_rate_limiting(self, requests_per_minute: int = 60):
|
||||
"""Enable rate limiting simulation."""
|
||||
self.rate_limit_config.update({
|
||||
"enabled": True,
|
||||
"requests_per_minute": requests_per_minute,
|
||||
"current_count": 0,
|
||||
"window_start": time.time()
|
||||
})
|
||||
|
||||
async def _check_rate_limit(self):
|
||||
"""Check rate limiting."""
|
||||
if not self.rate_limit_config["enabled"]:
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
window_duration = 60 # 1 minute
|
||||
|
||||
# Reset window if needed
|
||||
if current_time - self.rate_limit_config["window_start"] > window_duration:
|
||||
self.rate_limit_config["current_count"] = 0
|
||||
self.rate_limit_config["window_start"] = current_time
|
||||
|
||||
# Check limit
|
||||
if self.rate_limit_config["current_count"] >= self.rate_limit_config["requests_per_minute"]:
|
||||
raise RateLimitError("Rate limit exceeded")
|
||||
|
||||
self.rate_limit_config["current_count"] += 1
|
||||
|
||||
async def _simulate_network_delay(self, endpoint: str):
|
||||
"""Simulate network delay."""
|
||||
delay = self.response_delays.get(endpoint, 0)
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay / 1000) # Convert ms to seconds
|
||||
|
||||
async def _check_error_simulation(self, endpoint: str):
|
||||
"""Check if error should be simulated."""
|
||||
if endpoint in self.error_simulation:
|
||||
config = self.error_simulation[endpoint]
|
||||
if random.random() < config["probability"]:
|
||||
error_type = config["type"]
|
||||
if error_type == "timeout":
|
||||
raise asyncio.TimeoutError("Simulated timeout")
|
||||
elif error_type == "connection":
|
||||
raise aiohttp.ClientConnectionError("Simulated connection error")
|
||||
elif error_type == "server_error":
|
||||
raise APIError("Simulated server error")
|
||||
|
||||
async def _make_request(self, method: str, endpoint: str, **kwargs) -> APIResponse:
|
||||
"""Make HTTP request with simulation."""
|
||||
start_time = time.time()
|
||||
|
||||
# Check rate limiting
|
||||
await self._check_rate_limit()
|
||||
|
||||
# Simulate network delay
|
||||
await self._simulate_network_delay(endpoint)
|
||||
|
||||
# Check error simulation
|
||||
await self._check_error_simulation(endpoint)
|
||||
|
||||
# Record request
|
||||
request_record = {
|
||||
"method": method,
|
||||
"endpoint": endpoint,
|
||||
"timestamp": datetime.utcnow(),
|
||||
"kwargs": kwargs
|
||||
}
|
||||
self.request_history.append(request_record)
|
||||
|
||||
# Generate mock response
|
||||
response_data = await self._generate_mock_response(method, endpoint, kwargs)
|
||||
|
||||
end_time = time.time()
|
||||
response_time = (end_time - start_time) * 1000
|
||||
|
||||
return APIResponse(
|
||||
status_code=response_data["status_code"],
|
||||
data=response_data["data"],
|
||||
headers=response_data.get("headers", {}),
|
||||
response_time_ms=response_time,
|
||||
timestamp=datetime.utcnow()
|
||||
)
|
||||
|
||||
async def _generate_mock_response(self, method: str, endpoint: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Generate mock response based on endpoint."""
|
||||
if endpoint == "/health":
|
||||
return {
|
||||
"status_code": 200,
|
||||
"data": {
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"version": "1.0.0"
|
||||
}
|
||||
}
|
||||
|
||||
elif endpoint == "/auth/login":
|
||||
if method == "POST":
|
||||
# Generate mock JWT tokens
|
||||
payload = {
|
||||
"user_id": "test_user",
|
||||
"exp": datetime.utcnow() + timedelta(hours=1)
|
||||
}
|
||||
access_token = jwt.encode(payload, "secret", algorithm="HS256")
|
||||
refresh_token = jwt.encode({"user_id": "test_user"}, "secret", algorithm="HS256")
|
||||
|
||||
self.auth_token = access_token
|
||||
self.refresh_token = refresh_token
|
||||
self.token_expires_at = payload["exp"]
|
||||
|
||||
return {
|
||||
"status_code": 200,
|
||||
"data": {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": 3600
|
||||
}
|
||||
}
|
||||
|
||||
elif endpoint == "/auth/refresh":
|
||||
if method == "POST" and self.refresh_token:
|
||||
# Generate new access token
|
||||
payload = {
|
||||
"user_id": "test_user",
|
||||
"exp": datetime.utcnow() + timedelta(hours=1)
|
||||
}
|
||||
access_token = jwt.encode(payload, "secret", algorithm="HS256")
|
||||
|
||||
self.auth_token = access_token
|
||||
self.token_expires_at = payload["exp"]
|
||||
|
||||
return {
|
||||
"status_code": 200,
|
||||
"data": {
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": 3600
|
||||
}
|
||||
}
|
||||
|
||||
elif endpoint == "/pose/detect":
|
||||
if method == "POST":
|
||||
return {
|
||||
"status_code": 200,
|
||||
"data": {
|
||||
"persons": [
|
||||
{
|
||||
"person_id": "person_1",
|
||||
"confidence": 0.85,
|
||||
"bounding_box": {"x": 100, "y": 150, "width": 80, "height": 180},
|
||||
"keypoints": [[x, y, 0.9] for x, y in zip(range(17), range(17))],
|
||||
"activity": "standing"
|
||||
}
|
||||
],
|
||||
"processing_time_ms": 45.2,
|
||||
"model_version": "v1.0",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
elif endpoint == "/config":
|
||||
if method == "GET":
|
||||
return {
|
||||
"status_code": 200,
|
||||
"data": {
|
||||
"model_config": {
|
||||
"confidence_threshold": 0.7,
|
||||
"nms_threshold": 0.5,
|
||||
"max_persons": 10
|
||||
},
|
||||
"processing_config": {
|
||||
"batch_size": 1,
|
||||
"use_gpu": True,
|
||||
"preprocessing": "standard"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Default response
|
||||
return {
|
||||
"status_code": 404,
|
||||
"data": {"error": "Endpoint not found"}
|
||||
}
|
||||
|
||||
async def get(self, endpoint: str, **kwargs) -> APIResponse:
|
||||
"""Make GET request."""
|
||||
return await self._make_request("GET", endpoint, **kwargs)
|
||||
|
||||
async def post(self, endpoint: str, **kwargs) -> APIResponse:
|
||||
"""Make POST request."""
|
||||
return await self._make_request("POST", endpoint, **kwargs)
|
||||
|
||||
async def put(self, endpoint: str, **kwargs) -> APIResponse:
|
||||
"""Make PUT request."""
|
||||
return await self._make_request("PUT", endpoint, **kwargs)
|
||||
|
||||
async def delete(self, endpoint: str, **kwargs) -> APIResponse:
|
||||
"""Make DELETE request."""
|
||||
return await self._make_request("DELETE", endpoint, **kwargs)
|
||||
|
||||
async def login(self, username: str, password: str) -> bool:
|
||||
"""Authenticate with API."""
|
||||
response = await self.post("/auth/login", json={
|
||||
"username": username,
|
||||
"password": password
|
||||
})
|
||||
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
raise AuthenticationError("Login failed")
|
||||
|
||||
async def refresh_auth_token(self) -> bool:
|
||||
"""Refresh authentication token."""
|
||||
if not self.refresh_token:
|
||||
raise AuthenticationError("No refresh token available")
|
||||
|
||||
response = await self.post("/auth/refresh", json={
|
||||
"refresh_token": self.refresh_token
|
||||
})
|
||||
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
raise AuthenticationError("Token refresh failed")
|
||||
|
||||
def is_authenticated(self) -> bool:
|
||||
"""Check if client is authenticated."""
|
||||
if not self.auth_token or not self.token_expires_at:
|
||||
return False
|
||||
|
||||
return datetime.utcnow() < self.token_expires_at
|
||||
|
||||
def get_request_history(self) -> List[Dict[str, Any]]:
|
||||
"""Get request history."""
|
||||
return self.request_history.copy()
|
||||
|
||||
def clear_request_history(self):
|
||||
"""Clear request history."""
|
||||
self.request_history.clear()
|
||||
|
||||
|
||||
class MockWebSocketClient:
|
||||
"""Mock WebSocket client for testing."""
|
||||
|
||||
def __init__(self, uri: str = "ws://localhost:8000/ws"):
|
||||
self.uri = uri
|
||||
self.websocket = None
|
||||
self.is_connected = False
|
||||
self.messages_received = []
|
||||
self.messages_sent = []
|
||||
self.connection_errors = []
|
||||
self.auto_respond = True
|
||||
self.response_delay = 0.01 # 10ms default delay
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to WebSocket."""
|
||||
try:
|
||||
# Simulate connection
|
||||
await asyncio.sleep(0.01)
|
||||
self.is_connected = True
|
||||
return True
|
||||
except Exception as e:
|
||||
self.connection_errors.append(str(e))
|
||||
return False
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from WebSocket."""
|
||||
self.is_connected = False
|
||||
self.websocket = None
|
||||
|
||||
async def send_message(self, message: Dict[str, Any]) -> bool:
|
||||
"""Send message to WebSocket."""
|
||||
if not self.is_connected:
|
||||
raise ConnectionError("WebSocket not connected")
|
||||
|
||||
# Record sent message
|
||||
self.messages_sent.append({
|
||||
"message": message,
|
||||
"timestamp": datetime.utcnow()
|
||||
})
|
||||
|
||||
# Auto-respond if enabled
|
||||
if self.auto_respond:
|
||||
await asyncio.sleep(self.response_delay)
|
||||
response = await self._generate_auto_response(message)
|
||||
if response:
|
||||
self.messages_received.append({
|
||||
"message": response,
|
||||
"timestamp": datetime.utcnow()
|
||||
})
|
||||
|
||||
return True
|
||||
|
||||
async def receive_message(self, timeout: float = 1.0) -> Optional[Dict[str, Any]]:
|
||||
"""Receive message from WebSocket."""
|
||||
if not self.is_connected:
|
||||
raise ConnectionError("WebSocket not connected")
|
||||
|
||||
# Wait for message or timeout
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
if self.messages_received:
|
||||
return self.messages_received.pop(0)["message"]
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return None
|
||||
|
||||
async def _generate_auto_response(self, message: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Generate automatic response to message."""
|
||||
message_type = message.get("type")
|
||||
|
||||
if message_type == "subscribe":
|
||||
return {
|
||||
"type": "subscription_confirmed",
|
||||
"channel": message.get("channel"),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
elif message_type == "pose_request":
|
||||
return {
|
||||
"type": "pose_data",
|
||||
"data": {
|
||||
"persons": [
|
||||
{
|
||||
"person_id": "person_1",
|
||||
"confidence": 0.88,
|
||||
"bounding_box": {"x": 150, "y": 200, "width": 80, "height": 180},
|
||||
"keypoints": [[x, y, 0.9] for x, y in zip(range(17), range(17))]
|
||||
}
|
||||
],
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
},
|
||||
"request_id": message.get("request_id")
|
||||
}
|
||||
|
||||
elif message_type == "ping":
|
||||
return {
|
||||
"type": "pong",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def set_auto_respond(self, enabled: bool, delay_ms: float = 10):
|
||||
"""Configure auto-response behavior."""
|
||||
self.auto_respond = enabled
|
||||
self.response_delay = delay_ms / 1000
|
||||
|
||||
def inject_message(self, message: Dict[str, Any]):
|
||||
"""Inject message as if received from server."""
|
||||
self.messages_received.append({
|
||||
"message": message,
|
||||
"timestamp": datetime.utcnow()
|
||||
})
|
||||
|
||||
def get_sent_messages(self) -> List[Dict[str, Any]]:
|
||||
"""Get all sent messages."""
|
||||
return self.messages_sent.copy()
|
||||
|
||||
def get_received_messages(self) -> List[Dict[str, Any]]:
|
||||
"""Get all received messages."""
|
||||
return self.messages_received.copy()
|
||||
|
||||
def clear_message_history(self):
|
||||
"""Clear message history."""
|
||||
self.messages_sent.clear()
|
||||
self.messages_received.clear()
|
||||
|
||||
|
||||
class APITestClient:
|
||||
"""High-level test client combining HTTP and WebSocket."""
|
||||
|
||||
def __init__(self, base_url: str = "http://localhost:8000"):
|
||||
self.base_url = base_url
|
||||
self.ws_url = base_url.replace("http", "ws") + "/ws"
|
||||
self.http_client = MockAPIClient(base_url)
|
||||
self.ws_client = MockWebSocketClient(self.ws_url)
|
||||
self.test_session_id = None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
await self.setup()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
await self.teardown()
|
||||
|
||||
async def setup(self):
|
||||
"""Setup test client."""
|
||||
await self.http_client.connect()
|
||||
await self.ws_client.connect()
|
||||
self.test_session_id = f"test_session_{int(time.time())}"
|
||||
|
||||
async def teardown(self):
|
||||
"""Teardown test client."""
|
||||
await self.ws_client.disconnect()
|
||||
await self.http_client.disconnect()
|
||||
|
||||
async def authenticate(self, username: str = "test_user", password: str = "test_pass") -> bool:
|
||||
"""Authenticate with API."""
|
||||
return await self.http_client.login(username, password)
|
||||
|
||||
async def test_health_endpoint(self) -> APIResponse:
|
||||
"""Test health endpoint."""
|
||||
return await self.http_client.get("/health")
|
||||
|
||||
async def test_pose_detection(self, csi_data: Dict[str, Any]) -> APIResponse:
|
||||
"""Test pose detection endpoint."""
|
||||
return await self.http_client.post("/pose/detect", json=csi_data)
|
||||
|
||||
async def test_websocket_streaming(self, duration_seconds: int = 5) -> List[Dict[str, Any]]:
|
||||
"""Test WebSocket streaming."""
|
||||
# Subscribe to pose stream
|
||||
await self.ws_client.send_message({
|
||||
"type": "subscribe",
|
||||
"channel": "pose_stream",
|
||||
"session_id": self.test_session_id
|
||||
})
|
||||
|
||||
# Collect messages for specified duration
|
||||
messages = []
|
||||
end_time = time.time() + duration_seconds
|
||||
|
||||
while time.time() < end_time:
|
||||
message = await self.ws_client.receive_message(timeout=0.1)
|
||||
if message:
|
||||
messages.append(message)
|
||||
|
||||
return messages
|
||||
|
||||
async def simulate_concurrent_requests(self, num_requests: int = 10) -> List[APIResponse]:
|
||||
"""Simulate concurrent HTTP requests."""
|
||||
tasks = []
|
||||
|
||||
for i in range(num_requests):
|
||||
task = asyncio.create_task(self.http_client.get("/health"))
|
||||
tasks.append(task)
|
||||
|
||||
responses = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
return responses
|
||||
|
||||
async def simulate_websocket_load(self, num_connections: int = 5, duration_seconds: int = 3) -> Dict[str, Any]:
|
||||
"""Simulate WebSocket load testing."""
|
||||
# Create multiple WebSocket clients
|
||||
ws_clients = []
|
||||
for i in range(num_connections):
|
||||
client = MockWebSocketClient(self.ws_url)
|
||||
await client.connect()
|
||||
ws_clients.append(client)
|
||||
|
||||
# Send messages from all clients
|
||||
message_counts = []
|
||||
|
||||
try:
|
||||
tasks = []
|
||||
for i, client in enumerate(ws_clients):
|
||||
task = asyncio.create_task(self._send_messages_for_duration(client, duration_seconds, i))
|
||||
tasks.append(task)
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
message_counts = results
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
for client in ws_clients:
|
||||
await client.disconnect()
|
||||
|
||||
return {
|
||||
"num_connections": num_connections,
|
||||
"duration_seconds": duration_seconds,
|
||||
"messages_per_connection": message_counts,
|
||||
"total_messages": sum(message_counts)
|
||||
}
|
||||
|
||||
async def _send_messages_for_duration(self, client: MockWebSocketClient, duration: int, client_id: int) -> int:
|
||||
"""Send messages for specified duration."""
|
||||
message_count = 0
|
||||
end_time = time.time() + duration
|
||||
|
||||
while time.time() < end_time:
|
||||
await client.send_message({
|
||||
"type": "ping",
|
||||
"client_id": client_id,
|
||||
"message_id": message_count
|
||||
})
|
||||
message_count += 1
|
||||
await asyncio.sleep(0.1) # 10 messages per second
|
||||
|
||||
return message_count
|
||||
|
||||
def configure_error_simulation(self, endpoint: str, error_type: str, probability: float = 0.1):
|
||||
"""Configure error simulation for testing."""
|
||||
self.http_client.simulate_error(endpoint, error_type, probability)
|
||||
|
||||
def configure_rate_limiting(self, requests_per_minute: int = 60):
|
||||
"""Configure rate limiting for testing."""
|
||||
self.http_client.enable_rate_limiting(requests_per_minute)
|
||||
|
||||
def get_performance_metrics(self) -> Dict[str, Any]:
|
||||
"""Get performance metrics from test session."""
|
||||
http_history = self.http_client.get_request_history()
|
||||
ws_sent = self.ws_client.get_sent_messages()
|
||||
ws_received = self.ws_client.get_received_messages()
|
||||
|
||||
# Calculate HTTP metrics
|
||||
if http_history:
|
||||
response_times = [r.get("response_time_ms", 0) for r in http_history]
|
||||
http_metrics = {
|
||||
"total_requests": len(http_history),
|
||||
"avg_response_time_ms": sum(response_times) / len(response_times),
|
||||
"min_response_time_ms": min(response_times),
|
||||
"max_response_time_ms": max(response_times)
|
||||
}
|
||||
else:
|
||||
http_metrics = {"total_requests": 0}
|
||||
|
||||
# Calculate WebSocket metrics
|
||||
ws_metrics = {
|
||||
"messages_sent": len(ws_sent),
|
||||
"messages_received": len(ws_received),
|
||||
"connection_active": self.ws_client.is_connected
|
||||
}
|
||||
|
||||
return {
|
||||
"session_id": self.test_session_id,
|
||||
"http_metrics": http_metrics,
|
||||
"websocket_metrics": ws_metrics,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
# Utility functions for test data generation
|
||||
def generate_test_csi_data() -> Dict[str, Any]:
|
||||
"""Generate test CSI data for API testing."""
|
||||
import numpy as np
|
||||
|
||||
return {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"router_id": "test_router_001",
|
||||
"amplitude": np.random.uniform(0, 1, (4, 64)).tolist(),
|
||||
"phase": np.random.uniform(-np.pi, np.pi, (4, 64)).tolist(),
|
||||
"frequency": 5.8e9,
|
||||
"bandwidth": 80e6,
|
||||
"num_antennas": 4,
|
||||
"num_subcarriers": 64
|
||||
}
|
||||
|
||||
|
||||
def create_test_user_credentials() -> Dict[str, str]:
|
||||
"""Create test user credentials."""
|
||||
return {
|
||||
"username": "test_user",
|
||||
"password": "test_password_123",
|
||||
"email": "test@example.com"
|
||||
}
|
||||
|
||||
|
||||
async def wait_for_condition(condition_func, timeout: float = 5.0, interval: float = 0.1) -> bool:
|
||||
"""Wait for condition to become true."""
|
||||
end_time = time.time() + timeout
|
||||
|
||||
while time.time() < end_time:
|
||||
if await condition_func() if asyncio.iscoroutinefunction(condition_func) else condition_func():
|
||||
return True
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
return False
|
||||
487
tests/fixtures/csi_data.py
vendored
Normal file
487
tests/fixtures/csi_data.py
vendored
Normal file
@@ -0,0 +1,487 @@
|
||||
"""
|
||||
Test data generation utilities for CSI data.
|
||||
|
||||
Provides realistic CSI data samples for testing pose estimation pipeline.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
import json
|
||||
import random
|
||||
|
||||
|
||||
class CSIDataGenerator:
|
||||
"""Generate realistic CSI data for testing."""
|
||||
|
||||
def __init__(self,
|
||||
frequency: float = 5.8e9,
|
||||
bandwidth: float = 80e6,
|
||||
num_antennas: int = 4,
|
||||
num_subcarriers: int = 64):
|
||||
self.frequency = frequency
|
||||
self.bandwidth = bandwidth
|
||||
self.num_antennas = num_antennas
|
||||
self.num_subcarriers = num_subcarriers
|
||||
self.sample_rate = 1000 # Hz
|
||||
self.noise_level = 0.1
|
||||
|
||||
# Pre-computed patterns for different scenarios
|
||||
self._initialize_patterns()
|
||||
|
||||
def _initialize_patterns(self):
|
||||
"""Initialize CSI patterns for different scenarios."""
|
||||
# Empty room pattern (baseline)
|
||||
self.empty_room_pattern = {
|
||||
"amplitude_mean": 0.3,
|
||||
"amplitude_std": 0.05,
|
||||
"phase_variance": 0.1,
|
||||
"temporal_stability": 0.95
|
||||
}
|
||||
|
||||
# Single person patterns
|
||||
self.single_person_patterns = {
|
||||
"standing": {
|
||||
"amplitude_mean": 0.5,
|
||||
"amplitude_std": 0.08,
|
||||
"phase_variance": 0.2,
|
||||
"temporal_stability": 0.85,
|
||||
"movement_frequency": 0.1
|
||||
},
|
||||
"walking": {
|
||||
"amplitude_mean": 0.6,
|
||||
"amplitude_std": 0.15,
|
||||
"phase_variance": 0.4,
|
||||
"temporal_stability": 0.6,
|
||||
"movement_frequency": 2.0
|
||||
},
|
||||
"sitting": {
|
||||
"amplitude_mean": 0.4,
|
||||
"amplitude_std": 0.06,
|
||||
"phase_variance": 0.15,
|
||||
"temporal_stability": 0.9,
|
||||
"movement_frequency": 0.05
|
||||
},
|
||||
"fallen": {
|
||||
"amplitude_mean": 0.35,
|
||||
"amplitude_std": 0.04,
|
||||
"phase_variance": 0.08,
|
||||
"temporal_stability": 0.95,
|
||||
"movement_frequency": 0.02
|
||||
}
|
||||
}
|
||||
|
||||
# Multi-person patterns
|
||||
self.multi_person_patterns = {
|
||||
2: {"amplitude_multiplier": 1.4, "phase_complexity": 1.6},
|
||||
3: {"amplitude_multiplier": 1.7, "phase_complexity": 2.1},
|
||||
4: {"amplitude_multiplier": 2.0, "phase_complexity": 2.8}
|
||||
}
|
||||
|
||||
def generate_empty_room_sample(self, timestamp: Optional[datetime] = None) -> Dict[str, Any]:
|
||||
"""Generate CSI sample for empty room."""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.utcnow()
|
||||
|
||||
pattern = self.empty_room_pattern
|
||||
|
||||
# Generate amplitude matrix
|
||||
amplitude = np.random.normal(
|
||||
pattern["amplitude_mean"],
|
||||
pattern["amplitude_std"],
|
||||
(self.num_antennas, self.num_subcarriers)
|
||||
)
|
||||
amplitude = np.clip(amplitude, 0, 1)
|
||||
|
||||
# Generate phase matrix
|
||||
phase = np.random.uniform(
|
||||
-np.pi, np.pi,
|
||||
(self.num_antennas, self.num_subcarriers)
|
||||
)
|
||||
|
||||
# Add temporal stability
|
||||
if hasattr(self, '_last_empty_sample'):
|
||||
stability = pattern["temporal_stability"]
|
||||
amplitude = stability * self._last_empty_sample["amplitude"] + (1 - stability) * amplitude
|
||||
phase = stability * self._last_empty_sample["phase"] + (1 - stability) * phase
|
||||
|
||||
sample = {
|
||||
"timestamp": timestamp.isoformat(),
|
||||
"router_id": "router_001",
|
||||
"amplitude": amplitude.tolist(),
|
||||
"phase": phase.tolist(),
|
||||
"frequency": self.frequency,
|
||||
"bandwidth": self.bandwidth,
|
||||
"num_antennas": self.num_antennas,
|
||||
"num_subcarriers": self.num_subcarriers,
|
||||
"sample_rate": self.sample_rate,
|
||||
"scenario": "empty_room",
|
||||
"signal_quality": np.random.uniform(0.85, 0.95)
|
||||
}
|
||||
|
||||
self._last_empty_sample = {
|
||||
"amplitude": amplitude,
|
||||
"phase": phase
|
||||
}
|
||||
|
||||
return sample
|
||||
|
||||
def generate_single_person_sample(self,
|
||||
activity: str = "standing",
|
||||
timestamp: Optional[datetime] = None) -> Dict[str, Any]:
|
||||
"""Generate CSI sample for single person activity."""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.utcnow()
|
||||
|
||||
if activity not in self.single_person_patterns:
|
||||
raise ValueError(f"Unknown activity: {activity}")
|
||||
|
||||
pattern = self.single_person_patterns[activity]
|
||||
|
||||
# Generate base amplitude
|
||||
amplitude = np.random.normal(
|
||||
pattern["amplitude_mean"],
|
||||
pattern["amplitude_std"],
|
||||
(self.num_antennas, self.num_subcarriers)
|
||||
)
|
||||
|
||||
# Add movement-induced variations
|
||||
movement_freq = pattern["movement_frequency"]
|
||||
time_factor = timestamp.timestamp()
|
||||
movement_modulation = 0.1 * np.sin(2 * np.pi * movement_freq * time_factor)
|
||||
amplitude += movement_modulation
|
||||
amplitude = np.clip(amplitude, 0, 1)
|
||||
|
||||
# Generate phase with activity-specific variance
|
||||
phase_base = np.random.uniform(-np.pi, np.pi, (self.num_antennas, self.num_subcarriers))
|
||||
phase_variance = pattern["phase_variance"]
|
||||
phase_noise = np.random.normal(0, phase_variance, (self.num_antennas, self.num_subcarriers))
|
||||
phase = phase_base + phase_noise
|
||||
phase = np.mod(phase + np.pi, 2 * np.pi) - np.pi # Wrap to [-π, π]
|
||||
|
||||
# Add temporal correlation
|
||||
if hasattr(self, f'_last_{activity}_sample'):
|
||||
stability = pattern["temporal_stability"]
|
||||
last_sample = getattr(self, f'_last_{activity}_sample')
|
||||
amplitude = stability * last_sample["amplitude"] + (1 - stability) * amplitude
|
||||
phase = stability * last_sample["phase"] + (1 - stability) * phase
|
||||
|
||||
sample = {
|
||||
"timestamp": timestamp.isoformat(),
|
||||
"router_id": "router_001",
|
||||
"amplitude": amplitude.tolist(),
|
||||
"phase": phase.tolist(),
|
||||
"frequency": self.frequency,
|
||||
"bandwidth": self.bandwidth,
|
||||
"num_antennas": self.num_antennas,
|
||||
"num_subcarriers": self.num_subcarriers,
|
||||
"sample_rate": self.sample_rate,
|
||||
"scenario": f"single_person_{activity}",
|
||||
"signal_quality": np.random.uniform(0.7, 0.9),
|
||||
"activity": activity
|
||||
}
|
||||
|
||||
setattr(self, f'_last_{activity}_sample', {
|
||||
"amplitude": amplitude,
|
||||
"phase": phase
|
||||
})
|
||||
|
||||
return sample
|
||||
|
||||
def generate_multi_person_sample(self,
|
||||
num_persons: int = 2,
|
||||
activities: Optional[List[str]] = None,
|
||||
timestamp: Optional[datetime] = None) -> Dict[str, Any]:
|
||||
"""Generate CSI sample for multiple persons."""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.utcnow()
|
||||
|
||||
if num_persons < 2 or num_persons > 4:
|
||||
raise ValueError("Number of persons must be between 2 and 4")
|
||||
|
||||
if activities is None:
|
||||
activities = random.choices(list(self.single_person_patterns.keys()), k=num_persons)
|
||||
|
||||
if len(activities) != num_persons:
|
||||
raise ValueError("Number of activities must match number of persons")
|
||||
|
||||
# Start with empty room baseline
|
||||
amplitude = np.random.normal(
|
||||
self.empty_room_pattern["amplitude_mean"],
|
||||
self.empty_room_pattern["amplitude_std"],
|
||||
(self.num_antennas, self.num_subcarriers)
|
||||
)
|
||||
|
||||
phase = np.random.uniform(
|
||||
-np.pi, np.pi,
|
||||
(self.num_antennas, self.num_subcarriers)
|
||||
)
|
||||
|
||||
# Add contribution from each person
|
||||
for i, activity in enumerate(activities):
|
||||
person_pattern = self.single_person_patterns[activity]
|
||||
|
||||
# Generate person-specific contribution
|
||||
person_amplitude = np.random.normal(
|
||||
person_pattern["amplitude_mean"] * 0.7, # Reduced for multi-person
|
||||
person_pattern["amplitude_std"],
|
||||
(self.num_antennas, self.num_subcarriers)
|
||||
)
|
||||
|
||||
# Add spatial variation (different persons at different locations)
|
||||
spatial_offset = i * self.num_subcarriers // num_persons
|
||||
person_amplitude = np.roll(person_amplitude, spatial_offset, axis=1)
|
||||
|
||||
# Add movement modulation
|
||||
movement_freq = person_pattern["movement_frequency"]
|
||||
time_factor = timestamp.timestamp() + i * 0.5 # Phase offset between persons
|
||||
movement_modulation = 0.05 * np.sin(2 * np.pi * movement_freq * time_factor)
|
||||
person_amplitude += movement_modulation
|
||||
|
||||
amplitude += person_amplitude
|
||||
|
||||
# Add phase contribution
|
||||
person_phase = np.random.normal(0, person_pattern["phase_variance"],
|
||||
(self.num_antennas, self.num_subcarriers))
|
||||
person_phase = np.roll(person_phase, spatial_offset, axis=1)
|
||||
phase += person_phase
|
||||
|
||||
# Apply multi-person complexity
|
||||
pattern = self.multi_person_patterns[num_persons]
|
||||
amplitude *= pattern["amplitude_multiplier"]
|
||||
phase *= pattern["phase_complexity"]
|
||||
|
||||
# Clip and normalize
|
||||
amplitude = np.clip(amplitude, 0, 1)
|
||||
phase = np.mod(phase + np.pi, 2 * np.pi) - np.pi
|
||||
|
||||
sample = {
|
||||
"timestamp": timestamp.isoformat(),
|
||||
"router_id": "router_001",
|
||||
"amplitude": amplitude.tolist(),
|
||||
"phase": phase.tolist(),
|
||||
"frequency": self.frequency,
|
||||
"bandwidth": self.bandwidth,
|
||||
"num_antennas": self.num_antennas,
|
||||
"num_subcarriers": self.num_subcarriers,
|
||||
"sample_rate": self.sample_rate,
|
||||
"scenario": f"multi_person_{num_persons}",
|
||||
"signal_quality": np.random.uniform(0.6, 0.8),
|
||||
"num_persons": num_persons,
|
||||
"activities": activities
|
||||
}
|
||||
|
||||
return sample
|
||||
|
||||
def generate_time_series(self,
|
||||
duration_seconds: int = 10,
|
||||
scenario: str = "single_person_walking",
|
||||
**kwargs) -> List[Dict[str, Any]]:
|
||||
"""Generate time series of CSI samples."""
|
||||
samples = []
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
for i in range(duration_seconds * self.sample_rate):
|
||||
timestamp = start_time + timedelta(seconds=i / self.sample_rate)
|
||||
|
||||
if scenario == "empty_room":
|
||||
sample = self.generate_empty_room_sample(timestamp)
|
||||
elif scenario.startswith("single_person_"):
|
||||
activity = scenario.replace("single_person_", "")
|
||||
sample = self.generate_single_person_sample(activity, timestamp)
|
||||
elif scenario.startswith("multi_person_"):
|
||||
num_persons = int(scenario.split("_")[-1])
|
||||
sample = self.generate_multi_person_sample(num_persons, timestamp=timestamp, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown scenario: {scenario}")
|
||||
|
||||
samples.append(sample)
|
||||
|
||||
return samples
|
||||
|
||||
def add_noise(self, sample: Dict[str, Any], noise_level: Optional[float] = None) -> Dict[str, Any]:
|
||||
"""Add noise to CSI sample."""
|
||||
if noise_level is None:
|
||||
noise_level = self.noise_level
|
||||
|
||||
noisy_sample = sample.copy()
|
||||
|
||||
# Add amplitude noise
|
||||
amplitude = np.array(sample["amplitude"])
|
||||
amplitude_noise = np.random.normal(0, noise_level, amplitude.shape)
|
||||
noisy_amplitude = amplitude + amplitude_noise
|
||||
noisy_amplitude = np.clip(noisy_amplitude, 0, 1)
|
||||
noisy_sample["amplitude"] = noisy_amplitude.tolist()
|
||||
|
||||
# Add phase noise
|
||||
phase = np.array(sample["phase"])
|
||||
phase_noise = np.random.normal(0, noise_level * np.pi, phase.shape)
|
||||
noisy_phase = phase + phase_noise
|
||||
noisy_phase = np.mod(noisy_phase + np.pi, 2 * np.pi) - np.pi
|
||||
noisy_sample["phase"] = noisy_phase.tolist()
|
||||
|
||||
# Reduce signal quality
|
||||
noisy_sample["signal_quality"] *= (1 - noise_level)
|
||||
|
||||
return noisy_sample
|
||||
|
||||
def simulate_hardware_artifacts(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Simulate hardware-specific artifacts."""
|
||||
artifact_sample = sample.copy()
|
||||
|
||||
amplitude = np.array(sample["amplitude"])
|
||||
phase = np.array(sample["phase"])
|
||||
|
||||
# Simulate antenna coupling
|
||||
coupling_matrix = np.random.uniform(0.95, 1.05, (self.num_antennas, self.num_antennas))
|
||||
amplitude = coupling_matrix @ amplitude
|
||||
|
||||
# Simulate frequency-dependent gain variations
|
||||
freq_response = 1 + 0.1 * np.sin(np.linspace(0, 2*np.pi, self.num_subcarriers))
|
||||
amplitude *= freq_response[np.newaxis, :]
|
||||
|
||||
# Simulate phase drift
|
||||
phase_drift = np.random.uniform(-0.1, 0.1) * np.arange(self.num_subcarriers)
|
||||
phase += phase_drift[np.newaxis, :]
|
||||
|
||||
# Clip and wrap
|
||||
amplitude = np.clip(amplitude, 0, 1)
|
||||
phase = np.mod(phase + np.pi, 2 * np.pi) - np.pi
|
||||
|
||||
artifact_sample["amplitude"] = amplitude.tolist()
|
||||
artifact_sample["phase"] = phase.tolist()
|
||||
|
||||
return artifact_sample
|
||||
|
||||
|
||||
# Convenience functions for common test scenarios
|
||||
def generate_fall_detection_sequence() -> List[Dict[str, Any]]:
|
||||
"""Generate CSI sequence showing fall detection scenario."""
|
||||
generator = CSIDataGenerator()
|
||||
|
||||
sequence = []
|
||||
|
||||
# Normal standing (5 seconds)
|
||||
sequence.extend(generator.generate_time_series(5, "single_person_standing"))
|
||||
|
||||
# Walking (3 seconds)
|
||||
sequence.extend(generator.generate_time_series(3, "single_person_walking"))
|
||||
|
||||
# Fall event (1 second transition)
|
||||
sequence.extend(generator.generate_time_series(1, "single_person_fallen"))
|
||||
|
||||
# Fallen state (3 seconds)
|
||||
sequence.extend(generator.generate_time_series(3, "single_person_fallen"))
|
||||
|
||||
return sequence
|
||||
|
||||
|
||||
def generate_multi_person_scenario() -> List[Dict[str, Any]]:
|
||||
"""Generate CSI sequence for multi-person scenario."""
|
||||
generator = CSIDataGenerator()
|
||||
|
||||
sequence = []
|
||||
|
||||
# Start with empty room
|
||||
sequence.extend(generator.generate_time_series(2, "empty_room"))
|
||||
|
||||
# One person enters
|
||||
sequence.extend(generator.generate_time_series(3, "single_person_walking"))
|
||||
|
||||
# Second person enters
|
||||
sequence.extend(generator.generate_time_series(5, "multi_person_2",
|
||||
activities=["standing", "walking"]))
|
||||
|
||||
# Third person enters
|
||||
sequence.extend(generator.generate_time_series(4, "multi_person_3",
|
||||
activities=["standing", "walking", "sitting"]))
|
||||
|
||||
return sequence
|
||||
|
||||
|
||||
def generate_noisy_environment_data() -> List[Dict[str, Any]]:
|
||||
"""Generate CSI data with various noise levels."""
|
||||
generator = CSIDataGenerator()
|
||||
|
||||
# Generate clean data
|
||||
clean_samples = generator.generate_time_series(5, "single_person_walking")
|
||||
|
||||
# Add different noise levels
|
||||
noisy_samples = []
|
||||
noise_levels = [0.05, 0.1, 0.2, 0.3]
|
||||
|
||||
for noise_level in noise_levels:
|
||||
for sample in clean_samples[:10]: # Take first 10 samples
|
||||
noisy_sample = generator.add_noise(sample, noise_level)
|
||||
noisy_samples.append(noisy_sample)
|
||||
|
||||
return noisy_samples
|
||||
|
||||
|
||||
def generate_hardware_test_data() -> List[Dict[str, Any]]:
|
||||
"""Generate CSI data with hardware artifacts."""
|
||||
generator = CSIDataGenerator()
|
||||
|
||||
# Generate base samples
|
||||
base_samples = generator.generate_time_series(3, "single_person_standing")
|
||||
|
||||
# Add hardware artifacts
|
||||
artifact_samples = []
|
||||
for sample in base_samples:
|
||||
artifact_sample = generator.simulate_hardware_artifacts(sample)
|
||||
artifact_samples.append(artifact_sample)
|
||||
|
||||
return artifact_samples
|
||||
|
||||
|
||||
# Test data validation utilities
|
||||
def validate_csi_sample(sample: Dict[str, Any]) -> bool:
|
||||
"""Validate CSI sample structure and data ranges."""
|
||||
required_fields = [
|
||||
"timestamp", "router_id", "amplitude", "phase",
|
||||
"frequency", "bandwidth", "num_antennas", "num_subcarriers"
|
||||
]
|
||||
|
||||
# Check required fields
|
||||
for field in required_fields:
|
||||
if field not in sample:
|
||||
return False
|
||||
|
||||
# Validate data types and ranges
|
||||
amplitude = np.array(sample["amplitude"])
|
||||
phase = np.array(sample["phase"])
|
||||
|
||||
# Check shapes
|
||||
expected_shape = (sample["num_antennas"], sample["num_subcarriers"])
|
||||
if amplitude.shape != expected_shape or phase.shape != expected_shape:
|
||||
return False
|
||||
|
||||
# Check value ranges
|
||||
if not (0 <= amplitude.min() and amplitude.max() <= 1):
|
||||
return False
|
||||
|
||||
if not (-np.pi <= phase.min() and phase.max() <= np.pi):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def extract_features_from_csi(sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Extract features from CSI sample for testing."""
|
||||
amplitude = np.array(sample["amplitude"])
|
||||
phase = np.array(sample["phase"])
|
||||
|
||||
features = {
|
||||
"amplitude_mean": float(np.mean(amplitude)),
|
||||
"amplitude_std": float(np.std(amplitude)),
|
||||
"amplitude_max": float(np.max(amplitude)),
|
||||
"amplitude_min": float(np.min(amplitude)),
|
||||
"phase_variance": float(np.var(phase)),
|
||||
"phase_range": float(np.max(phase) - np.min(phase)),
|
||||
"signal_energy": float(np.sum(amplitude ** 2)),
|
||||
"phase_coherence": float(np.abs(np.mean(np.exp(1j * phase)))),
|
||||
"spatial_correlation": float(np.mean(np.corrcoef(amplitude))),
|
||||
"frequency_diversity": float(np.std(np.mean(amplitude, axis=0)))
|
||||
}
|
||||
|
||||
return features
|
||||
Reference in New Issue
Block a user