Files
wifi-densepose/tests/fixtures/api_client.py
2025-06-07 11:44:19 +00:00

661 lines
23 KiB
Python

"""
Test client utilities for API testing.
Provides mock and real API clients for comprehensive testing.
"""
import asyncio
import aiohttp
import json
import time
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional, Union, AsyncGenerator
from unittest.mock import AsyncMock, MagicMock
import websockets
import jwt
from dataclasses import dataclass, asdict
from enum import Enum
class AuthenticationError(Exception):
"""Authentication related errors."""
pass
class APIError(Exception):
"""General API errors."""
pass
class RateLimitError(Exception):
"""Rate limiting errors."""
pass
@dataclass
class APIResponse:
"""API response wrapper."""
status_code: int
data: Dict[str, Any]
headers: Dict[str, str]
response_time_ms: float
timestamp: datetime
class MockAPIClient:
"""Mock API client for testing."""
def __init__(self, base_url: str = "http://localhost:8000"):
self.base_url = base_url
self.session = None
self.auth_token = None
self.refresh_token = None
self.token_expires_at = None
self.request_history = []
self.response_delays = {}
self.error_simulation = {}
self.rate_limit_config = {
"enabled": False,
"requests_per_minute": 60,
"current_count": 0,
"window_start": time.time()
}
async def __aenter__(self):
"""Async context manager entry."""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
await self.disconnect()
async def connect(self):
"""Initialize connection."""
self.session = aiohttp.ClientSession()
async def disconnect(self):
"""Close connection."""
if self.session:
await self.session.close()
def set_response_delay(self, endpoint: str, delay_ms: float):
"""Set artificial delay for endpoint."""
self.response_delays[endpoint] = delay_ms
def simulate_error(self, endpoint: str, error_type: str, probability: float = 1.0):
"""Simulate errors for endpoint."""
self.error_simulation[endpoint] = {
"type": error_type,
"probability": probability
}
def enable_rate_limiting(self, requests_per_minute: int = 60):
"""Enable rate limiting simulation."""
self.rate_limit_config.update({
"enabled": True,
"requests_per_minute": requests_per_minute,
"current_count": 0,
"window_start": time.time()
})
async def _check_rate_limit(self):
"""Check rate limiting."""
if not self.rate_limit_config["enabled"]:
return
current_time = time.time()
window_duration = 60 # 1 minute
# Reset window if needed
if current_time - self.rate_limit_config["window_start"] > window_duration:
self.rate_limit_config["current_count"] = 0
self.rate_limit_config["window_start"] = current_time
# Check limit
if self.rate_limit_config["current_count"] >= self.rate_limit_config["requests_per_minute"]:
raise RateLimitError("Rate limit exceeded")
self.rate_limit_config["current_count"] += 1
async def _simulate_network_delay(self, endpoint: str):
"""Simulate network delay."""
delay = self.response_delays.get(endpoint, 0)
if delay > 0:
await asyncio.sleep(delay / 1000) # Convert ms to seconds
async def _check_error_simulation(self, endpoint: str):
"""Check if error should be simulated."""
if endpoint in self.error_simulation:
config = self.error_simulation[endpoint]
if random.random() < config["probability"]:
error_type = config["type"]
if error_type == "timeout":
raise asyncio.TimeoutError("Simulated timeout")
elif error_type == "connection":
raise aiohttp.ClientConnectionError("Simulated connection error")
elif error_type == "server_error":
raise APIError("Simulated server error")
async def _make_request(self, method: str, endpoint: str, **kwargs) -> APIResponse:
"""Make HTTP request with simulation."""
start_time = time.time()
# Check rate limiting
await self._check_rate_limit()
# Simulate network delay
await self._simulate_network_delay(endpoint)
# Check error simulation
await self._check_error_simulation(endpoint)
# Record request
request_record = {
"method": method,
"endpoint": endpoint,
"timestamp": datetime.utcnow(),
"kwargs": kwargs
}
self.request_history.append(request_record)
# Generate mock response
response_data = await self._generate_mock_response(method, endpoint, kwargs)
end_time = time.time()
response_time = (end_time - start_time) * 1000
return APIResponse(
status_code=response_data["status_code"],
data=response_data["data"],
headers=response_data.get("headers", {}),
response_time_ms=response_time,
timestamp=datetime.utcnow()
)
async def _generate_mock_response(self, method: str, endpoint: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Generate mock response based on endpoint."""
if endpoint == "/health":
return {
"status_code": 200,
"data": {
"status": "healthy",
"timestamp": datetime.utcnow().isoformat(),
"version": "1.0.0"
}
}
elif endpoint == "/auth/login":
if method == "POST":
# Generate mock JWT tokens
payload = {
"user_id": "test_user",
"exp": datetime.utcnow() + timedelta(hours=1)
}
access_token = jwt.encode(payload, "secret", algorithm="HS256")
refresh_token = jwt.encode({"user_id": "test_user"}, "secret", algorithm="HS256")
self.auth_token = access_token
self.refresh_token = refresh_token
self.token_expires_at = payload["exp"]
return {
"status_code": 200,
"data": {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
"expires_in": 3600
}
}
elif endpoint == "/auth/refresh":
if method == "POST" and self.refresh_token:
# Generate new access token
payload = {
"user_id": "test_user",
"exp": datetime.utcnow() + timedelta(hours=1)
}
access_token = jwt.encode(payload, "secret", algorithm="HS256")
self.auth_token = access_token
self.token_expires_at = payload["exp"]
return {
"status_code": 200,
"data": {
"access_token": access_token,
"token_type": "bearer",
"expires_in": 3600
}
}
elif endpoint == "/pose/detect":
if method == "POST":
return {
"status_code": 200,
"data": {
"persons": [
{
"person_id": "person_1",
"confidence": 0.85,
"bounding_box": {"x": 100, "y": 150, "width": 80, "height": 180},
"keypoints": [[x, y, 0.9] for x, y in zip(range(17), range(17))],
"activity": "standing"
}
],
"processing_time_ms": 45.2,
"model_version": "v1.0",
"timestamp": datetime.utcnow().isoformat()
}
}
elif endpoint == "/config":
if method == "GET":
return {
"status_code": 200,
"data": {
"model_config": {
"confidence_threshold": 0.7,
"nms_threshold": 0.5,
"max_persons": 10
},
"processing_config": {
"batch_size": 1,
"use_gpu": True,
"preprocessing": "standard"
}
}
}
# Default response
return {
"status_code": 404,
"data": {"error": "Endpoint not found"}
}
async def get(self, endpoint: str, **kwargs) -> APIResponse:
"""Make GET request."""
return await self._make_request("GET", endpoint, **kwargs)
async def post(self, endpoint: str, **kwargs) -> APIResponse:
"""Make POST request."""
return await self._make_request("POST", endpoint, **kwargs)
async def put(self, endpoint: str, **kwargs) -> APIResponse:
"""Make PUT request."""
return await self._make_request("PUT", endpoint, **kwargs)
async def delete(self, endpoint: str, **kwargs) -> APIResponse:
"""Make DELETE request."""
return await self._make_request("DELETE", endpoint, **kwargs)
async def login(self, username: str, password: str) -> bool:
"""Authenticate with API."""
response = await self.post("/auth/login", json={
"username": username,
"password": password
})
if response.status_code == 200:
return True
else:
raise AuthenticationError("Login failed")
async def refresh_auth_token(self) -> bool:
"""Refresh authentication token."""
if not self.refresh_token:
raise AuthenticationError("No refresh token available")
response = await self.post("/auth/refresh", json={
"refresh_token": self.refresh_token
})
if response.status_code == 200:
return True
else:
raise AuthenticationError("Token refresh failed")
def is_authenticated(self) -> bool:
"""Check if client is authenticated."""
if not self.auth_token or not self.token_expires_at:
return False
return datetime.utcnow() < self.token_expires_at
def get_request_history(self) -> List[Dict[str, Any]]:
"""Get request history."""
return self.request_history.copy()
def clear_request_history(self):
"""Clear request history."""
self.request_history.clear()
class MockWebSocketClient:
"""Mock WebSocket client for testing."""
def __init__(self, uri: str = "ws://localhost:8000/ws"):
self.uri = uri
self.websocket = None
self.is_connected = False
self.messages_received = []
self.messages_sent = []
self.connection_errors = []
self.auto_respond = True
self.response_delay = 0.01 # 10ms default delay
async def connect(self) -> bool:
"""Connect to WebSocket."""
try:
# Simulate connection
await asyncio.sleep(0.01)
self.is_connected = True
return True
except Exception as e:
self.connection_errors.append(str(e))
return False
async def disconnect(self):
"""Disconnect from WebSocket."""
self.is_connected = False
self.websocket = None
async def send_message(self, message: Dict[str, Any]) -> bool:
"""Send message to WebSocket."""
if not self.is_connected:
raise ConnectionError("WebSocket not connected")
# Record sent message
self.messages_sent.append({
"message": message,
"timestamp": datetime.utcnow()
})
# Auto-respond if enabled
if self.auto_respond:
await asyncio.sleep(self.response_delay)
response = await self._generate_auto_response(message)
if response:
self.messages_received.append({
"message": response,
"timestamp": datetime.utcnow()
})
return True
async def receive_message(self, timeout: float = 1.0) -> Optional[Dict[str, Any]]:
"""Receive message from WebSocket."""
if not self.is_connected:
raise ConnectionError("WebSocket not connected")
# Wait for message or timeout
start_time = time.time()
while time.time() - start_time < timeout:
if self.messages_received:
return self.messages_received.pop(0)["message"]
await asyncio.sleep(0.01)
return None
async def _generate_auto_response(self, message: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Generate automatic response to message."""
message_type = message.get("type")
if message_type == "subscribe":
return {
"type": "subscription_confirmed",
"channel": message.get("channel"),
"timestamp": datetime.utcnow().isoformat()
}
elif message_type == "pose_request":
return {
"type": "pose_data",
"data": {
"persons": [
{
"person_id": "person_1",
"confidence": 0.88,
"bounding_box": {"x": 150, "y": 200, "width": 80, "height": 180},
"keypoints": [[x, y, 0.9] for x, y in zip(range(17), range(17))]
}
],
"timestamp": datetime.utcnow().isoformat()
},
"request_id": message.get("request_id")
}
elif message_type == "ping":
return {
"type": "pong",
"timestamp": datetime.utcnow().isoformat()
}
return None
def set_auto_respond(self, enabled: bool, delay_ms: float = 10):
"""Configure auto-response behavior."""
self.auto_respond = enabled
self.response_delay = delay_ms / 1000
def inject_message(self, message: Dict[str, Any]):
"""Inject message as if received from server."""
self.messages_received.append({
"message": message,
"timestamp": datetime.utcnow()
})
def get_sent_messages(self) -> List[Dict[str, Any]]:
"""Get all sent messages."""
return self.messages_sent.copy()
def get_received_messages(self) -> List[Dict[str, Any]]:
"""Get all received messages."""
return self.messages_received.copy()
def clear_message_history(self):
"""Clear message history."""
self.messages_sent.clear()
self.messages_received.clear()
class APITestClient:
"""High-level test client combining HTTP and WebSocket."""
def __init__(self, base_url: str = "http://localhost:8000"):
self.base_url = base_url
self.ws_url = base_url.replace("http", "ws") + "/ws"
self.http_client = MockAPIClient(base_url)
self.ws_client = MockWebSocketClient(self.ws_url)
self.test_session_id = None
async def __aenter__(self):
"""Async context manager entry."""
await self.setup()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
await self.teardown()
async def setup(self):
"""Setup test client."""
await self.http_client.connect()
await self.ws_client.connect()
self.test_session_id = f"test_session_{int(time.time())}"
async def teardown(self):
"""Teardown test client."""
await self.ws_client.disconnect()
await self.http_client.disconnect()
async def authenticate(self, username: str = "test_user", password: str = "test_pass") -> bool:
"""Authenticate with API."""
return await self.http_client.login(username, password)
async def test_health_endpoint(self) -> APIResponse:
"""Test health endpoint."""
return await self.http_client.get("/health")
async def test_pose_detection(self, csi_data: Dict[str, Any]) -> APIResponse:
"""Test pose detection endpoint."""
return await self.http_client.post("/pose/detect", json=csi_data)
async def test_websocket_streaming(self, duration_seconds: int = 5) -> List[Dict[str, Any]]:
"""Test WebSocket streaming."""
# Subscribe to pose stream
await self.ws_client.send_message({
"type": "subscribe",
"channel": "pose_stream",
"session_id": self.test_session_id
})
# Collect messages for specified duration
messages = []
end_time = time.time() + duration_seconds
while time.time() < end_time:
message = await self.ws_client.receive_message(timeout=0.1)
if message:
messages.append(message)
return messages
async def simulate_concurrent_requests(self, num_requests: int = 10) -> List[APIResponse]:
"""Simulate concurrent HTTP requests."""
tasks = []
for i in range(num_requests):
task = asyncio.create_task(self.http_client.get("/health"))
tasks.append(task)
responses = await asyncio.gather(*tasks, return_exceptions=True)
return responses
async def simulate_websocket_load(self, num_connections: int = 5, duration_seconds: int = 3) -> Dict[str, Any]:
"""Simulate WebSocket load testing."""
# Create multiple WebSocket clients
ws_clients = []
for i in range(num_connections):
client = MockWebSocketClient(self.ws_url)
await client.connect()
ws_clients.append(client)
# Send messages from all clients
message_counts = []
try:
tasks = []
for i, client in enumerate(ws_clients):
task = asyncio.create_task(self._send_messages_for_duration(client, duration_seconds, i))
tasks.append(task)
results = await asyncio.gather(*tasks)
message_counts = results
finally:
# Cleanup
for client in ws_clients:
await client.disconnect()
return {
"num_connections": num_connections,
"duration_seconds": duration_seconds,
"messages_per_connection": message_counts,
"total_messages": sum(message_counts)
}
async def _send_messages_for_duration(self, client: MockWebSocketClient, duration: int, client_id: int) -> int:
"""Send messages for specified duration."""
message_count = 0
end_time = time.time() + duration
while time.time() < end_time:
await client.send_message({
"type": "ping",
"client_id": client_id,
"message_id": message_count
})
message_count += 1
await asyncio.sleep(0.1) # 10 messages per second
return message_count
def configure_error_simulation(self, endpoint: str, error_type: str, probability: float = 0.1):
"""Configure error simulation for testing."""
self.http_client.simulate_error(endpoint, error_type, probability)
def configure_rate_limiting(self, requests_per_minute: int = 60):
"""Configure rate limiting for testing."""
self.http_client.enable_rate_limiting(requests_per_minute)
def get_performance_metrics(self) -> Dict[str, Any]:
"""Get performance metrics from test session."""
http_history = self.http_client.get_request_history()
ws_sent = self.ws_client.get_sent_messages()
ws_received = self.ws_client.get_received_messages()
# Calculate HTTP metrics
if http_history:
response_times = [r.get("response_time_ms", 0) for r in http_history]
http_metrics = {
"total_requests": len(http_history),
"avg_response_time_ms": sum(response_times) / len(response_times),
"min_response_time_ms": min(response_times),
"max_response_time_ms": max(response_times)
}
else:
http_metrics = {"total_requests": 0}
# Calculate WebSocket metrics
ws_metrics = {
"messages_sent": len(ws_sent),
"messages_received": len(ws_received),
"connection_active": self.ws_client.is_connected
}
return {
"session_id": self.test_session_id,
"http_metrics": http_metrics,
"websocket_metrics": ws_metrics,
"timestamp": datetime.utcnow().isoformat()
}
# Utility functions for test data generation
def generate_test_csi_data() -> Dict[str, Any]:
"""Generate test CSI data for API testing."""
import numpy as np
return {
"timestamp": datetime.utcnow().isoformat(),
"router_id": "test_router_001",
"amplitude": np.random.uniform(0, 1, (4, 64)).tolist(),
"phase": np.random.uniform(-np.pi, np.pi, (4, 64)).tolist(),
"frequency": 5.8e9,
"bandwidth": 80e6,
"num_antennas": 4,
"num_subcarriers": 64
}
def create_test_user_credentials() -> Dict[str, str]:
"""Create test user credentials."""
return {
"username": "test_user",
"password": "test_password_123",
"email": "test@example.com"
}
async def wait_for_condition(condition_func, timeout: float = 5.0, interval: float = 0.1) -> bool:
"""Wait for condition to become true."""
end_time = time.time() + timeout
while time.time() < end_time:
if await condition_func() if asyncio.iscoroutinefunction(condition_func) else condition_func():
return True
await asyncio.sleep(interval)
return False