updates
This commit is contained in:
338
tests/integration/test_api_endpoints.py
Normal file
338
tests/integration/test_api_endpoints.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""
|
||||
Integration tests for WiFi-DensePose API endpoints.
|
||||
|
||||
Tests all REST API endpoints with real service dependencies.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import FastAPI
|
||||
import httpx
|
||||
|
||||
from src.api.dependencies import (
|
||||
get_pose_service,
|
||||
get_stream_service,
|
||||
get_hardware_service,
|
||||
get_current_user
|
||||
)
|
||||
from src.api.routers.health import router as health_router
|
||||
from src.api.routers.pose import router as pose_router
|
||||
from src.api.routers.stream import router as stream_router
|
||||
|
||||
|
||||
class TestAPIEndpoints:
|
||||
"""Integration tests for API endpoints."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create FastAPI app with test dependencies."""
|
||||
app = FastAPI()
|
||||
app.include_router(health_router, prefix="/health", tags=["health"])
|
||||
app.include_router(pose_router, prefix="/pose", tags=["pose"])
|
||||
app.include_router(stream_router, prefix="/stream", tags=["stream"])
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pose_service(self):
|
||||
"""Mock pose service."""
|
||||
service = AsyncMock()
|
||||
service.health_check.return_value = {
|
||||
"status": "healthy",
|
||||
"message": "Service operational",
|
||||
"uptime_seconds": 3600.0,
|
||||
"metrics": {"processed_frames": 1000}
|
||||
}
|
||||
service.is_ready.return_value = True
|
||||
service.estimate_poses.return_value = {
|
||||
"timestamp": datetime.utcnow(),
|
||||
"frame_id": "test-frame-001",
|
||||
"persons": [],
|
||||
"zone_summary": {"zone1": 0},
|
||||
"processing_time_ms": 50.0,
|
||||
"metadata": {}
|
||||
}
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stream_service(self):
|
||||
"""Mock stream service."""
|
||||
service = AsyncMock()
|
||||
service.health_check.return_value = {
|
||||
"status": "healthy",
|
||||
"message": "Stream service operational",
|
||||
"uptime_seconds": 1800.0
|
||||
}
|
||||
service.is_ready.return_value = True
|
||||
service.get_status.return_value = {
|
||||
"is_active": True,
|
||||
"active_streams": [],
|
||||
"uptime_seconds": 1800.0
|
||||
}
|
||||
service.is_active.return_value = True
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def mock_hardware_service(self):
|
||||
"""Mock hardware service."""
|
||||
service = AsyncMock()
|
||||
service.health_check.return_value = {
|
||||
"status": "healthy",
|
||||
"message": "Hardware connected",
|
||||
"uptime_seconds": 7200.0,
|
||||
"metrics": {"connected_routers": 3}
|
||||
}
|
||||
service.is_ready.return_value = True
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user(self):
|
||||
"""Mock authenticated user."""
|
||||
return {
|
||||
"id": "test-user-001",
|
||||
"username": "testuser",
|
||||
"email": "test@example.com",
|
||||
"is_admin": False,
|
||||
"is_active": True,
|
||||
"permissions": ["read", "write"]
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def client(self, app, mock_pose_service, mock_stream_service, mock_hardware_service, mock_user):
|
||||
"""Create test client with mocked dependencies."""
|
||||
app.dependency_overrides[get_pose_service] = lambda: mock_pose_service
|
||||
app.dependency_overrides[get_stream_service] = lambda: mock_stream_service
|
||||
app.dependency_overrides[get_hardware_service] = lambda: mock_hardware_service
|
||||
app.dependency_overrides[get_current_user] = lambda: mock_user
|
||||
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
|
||||
def test_health_check_endpoint_should_fail_initially(self, client):
|
||||
"""Test health check endpoint - should fail initially."""
|
||||
# This test should fail because we haven't implemented the endpoint properly
|
||||
response = client.get("/health/health")
|
||||
|
||||
# This assertion will fail initially, driving us to implement the endpoint
|
||||
assert response.status_code == 200
|
||||
assert "status" in response.json()
|
||||
assert "components" in response.json()
|
||||
assert "system_metrics" in response.json()
|
||||
|
||||
def test_readiness_check_endpoint_should_fail_initially(self, client):
|
||||
"""Test readiness check endpoint - should fail initially."""
|
||||
response = client.get("/health/ready")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "ready" in data
|
||||
assert "checks" in data
|
||||
assert isinstance(data["checks"], dict)
|
||||
|
||||
def test_liveness_check_endpoint_should_fail_initially(self, client):
|
||||
"""Test liveness check endpoint - should fail initially."""
|
||||
response = client.get("/health/live")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "status" in data
|
||||
assert data["status"] == "alive"
|
||||
|
||||
def test_version_info_endpoint_should_fail_initially(self, client):
|
||||
"""Test version info endpoint - should fail initially."""
|
||||
response = client.get("/health/version")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "name" in data
|
||||
assert "version" in data
|
||||
assert "environment" in data
|
||||
|
||||
def test_pose_current_endpoint_should_fail_initially(self, client):
|
||||
"""Test current pose estimation endpoint - should fail initially."""
|
||||
response = client.get("/pose/current")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "timestamp" in data
|
||||
assert "frame_id" in data
|
||||
assert "persons" in data
|
||||
assert "zone_summary" in data
|
||||
|
||||
def test_pose_analyze_endpoint_should_fail_initially(self, client):
|
||||
"""Test pose analysis endpoint - should fail initially."""
|
||||
request_data = {
|
||||
"zone_ids": ["zone1", "zone2"],
|
||||
"confidence_threshold": 0.7,
|
||||
"max_persons": 10,
|
||||
"include_keypoints": True,
|
||||
"include_segmentation": False
|
||||
}
|
||||
|
||||
response = client.post("/pose/analyze", json=request_data)
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "timestamp" in data
|
||||
assert "persons" in data
|
||||
|
||||
def test_zone_occupancy_endpoint_should_fail_initially(self, client):
|
||||
"""Test zone occupancy endpoint - should fail initially."""
|
||||
response = client.get("/pose/zones/zone1/occupancy")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "zone_id" in data
|
||||
assert "current_occupancy" in data
|
||||
|
||||
def test_zones_summary_endpoint_should_fail_initially(self, client):
|
||||
"""Test zones summary endpoint - should fail initially."""
|
||||
response = client.get("/pose/zones/summary")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "total_persons" in data
|
||||
assert "zones" in data
|
||||
|
||||
def test_stream_status_endpoint_should_fail_initially(self, client):
|
||||
"""Test stream status endpoint - should fail initially."""
|
||||
response = client.get("/stream/status")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "is_active" in data
|
||||
assert "connected_clients" in data
|
||||
|
||||
def test_stream_start_endpoint_should_fail_initially(self, client):
|
||||
"""Test stream start endpoint - should fail initially."""
|
||||
response = client.post("/stream/start")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "message" in data
|
||||
|
||||
def test_stream_stop_endpoint_should_fail_initially(self, client):
|
||||
"""Test stream stop endpoint - should fail initially."""
|
||||
response = client.post("/stream/stop")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "message" in data
|
||||
|
||||
|
||||
class TestAPIErrorHandling:
|
||||
"""Test API error handling scenarios."""
|
||||
|
||||
@pytest.fixture
|
||||
def app_with_failing_services(self):
|
||||
"""Create app with failing service dependencies."""
|
||||
app = FastAPI()
|
||||
app.include_router(health_router, prefix="/health", tags=["health"])
|
||||
app.include_router(pose_router, prefix="/pose", tags=["pose"])
|
||||
|
||||
# Mock failing services
|
||||
failing_pose_service = AsyncMock()
|
||||
failing_pose_service.health_check.side_effect = Exception("Service unavailable")
|
||||
|
||||
app.dependency_overrides[get_pose_service] = lambda: failing_pose_service
|
||||
|
||||
return app
|
||||
|
||||
def test_health_check_with_failing_service_should_fail_initially(self, app_with_failing_services):
|
||||
"""Test health check with failing service - should fail initially."""
|
||||
with TestClient(app_with_failing_services) as client:
|
||||
response = client.get("/health/health")
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "unhealthy"
|
||||
assert "hardware" in data["components"]
|
||||
assert data["components"]["pose"]["status"] == "unhealthy"
|
||||
|
||||
|
||||
class TestAPIAuthentication:
|
||||
"""Test API authentication scenarios."""
|
||||
|
||||
@pytest.fixture
|
||||
def app_with_auth(self):
|
||||
"""Create app with authentication enabled."""
|
||||
app = FastAPI()
|
||||
app.include_router(pose_router, prefix="/pose", tags=["pose"])
|
||||
|
||||
# Mock authenticated user dependency
|
||||
def get_authenticated_user():
|
||||
return {
|
||||
"id": "auth-user-001",
|
||||
"username": "authuser",
|
||||
"is_admin": True,
|
||||
"permissions": ["read", "write", "admin"]
|
||||
}
|
||||
|
||||
app.dependency_overrides[get_current_user] = get_authenticated_user
|
||||
|
||||
return app
|
||||
|
||||
def test_authenticated_endpoint_access_should_fail_initially(self, app_with_auth):
|
||||
"""Test authenticated endpoint access - should fail initially."""
|
||||
with TestClient(app_with_auth) as client:
|
||||
response = client.post("/pose/analyze", json={
|
||||
"confidence_threshold": 0.8,
|
||||
"include_keypoints": True
|
||||
})
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestAPIValidation:
|
||||
"""Test API request validation."""
|
||||
|
||||
@pytest.fixture
|
||||
def validation_app(self):
|
||||
"""Create app for validation testing."""
|
||||
app = FastAPI()
|
||||
app.include_router(pose_router, prefix="/pose", tags=["pose"])
|
||||
|
||||
# Mock service
|
||||
mock_service = AsyncMock()
|
||||
app.dependency_overrides[get_pose_service] = lambda: mock_service
|
||||
|
||||
return app
|
||||
|
||||
def test_invalid_confidence_threshold_should_fail_initially(self, validation_app):
|
||||
"""Test invalid confidence threshold validation - should fail initially."""
|
||||
with TestClient(validation_app) as client:
|
||||
response = client.post("/pose/analyze", json={
|
||||
"confidence_threshold": 1.5, # Invalid: > 1.0
|
||||
"include_keypoints": True
|
||||
})
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 422
|
||||
assert "validation error" in response.json()["detail"][0]["msg"].lower()
|
||||
|
||||
def test_invalid_max_persons_should_fail_initially(self, validation_app):
|
||||
"""Test invalid max_persons validation - should fail initially."""
|
||||
with TestClient(validation_app) as client:
|
||||
response = client.post("/pose/analyze", json={
|
||||
"max_persons": 0, # Invalid: < 1
|
||||
"include_keypoints": True
|
||||
})
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 422
|
||||
571
tests/integration/test_authentication.py
Normal file
571
tests/integration/test_authentication.py
Normal file
@@ -0,0 +1,571 @@
|
||||
"""
|
||||
Integration tests for authentication and authorization.
|
||||
|
||||
Tests JWT authentication flow, user permissions, and access control.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import jwt
|
||||
import json
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
|
||||
|
||||
class MockJWTToken:
|
||||
"""Mock JWT token for testing."""
|
||||
|
||||
def __init__(self, payload: Dict[str, Any], secret: str = "test-secret"):
|
||||
self.payload = payload
|
||||
self.secret = secret
|
||||
self.token = jwt.encode(payload, secret, algorithm="HS256")
|
||||
|
||||
def decode(self, token: str, secret: str) -> Dict[str, Any]:
|
||||
"""Decode JWT token."""
|
||||
return jwt.decode(token, secret, algorithms=["HS256"])
|
||||
|
||||
|
||||
class TestJWTAuthentication:
|
||||
"""Test JWT authentication functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def valid_user_payload(self):
|
||||
"""Valid user payload for JWT token."""
|
||||
return {
|
||||
"sub": "user-001",
|
||||
"username": "testuser",
|
||||
"email": "test@example.com",
|
||||
"is_admin": False,
|
||||
"is_active": True,
|
||||
"permissions": ["read", "write"],
|
||||
"exp": datetime.utcnow() + timedelta(hours=1),
|
||||
"iat": datetime.utcnow()
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_payload(self):
|
||||
"""Admin user payload for JWT token."""
|
||||
return {
|
||||
"sub": "admin-001",
|
||||
"username": "admin",
|
||||
"email": "admin@example.com",
|
||||
"is_admin": True,
|
||||
"is_active": True,
|
||||
"permissions": ["read", "write", "admin"],
|
||||
"exp": datetime.utcnow() + timedelta(hours=1),
|
||||
"iat": datetime.utcnow()
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def expired_user_payload(self):
|
||||
"""Expired user payload for JWT token."""
|
||||
return {
|
||||
"sub": "user-002",
|
||||
"username": "expireduser",
|
||||
"email": "expired@example.com",
|
||||
"is_admin": False,
|
||||
"is_active": True,
|
||||
"permissions": ["read"],
|
||||
"exp": datetime.utcnow() - timedelta(hours=1), # Expired
|
||||
"iat": datetime.utcnow() - timedelta(hours=2)
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jwt_service(self):
|
||||
"""Mock JWT service."""
|
||||
class MockJWTService:
|
||||
def __init__(self):
|
||||
self.secret = "test-secret-key"
|
||||
self.algorithm = "HS256"
|
||||
|
||||
def create_token(self, user_data: Dict[str, Any]) -> str:
|
||||
"""Create JWT token."""
|
||||
payload = {
|
||||
**user_data,
|
||||
"exp": datetime.utcnow() + timedelta(hours=1),
|
||||
"iat": datetime.utcnow()
|
||||
}
|
||||
return jwt.encode(payload, self.secret, algorithm=self.algorithm)
|
||||
|
||||
def verify_token(self, token: str) -> Dict[str, Any]:
|
||||
"""Verify JWT token."""
|
||||
try:
|
||||
payload = jwt.decode(token, self.secret, algorithms=[self.algorithm])
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has expired"
|
||||
)
|
||||
except jwt.InvalidTokenError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token"
|
||||
)
|
||||
|
||||
def refresh_token(self, token: str) -> str:
|
||||
"""Refresh JWT token."""
|
||||
payload = self.verify_token(token)
|
||||
# Remove exp and iat for new token
|
||||
payload.pop("exp", None)
|
||||
payload.pop("iat", None)
|
||||
return self.create_token(payload)
|
||||
|
||||
return MockJWTService()
|
||||
|
||||
def test_jwt_token_creation_should_fail_initially(self, mock_jwt_service, valid_user_payload):
|
||||
"""Test JWT token creation - should fail initially."""
|
||||
token = mock_jwt_service.create_token(valid_user_payload)
|
||||
|
||||
# This will fail initially
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 0
|
||||
|
||||
# Verify token can be decoded
|
||||
decoded = mock_jwt_service.verify_token(token)
|
||||
assert decoded["sub"] == valid_user_payload["sub"]
|
||||
assert decoded["username"] == valid_user_payload["username"]
|
||||
|
||||
def test_jwt_token_verification_should_fail_initially(self, mock_jwt_service, valid_user_payload):
|
||||
"""Test JWT token verification - should fail initially."""
|
||||
token = mock_jwt_service.create_token(valid_user_payload)
|
||||
decoded = mock_jwt_service.verify_token(token)
|
||||
|
||||
# This will fail initially
|
||||
assert decoded["sub"] == valid_user_payload["sub"]
|
||||
assert decoded["is_admin"] == valid_user_payload["is_admin"]
|
||||
assert "exp" in decoded
|
||||
assert "iat" in decoded
|
||||
|
||||
def test_expired_token_rejection_should_fail_initially(self, mock_jwt_service, expired_user_payload):
|
||||
"""Test expired token rejection - should fail initially."""
|
||||
# Create token with expired payload
|
||||
token = jwt.encode(expired_user_payload, mock_jwt_service.secret, algorithm=mock_jwt_service.algorithm)
|
||||
|
||||
# This should fail initially
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
mock_jwt_service.verify_token(token)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert "expired" in exc_info.value.detail.lower()
|
||||
|
||||
def test_invalid_token_rejection_should_fail_initially(self, mock_jwt_service):
|
||||
"""Test invalid token rejection - should fail initially."""
|
||||
invalid_token = "invalid.jwt.token"
|
||||
|
||||
# This should fail initially
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
mock_jwt_service.verify_token(invalid_token)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert "invalid" in exc_info.value.detail.lower()
|
||||
|
||||
def test_token_refresh_should_fail_initially(self, mock_jwt_service, valid_user_payload):
|
||||
"""Test token refresh functionality - should fail initially."""
|
||||
original_token = mock_jwt_service.create_token(valid_user_payload)
|
||||
|
||||
# Wait a moment to ensure different timestamps
|
||||
import time
|
||||
time.sleep(0.1)
|
||||
|
||||
refreshed_token = mock_jwt_service.refresh_token(original_token)
|
||||
|
||||
# This will fail initially
|
||||
assert refreshed_token != original_token
|
||||
|
||||
# Verify both tokens are valid but have different timestamps
|
||||
original_payload = mock_jwt_service.verify_token(original_token)
|
||||
refreshed_payload = mock_jwt_service.verify_token(refreshed_token)
|
||||
|
||||
assert original_payload["sub"] == refreshed_payload["sub"]
|
||||
assert original_payload["iat"] != refreshed_payload["iat"]
|
||||
|
||||
|
||||
class TestUserAuthentication:
|
||||
"""Test user authentication scenarios."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_service(self):
|
||||
"""Mock user service."""
|
||||
class MockUserService:
|
||||
def __init__(self):
|
||||
self.users = {
|
||||
"testuser": {
|
||||
"id": "user-001",
|
||||
"username": "testuser",
|
||||
"email": "test@example.com",
|
||||
"password_hash": "hashed_password",
|
||||
"is_admin": False,
|
||||
"is_active": True,
|
||||
"permissions": ["read", "write"],
|
||||
"zones": ["zone1", "zone2"],
|
||||
"created_at": datetime.utcnow()
|
||||
},
|
||||
"admin": {
|
||||
"id": "admin-001",
|
||||
"username": "admin",
|
||||
"email": "admin@example.com",
|
||||
"password_hash": "admin_hashed_password",
|
||||
"is_admin": True,
|
||||
"is_active": True,
|
||||
"permissions": ["read", "write", "admin"],
|
||||
"zones": [], # Admin has access to all zones
|
||||
"created_at": datetime.utcnow()
|
||||
}
|
||||
}
|
||||
|
||||
async def authenticate_user(self, username: str, password: str) -> Optional[Dict[str, Any]]:
|
||||
"""Authenticate user with username and password."""
|
||||
user = self.users.get(username)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
# Mock password verification
|
||||
if password == "correct_password":
|
||||
return user
|
||||
return None
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get user by ID."""
|
||||
for user in self.users.values():
|
||||
if user["id"] == user_id:
|
||||
return user
|
||||
return None
|
||||
|
||||
async def update_user_activity(self, user_id: str):
|
||||
"""Update user last activity."""
|
||||
user = await self.get_user_by_id(user_id)
|
||||
if user:
|
||||
user["last_activity"] = datetime.utcnow()
|
||||
|
||||
return MockUserService()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_authentication_success_should_fail_initially(self, mock_user_service):
|
||||
"""Test successful user authentication - should fail initially."""
|
||||
user = await mock_user_service.authenticate_user("testuser", "correct_password")
|
||||
|
||||
# This will fail initially
|
||||
assert user is not None
|
||||
assert user["username"] == "testuser"
|
||||
assert user["is_active"] is True
|
||||
assert "read" in user["permissions"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_authentication_failure_should_fail_initially(self, mock_user_service):
|
||||
"""Test failed user authentication - should fail initially."""
|
||||
user = await mock_user_service.authenticate_user("testuser", "wrong_password")
|
||||
|
||||
# This will fail initially
|
||||
assert user is None
|
||||
|
||||
# Test with non-existent user
|
||||
user = await mock_user_service.authenticate_user("nonexistent", "any_password")
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_user_authentication_should_fail_initially(self, mock_user_service):
|
||||
"""Test admin user authentication - should fail initially."""
|
||||
admin = await mock_user_service.authenticate_user("admin", "correct_password")
|
||||
|
||||
# This will fail initially
|
||||
assert admin is not None
|
||||
assert admin["is_admin"] is True
|
||||
assert "admin" in admin["permissions"]
|
||||
assert admin["zones"] == [] # Admin has access to all zones
|
||||
|
||||
|
||||
class TestAuthorizationDependencies:
|
||||
"""Test authorization dependency functions."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request(self):
|
||||
"""Mock FastAPI request."""
|
||||
class MockRequest:
|
||||
def __init__(self):
|
||||
self.state = MagicMock()
|
||||
self.state.user = None
|
||||
|
||||
return MockRequest()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credentials(self):
|
||||
"""Mock HTTP authorization credentials."""
|
||||
def create_credentials(token: str):
|
||||
return HTTPAuthorizationCredentials(
|
||||
scheme="Bearer",
|
||||
credentials=token
|
||||
)
|
||||
return create_credentials
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_with_valid_token_should_fail_initially(self, mock_request, mock_credentials):
|
||||
"""Test get_current_user with valid token - should fail initially."""
|
||||
# Mock the get_current_user dependency
|
||||
async def mock_get_current_user(request, credentials):
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
# Mock token validation
|
||||
if credentials.credentials == "valid_token":
|
||||
return {
|
||||
"id": "user-001",
|
||||
"username": "testuser",
|
||||
"is_admin": False,
|
||||
"is_active": True,
|
||||
"permissions": ["read", "write"]
|
||||
}
|
||||
return None
|
||||
|
||||
credentials = mock_credentials("valid_token")
|
||||
user = await mock_get_current_user(mock_request, credentials)
|
||||
|
||||
# This will fail initially
|
||||
assert user is not None
|
||||
assert user["username"] == "testuser"
|
||||
assert user["is_active"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_without_credentials_should_fail_initially(self, mock_request):
|
||||
"""Test get_current_user without credentials - should fail initially."""
|
||||
async def mock_get_current_user(request, credentials):
|
||||
if not credentials:
|
||||
return None
|
||||
return {"id": "user-001"}
|
||||
|
||||
user = await mock_get_current_user(mock_request, None)
|
||||
|
||||
# This will fail initially
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_active_user_should_fail_initially(self):
|
||||
"""Test require active user dependency - should fail initially."""
|
||||
async def mock_get_current_active_user(current_user):
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required"
|
||||
)
|
||||
|
||||
if not current_user.get("is_active", True):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
# Test with active user
|
||||
active_user = {"id": "user-001", "is_active": True}
|
||||
result = await mock_get_current_active_user(active_user)
|
||||
|
||||
# This will fail initially
|
||||
assert result == active_user
|
||||
|
||||
# Test with inactive user
|
||||
inactive_user = {"id": "user-002", "is_active": False}
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await mock_get_current_active_user(inactive_user)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
# Test with no user
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await mock_get_current_active_user(None)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_require_admin_user_should_fail_initially(self):
|
||||
"""Test require admin user dependency - should fail initially."""
|
||||
async def mock_get_admin_user(current_user):
|
||||
if not current_user.get("is_admin", False):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin privileges required"
|
||||
)
|
||||
return current_user
|
||||
|
||||
# Test with admin user
|
||||
admin_user = {"id": "admin-001", "is_admin": True}
|
||||
result = await mock_get_admin_user(admin_user)
|
||||
|
||||
# This will fail initially
|
||||
assert result == admin_user
|
||||
|
||||
# Test with regular user
|
||||
regular_user = {"id": "user-001", "is_admin": False}
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await mock_get_admin_user(regular_user)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_checking_should_fail_initially(self):
|
||||
"""Test permission checking functionality - should fail initially."""
|
||||
def require_permission(permission: str):
|
||||
async def check_permission(current_user):
|
||||
user_permissions = current_user.get("permissions", [])
|
||||
|
||||
# Admin users have all permissions
|
||||
if current_user.get("is_admin", False):
|
||||
return current_user
|
||||
|
||||
# Check specific permission
|
||||
if permission not in user_permissions:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Permission '{permission}' required"
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
return check_permission
|
||||
|
||||
# Test with user having required permission
|
||||
user_with_permission = {
|
||||
"id": "user-001",
|
||||
"permissions": ["read", "write"],
|
||||
"is_admin": False
|
||||
}
|
||||
|
||||
check_read = require_permission("read")
|
||||
result = await check_read(user_with_permission)
|
||||
|
||||
# This will fail initially
|
||||
assert result == user_with_permission
|
||||
|
||||
# Test with user missing permission
|
||||
check_admin = require_permission("admin")
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await check_admin(user_with_permission)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert "admin" in exc_info.value.detail
|
||||
|
||||
# Test with admin user (should have all permissions)
|
||||
admin_user = {"id": "admin-001", "is_admin": True, "permissions": ["read"]}
|
||||
result = await check_admin(admin_user)
|
||||
assert result == admin_user
|
||||
|
||||
|
||||
class TestZoneAndRouterAccess:
|
||||
"""Test zone and router access control."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_domain_config(self):
|
||||
"""Mock domain configuration."""
|
||||
class MockDomainConfig:
|
||||
def __init__(self):
|
||||
self.zones = {
|
||||
"zone1": {"id": "zone1", "name": "Zone 1", "enabled": True},
|
||||
"zone2": {"id": "zone2", "name": "Zone 2", "enabled": True},
|
||||
"zone3": {"id": "zone3", "name": "Zone 3", "enabled": False}
|
||||
}
|
||||
self.routers = {
|
||||
"router1": {"id": "router1", "name": "Router 1", "enabled": True},
|
||||
"router2": {"id": "router2", "name": "Router 2", "enabled": False}
|
||||
}
|
||||
|
||||
def get_zone(self, zone_id: str):
|
||||
return self.zones.get(zone_id)
|
||||
|
||||
def get_router(self, router_id: str):
|
||||
return self.routers.get(router_id)
|
||||
|
||||
return MockDomainConfig()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zone_access_validation_should_fail_initially(self, mock_domain_config):
|
||||
"""Test zone access validation - should fail initially."""
|
||||
async def validate_zone_access(zone_id: str, current_user=None):
|
||||
zone = mock_domain_config.get_zone(zone_id)
|
||||
if not zone:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Zone '{zone_id}' not found"
|
||||
)
|
||||
|
||||
if not zone["enabled"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Zone '{zone_id}' is disabled"
|
||||
)
|
||||
|
||||
if current_user:
|
||||
if current_user.get("is_admin", False):
|
||||
return zone_id
|
||||
|
||||
user_zones = current_user.get("zones", [])
|
||||
if user_zones and zone_id not in user_zones:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Access denied to zone '{zone_id}'"
|
||||
)
|
||||
|
||||
return zone_id
|
||||
|
||||
# Test valid zone access
|
||||
result = await validate_zone_access("zone1")
|
||||
|
||||
# This will fail initially
|
||||
assert result == "zone1"
|
||||
|
||||
# Test invalid zone
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await validate_zone_access("nonexistent")
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
# Test disabled zone
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await validate_zone_access("zone3")
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
# Test user with zone access
|
||||
user_with_access = {"id": "user-001", "zones": ["zone1", "zone2"]}
|
||||
result = await validate_zone_access("zone1", user_with_access)
|
||||
assert result == "zone1"
|
||||
|
||||
# Test user without zone access
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await validate_zone_access("zone2", {"id": "user-002", "zones": ["zone1"]})
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_access_validation_should_fail_initially(self, mock_domain_config):
|
||||
"""Test router access validation - should fail initially."""
|
||||
async def validate_router_access(router_id: str, current_user=None):
|
||||
router = mock_domain_config.get_router(router_id)
|
||||
if not router:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Router '{router_id}' not found"
|
||||
)
|
||||
|
||||
if not router["enabled"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Router '{router_id}' is disabled"
|
||||
)
|
||||
|
||||
return router_id
|
||||
|
||||
# Test valid router access
|
||||
result = await validate_router_access("router1")
|
||||
|
||||
# This will fail initially
|
||||
assert result == "router1"
|
||||
|
||||
# Test disabled router
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await validate_router_access("router2")
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
|
||||
447
tests/integration/test_full_system_integration.py
Normal file
447
tests/integration/test_full_system_integration.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""
|
||||
Full system integration tests for WiFi-DensePose API
|
||||
Tests the complete integration of all components working together.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
import httpx
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from src.config.settings import get_settings
|
||||
from src.app import app
|
||||
from src.database.connection import get_database_manager
|
||||
from src.services.orchestrator import get_service_orchestrator
|
||||
from src.tasks.cleanup import get_cleanup_manager
|
||||
from src.tasks.monitoring import get_monitoring_manager
|
||||
from src.tasks.backup import get_backup_manager
|
||||
|
||||
|
||||
class TestFullSystemIntegration:
|
||||
"""Test complete system integration."""
|
||||
|
||||
@pytest.fixture
|
||||
async def settings(self):
|
||||
"""Get test settings."""
|
||||
settings = get_settings()
|
||||
settings.environment = "test"
|
||||
settings.debug = True
|
||||
settings.database_url = "sqlite+aiosqlite:///test_integration.db"
|
||||
settings.redis_enabled = False
|
||||
return settings
|
||||
|
||||
@pytest.fixture
|
||||
async def db_manager(self, settings):
|
||||
"""Get database manager for testing."""
|
||||
manager = get_database_manager(settings)
|
||||
await manager.initialize()
|
||||
yield manager
|
||||
await manager.close_all_connections()
|
||||
|
||||
@pytest.fixture
|
||||
async def client(self, settings):
|
||||
"""Get test HTTP client."""
|
||||
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
@pytest.fixture
|
||||
async def orchestrator(self, settings, db_manager):
|
||||
"""Get service orchestrator for testing."""
|
||||
orchestrator = get_service_orchestrator(settings)
|
||||
await orchestrator.initialize()
|
||||
yield orchestrator
|
||||
await orchestrator.shutdown()
|
||||
|
||||
async def test_application_startup_and_shutdown(self, settings, db_manager):
|
||||
"""Test complete application startup and shutdown sequence."""
|
||||
|
||||
# Test database initialization
|
||||
await db_manager.test_connection()
|
||||
stats = await db_manager.get_connection_stats()
|
||||
assert stats["database"]["connected"] is True
|
||||
|
||||
# Test service orchestrator initialization
|
||||
orchestrator = get_service_orchestrator(settings)
|
||||
await orchestrator.initialize()
|
||||
|
||||
# Verify services are running
|
||||
health_status = await orchestrator.get_health_status()
|
||||
assert health_status["status"] in ["healthy", "warning"]
|
||||
|
||||
# Test graceful shutdown
|
||||
await orchestrator.shutdown()
|
||||
|
||||
# Verify cleanup
|
||||
final_stats = await db_manager.get_connection_stats()
|
||||
assert final_stats is not None
|
||||
|
||||
async def test_api_endpoints_integration(self, client, settings, db_manager):
|
||||
"""Test API endpoints work with database integration."""
|
||||
|
||||
# Test health endpoint
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
health_data = response.json()
|
||||
assert "status" in health_data
|
||||
assert "timestamp" in health_data
|
||||
|
||||
# Test metrics endpoint
|
||||
response = await client.get("/metrics")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Test devices endpoint
|
||||
response = await client.get("/api/v1/devices")
|
||||
assert response.status_code == 200
|
||||
devices_data = response.json()
|
||||
assert "devices" in devices_data
|
||||
assert isinstance(devices_data["devices"], list)
|
||||
|
||||
# Test sessions endpoint
|
||||
response = await client.get("/api/v1/sessions")
|
||||
assert response.status_code == 200
|
||||
sessions_data = response.json()
|
||||
assert "sessions" in sessions_data
|
||||
assert isinstance(sessions_data["sessions"], list)
|
||||
|
||||
@patch('src.core.router_interface.RouterInterface')
|
||||
@patch('src.core.csi_processor.CSIProcessor')
|
||||
@patch('src.core.pose_estimator.PoseEstimator')
|
||||
async def test_data_processing_pipeline(
|
||||
self,
|
||||
mock_pose_estimator,
|
||||
mock_csi_processor,
|
||||
mock_router_interface,
|
||||
client,
|
||||
settings,
|
||||
db_manager
|
||||
):
|
||||
"""Test complete data processing pipeline integration."""
|
||||
|
||||
# Setup mocks
|
||||
mock_router = MagicMock()
|
||||
mock_router_interface.return_value = mock_router
|
||||
mock_router.connect.return_value = True
|
||||
mock_router.start_capture.return_value = True
|
||||
mock_router.get_csi_data.return_value = {
|
||||
"timestamp": time.time(),
|
||||
"csi_matrix": [[1.0, 2.0], [3.0, 4.0]],
|
||||
"rssi": -45,
|
||||
"noise_floor": -90
|
||||
}
|
||||
|
||||
mock_processor = MagicMock()
|
||||
mock_csi_processor.return_value = mock_processor
|
||||
mock_processor.process_csi_data.return_value = {
|
||||
"processed_csi": [[1.1, 2.1], [3.1, 4.1]],
|
||||
"quality_score": 0.85,
|
||||
"phase_sanitized": True
|
||||
}
|
||||
|
||||
mock_estimator = MagicMock()
|
||||
mock_pose_estimator.return_value = mock_estimator
|
||||
mock_estimator.estimate_pose.return_value = {
|
||||
"pose_data": {
|
||||
"keypoints": [[100, 200], [150, 250]],
|
||||
"confidence": 0.9
|
||||
},
|
||||
"processing_time": 0.05
|
||||
}
|
||||
|
||||
# Test device registration
|
||||
device_data = {
|
||||
"name": "test_router",
|
||||
"ip_address": "192.168.1.1",
|
||||
"device_type": "router",
|
||||
"model": "test_model"
|
||||
}
|
||||
|
||||
response = await client.post("/api/v1/devices", json=device_data)
|
||||
assert response.status_code == 201
|
||||
device_response = response.json()
|
||||
device_id = device_response["device"]["id"]
|
||||
|
||||
# Test session creation
|
||||
session_data = {
|
||||
"device_id": device_id,
|
||||
"session_type": "pose_detection",
|
||||
"configuration": {
|
||||
"sampling_rate": 1000,
|
||||
"duration": 60
|
||||
}
|
||||
}
|
||||
|
||||
response = await client.post("/api/v1/sessions", json=session_data)
|
||||
assert response.status_code == 201
|
||||
session_response = response.json()
|
||||
session_id = session_response["session"]["id"]
|
||||
|
||||
# Test CSI data submission
|
||||
csi_data = {
|
||||
"session_id": session_id,
|
||||
"timestamp": time.time(),
|
||||
"csi_matrix": [[1.0, 2.0], [3.0, 4.0]],
|
||||
"rssi": -45,
|
||||
"noise_floor": -90
|
||||
}
|
||||
|
||||
response = await client.post("/api/v1/csi-data", json=csi_data)
|
||||
assert response.status_code == 201
|
||||
|
||||
# Test pose detection retrieval
|
||||
response = await client.get(f"/api/v1/sessions/{session_id}/pose-detections")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Test session completion
|
||||
response = await client.patch(
|
||||
f"/api/v1/sessions/{session_id}",
|
||||
json={"status": "completed"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_background_tasks_integration(self, settings, db_manager):
|
||||
"""Test background tasks integration."""
|
||||
|
||||
# Test cleanup manager
|
||||
cleanup_manager = get_cleanup_manager(settings)
|
||||
cleanup_stats = cleanup_manager.get_stats()
|
||||
assert "manager" in cleanup_stats
|
||||
|
||||
# Run cleanup task
|
||||
cleanup_result = await cleanup_manager.run_all_tasks()
|
||||
assert cleanup_result["success"] is True
|
||||
|
||||
# Test monitoring manager
|
||||
monitoring_manager = get_monitoring_manager(settings)
|
||||
monitoring_stats = monitoring_manager.get_stats()
|
||||
assert "manager" in monitoring_stats
|
||||
|
||||
# Run monitoring task
|
||||
monitoring_result = await monitoring_manager.run_all_tasks()
|
||||
assert monitoring_result["success"] is True
|
||||
|
||||
# Test backup manager
|
||||
backup_manager = get_backup_manager(settings)
|
||||
backup_stats = backup_manager.get_stats()
|
||||
assert "manager" in backup_stats
|
||||
|
||||
# Run backup task
|
||||
backup_result = await backup_manager.run_all_tasks()
|
||||
assert backup_result["success"] is True
|
||||
|
||||
async def test_error_handling_integration(self, client, settings, db_manager):
|
||||
"""Test error handling across the system."""
|
||||
|
||||
# Test invalid device creation
|
||||
invalid_device_data = {
|
||||
"name": "", # Invalid empty name
|
||||
"ip_address": "invalid_ip",
|
||||
"device_type": "unknown_type"
|
||||
}
|
||||
|
||||
response = await client.post("/api/v1/devices", json=invalid_device_data)
|
||||
assert response.status_code == 422
|
||||
error_response = response.json()
|
||||
assert "detail" in error_response
|
||||
|
||||
# Test non-existent resource access
|
||||
response = await client.get("/api/v1/devices/99999")
|
||||
assert response.status_code == 404
|
||||
|
||||
# Test invalid session creation
|
||||
invalid_session_data = {
|
||||
"device_id": "invalid_uuid",
|
||||
"session_type": "invalid_type"
|
||||
}
|
||||
|
||||
response = await client.post("/api/v1/sessions", json=invalid_session_data)
|
||||
assert response.status_code == 422
|
||||
|
||||
async def test_authentication_and_authorization(self, client, settings):
|
||||
"""Test authentication and authorization integration."""
|
||||
|
||||
# Test protected endpoint without authentication
|
||||
response = await client.get("/api/v1/admin/system-info")
|
||||
assert response.status_code in [401, 403]
|
||||
|
||||
# Test with invalid token
|
||||
headers = {"Authorization": "Bearer invalid_token"}
|
||||
response = await client.get("/api/v1/admin/system-info", headers=headers)
|
||||
assert response.status_code in [401, 403]
|
||||
|
||||
async def test_rate_limiting_integration(self, client, settings):
|
||||
"""Test rate limiting integration."""
|
||||
|
||||
# Make multiple rapid requests to test rate limiting
|
||||
responses = []
|
||||
for i in range(10):
|
||||
response = await client.get("/health")
|
||||
responses.append(response.status_code)
|
||||
|
||||
# Should have at least some successful responses
|
||||
assert 200 in responses
|
||||
|
||||
# Rate limiting might kick in for some requests
|
||||
# This depends on the rate limiting configuration
|
||||
|
||||
async def test_monitoring_and_metrics_integration(self, client, settings, db_manager):
|
||||
"""Test monitoring and metrics collection integration."""
|
||||
|
||||
# Test metrics endpoint
|
||||
response = await client.get("/metrics")
|
||||
assert response.status_code == 200
|
||||
metrics_text = response.text
|
||||
|
||||
# Check for Prometheus format metrics
|
||||
assert "# HELP" in metrics_text
|
||||
assert "# TYPE" in metrics_text
|
||||
|
||||
# Test health check with detailed information
|
||||
response = await client.get("/health?detailed=true")
|
||||
assert response.status_code == 200
|
||||
health_data = response.json()
|
||||
|
||||
assert "database" in health_data
|
||||
assert "services" in health_data
|
||||
assert "system" in health_data
|
||||
|
||||
async def test_configuration_management_integration(self, settings):
|
||||
"""Test configuration management integration."""
|
||||
|
||||
# Test settings validation
|
||||
assert settings.environment == "test"
|
||||
assert settings.debug is True
|
||||
|
||||
# Test database URL configuration
|
||||
assert "test_integration.db" in settings.database_url
|
||||
|
||||
# Test Redis configuration
|
||||
assert settings.redis_enabled is False
|
||||
|
||||
# Test logging configuration
|
||||
assert settings.log_level in ["DEBUG", "INFO", "WARNING", "ERROR"]
|
||||
|
||||
async def test_database_migration_integration(self, settings, db_manager):
|
||||
"""Test database migration integration."""
|
||||
|
||||
# Test database connection
|
||||
await db_manager.test_connection()
|
||||
|
||||
# Test table creation
|
||||
async with db_manager.get_async_session() as session:
|
||||
from sqlalchemy import text
|
||||
|
||||
# Check if tables exist
|
||||
tables_query = text("""
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND name NOT LIKE 'sqlite_%'
|
||||
""")
|
||||
|
||||
result = await session.execute(tables_query)
|
||||
tables = [row[0] for row in result.fetchall()]
|
||||
|
||||
# Should have our main tables
|
||||
expected_tables = ["devices", "sessions", "csi_data", "pose_detections"]
|
||||
for table in expected_tables:
|
||||
assert table in tables
|
||||
|
||||
async def test_concurrent_operations_integration(self, client, settings, db_manager):
|
||||
"""Test concurrent operations integration."""
|
||||
|
||||
async def create_device(name: str):
|
||||
device_data = {
|
||||
"name": f"test_device_{name}",
|
||||
"ip_address": f"192.168.1.{name}",
|
||||
"device_type": "router",
|
||||
"model": "test_model"
|
||||
}
|
||||
response = await client.post("/api/v1/devices", json=device_data)
|
||||
return response.status_code
|
||||
|
||||
# Create multiple devices concurrently
|
||||
tasks = [create_device(str(i)) for i in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All should succeed
|
||||
assert all(status == 201 for status in results)
|
||||
|
||||
# Verify all devices were created
|
||||
response = await client.get("/api/v1/devices")
|
||||
assert response.status_code == 200
|
||||
devices_data = response.json()
|
||||
assert len(devices_data["devices"]) >= 5
|
||||
|
||||
async def test_system_resource_management(self, settings, db_manager, orchestrator):
|
||||
"""Test system resource management integration."""
|
||||
|
||||
# Test connection pool management
|
||||
stats = await db_manager.get_connection_stats()
|
||||
assert "database" in stats
|
||||
assert "pool_size" in stats["database"]
|
||||
|
||||
# Test service resource usage
|
||||
health_status = await orchestrator.get_health_status()
|
||||
assert "memory_usage" in health_status
|
||||
assert "cpu_usage" in health_status
|
||||
|
||||
# Test cleanup of resources
|
||||
await orchestrator.cleanup_resources()
|
||||
|
||||
# Verify resources are cleaned up
|
||||
final_stats = await db_manager.get_connection_stats()
|
||||
assert final_stats is not None
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestSystemPerformance:
|
||||
"""Test system performance under load."""
|
||||
|
||||
async def test_api_response_times(self, client):
|
||||
"""Test API response times under normal load."""
|
||||
|
||||
start_time = time.time()
|
||||
response = await client.get("/health")
|
||||
end_time = time.time()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert (end_time - start_time) < 1.0 # Should respond within 1 second
|
||||
|
||||
async def test_database_query_performance(self, db_manager):
|
||||
"""Test database query performance."""
|
||||
|
||||
async with db_manager.get_async_session() as session:
|
||||
from sqlalchemy import text
|
||||
|
||||
start_time = time.time()
|
||||
result = await session.execute(text("SELECT 1"))
|
||||
end_time = time.time()
|
||||
|
||||
assert result.scalar() == 1
|
||||
assert (end_time - start_time) < 0.1 # Should complete within 100ms
|
||||
|
||||
async def test_memory_usage_stability(self, orchestrator):
|
||||
"""Test memory usage remains stable."""
|
||||
|
||||
import psutil
|
||||
import os
|
||||
|
||||
process = psutil.Process(os.getpid())
|
||||
initial_memory = process.memory_info().rss
|
||||
|
||||
# Perform some operations
|
||||
for _ in range(10):
|
||||
health_status = await orchestrator.get_health_status()
|
||||
assert health_status is not None
|
||||
|
||||
final_memory = process.memory_info().rss
|
||||
memory_increase = final_memory - initial_memory
|
||||
|
||||
# Memory increase should be reasonable (less than 50MB)
|
||||
assert memory_increase < 50 * 1024 * 1024
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
663
tests/integration/test_hardware_integration.py
Normal file
663
tests/integration/test_hardware_integration.py
Normal file
@@ -0,0 +1,663 @@
|
||||
"""
|
||||
Integration tests for hardware integration and router communication.
|
||||
|
||||
Tests WiFi router communication, CSI data collection, and hardware management.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import json
|
||||
import socket
|
||||
|
||||
|
||||
class MockRouterInterface:
|
||||
"""Mock WiFi router interface for testing."""
|
||||
|
||||
def __init__(self, router_id: str, ip_address: str = "192.168.1.1"):
|
||||
self.router_id = router_id
|
||||
self.ip_address = ip_address
|
||||
self.is_connected = False
|
||||
self.is_authenticated = False
|
||||
self.csi_streaming = False
|
||||
self.connection_attempts = 0
|
||||
self.last_heartbeat = None
|
||||
self.firmware_version = "1.2.3"
|
||||
self.capabilities = ["csi", "beamforming", "mimo"]
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to the router."""
|
||||
self.connection_attempts += 1
|
||||
|
||||
# Simulate connection failure for testing
|
||||
if self.connection_attempts == 1:
|
||||
return False
|
||||
|
||||
await asyncio.sleep(0.1) # Simulate connection time
|
||||
self.is_connected = True
|
||||
return True
|
||||
|
||||
async def authenticate(self, username: str, password: str) -> bool:
|
||||
"""Authenticate with the router."""
|
||||
if not self.is_connected:
|
||||
return False
|
||||
|
||||
# Simulate authentication
|
||||
if username == "admin" and password == "correct_password":
|
||||
self.is_authenticated = True
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def start_csi_streaming(self, config: Dict[str, Any]) -> bool:
|
||||
"""Start CSI data streaming."""
|
||||
if not self.is_authenticated:
|
||||
return False
|
||||
|
||||
# This should fail initially to test proper error handling
|
||||
return False
|
||||
|
||||
async def stop_csi_streaming(self) -> bool:
|
||||
"""Stop CSI data streaming."""
|
||||
if self.csi_streaming:
|
||||
self.csi_streaming = False
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get router status."""
|
||||
return {
|
||||
"router_id": self.router_id,
|
||||
"ip_address": self.ip_address,
|
||||
"is_connected": self.is_connected,
|
||||
"is_authenticated": self.is_authenticated,
|
||||
"csi_streaming": self.csi_streaming,
|
||||
"firmware_version": self.firmware_version,
|
||||
"uptime_seconds": 3600,
|
||||
"signal_strength": -45.2,
|
||||
"temperature": 42.5,
|
||||
"cpu_usage": 15.3
|
||||
}
|
||||
|
||||
async def send_heartbeat(self) -> bool:
|
||||
"""Send heartbeat to router."""
|
||||
if not self.is_connected:
|
||||
return False
|
||||
|
||||
self.last_heartbeat = datetime.utcnow()
|
||||
return True
|
||||
|
||||
|
||||
class TestRouterConnection:
|
||||
"""Test router connection functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def router_interface(self):
|
||||
"""Create router interface for testing."""
|
||||
return MockRouterInterface("router_001", "192.168.1.100")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_connection_should_fail_initially(self, router_interface):
|
||||
"""Test router connection - should fail initially."""
|
||||
# First connection attempt should fail
|
||||
result = await router_interface.connect()
|
||||
|
||||
# This will fail initially because we designed the mock to fail first attempt
|
||||
assert result is False
|
||||
assert router_interface.is_connected is False
|
||||
assert router_interface.connection_attempts == 1
|
||||
|
||||
# Second attempt should succeed
|
||||
result = await router_interface.connect()
|
||||
assert result is True
|
||||
assert router_interface.is_connected is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_authentication_should_fail_initially(self, router_interface):
|
||||
"""Test router authentication - should fail initially."""
|
||||
# Connect first
|
||||
await router_interface.connect()
|
||||
await router_interface.connect() # Second attempt succeeds
|
||||
|
||||
# Test wrong credentials
|
||||
result = await router_interface.authenticate("admin", "wrong_password")
|
||||
|
||||
# This will fail initially
|
||||
assert result is False
|
||||
assert router_interface.is_authenticated is False
|
||||
|
||||
# Test correct credentials
|
||||
result = await router_interface.authenticate("admin", "correct_password")
|
||||
assert result is True
|
||||
assert router_interface.is_authenticated is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_csi_streaming_start_should_fail_initially(self, router_interface):
|
||||
"""Test CSI streaming start - should fail initially."""
|
||||
# Setup connection and authentication
|
||||
await router_interface.connect()
|
||||
await router_interface.connect() # Second attempt succeeds
|
||||
await router_interface.authenticate("admin", "correct_password")
|
||||
|
||||
# Try to start CSI streaming
|
||||
config = {
|
||||
"frequency": 5.8e9,
|
||||
"bandwidth": 80e6,
|
||||
"sample_rate": 1000,
|
||||
"antenna_config": "4x4_mimo"
|
||||
}
|
||||
|
||||
result = await router_interface.start_csi_streaming(config)
|
||||
|
||||
# This will fail initially because the mock is designed to return False
|
||||
assert result is False
|
||||
assert router_interface.csi_streaming is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_status_retrieval_should_fail_initially(self, router_interface):
|
||||
"""Test router status retrieval - should fail initially."""
|
||||
status = await router_interface.get_status()
|
||||
|
||||
# This will fail initially
|
||||
assert isinstance(status, dict)
|
||||
assert status["router_id"] == "router_001"
|
||||
assert status["ip_address"] == "192.168.1.100"
|
||||
assert "firmware_version" in status
|
||||
assert "uptime_seconds" in status
|
||||
assert "signal_strength" in status
|
||||
assert "temperature" in status
|
||||
assert "cpu_usage" in status
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heartbeat_mechanism_should_fail_initially(self, router_interface):
|
||||
"""Test heartbeat mechanism - should fail initially."""
|
||||
# Heartbeat without connection should fail
|
||||
result = await router_interface.send_heartbeat()
|
||||
|
||||
# This will fail initially
|
||||
assert result is False
|
||||
assert router_interface.last_heartbeat is None
|
||||
|
||||
# Connect and try heartbeat
|
||||
await router_interface.connect()
|
||||
await router_interface.connect() # Second attempt succeeds
|
||||
|
||||
result = await router_interface.send_heartbeat()
|
||||
assert result is True
|
||||
assert router_interface.last_heartbeat is not None
|
||||
|
||||
|
||||
class TestMultiRouterManagement:
|
||||
"""Test management of multiple routers."""
|
||||
|
||||
@pytest.fixture
|
||||
def router_manager(self):
|
||||
"""Create router manager for testing."""
|
||||
class RouterManager:
|
||||
def __init__(self):
|
||||
self.routers = {}
|
||||
self.active_connections = 0
|
||||
|
||||
async def add_router(self, router_id: str, ip_address: str) -> bool:
|
||||
"""Add a router to management."""
|
||||
if router_id in self.routers:
|
||||
return False
|
||||
|
||||
router = MockRouterInterface(router_id, ip_address)
|
||||
self.routers[router_id] = router
|
||||
return True
|
||||
|
||||
async def connect_router(self, router_id: str) -> bool:
|
||||
"""Connect to a specific router."""
|
||||
if router_id not in self.routers:
|
||||
return False
|
||||
|
||||
router = self.routers[router_id]
|
||||
|
||||
# Try connecting twice (mock fails first time)
|
||||
success = await router.connect()
|
||||
if not success:
|
||||
success = await router.connect()
|
||||
|
||||
if success:
|
||||
self.active_connections += 1
|
||||
|
||||
return success
|
||||
|
||||
async def authenticate_router(self, router_id: str, username: str, password: str) -> bool:
|
||||
"""Authenticate with a router."""
|
||||
if router_id not in self.routers:
|
||||
return False
|
||||
|
||||
router = self.routers[router_id]
|
||||
return await router.authenticate(username, password)
|
||||
|
||||
async def get_all_status(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get status of all routers."""
|
||||
status = {}
|
||||
for router_id, router in self.routers.items():
|
||||
status[router_id] = await router.get_status()
|
||||
return status
|
||||
|
||||
async def start_all_csi_streaming(self, config: Dict[str, Any]) -> Dict[str, bool]:
|
||||
"""Start CSI streaming on all authenticated routers."""
|
||||
results = {}
|
||||
for router_id, router in self.routers.items():
|
||||
if router.is_authenticated:
|
||||
results[router_id] = await router.start_csi_streaming(config)
|
||||
else:
|
||||
results[router_id] = False
|
||||
return results
|
||||
|
||||
return RouterManager()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_router_addition_should_fail_initially(self, router_manager):
|
||||
"""Test adding multiple routers - should fail initially."""
|
||||
# Add first router
|
||||
result1 = await router_manager.add_router("router_001", "192.168.1.100")
|
||||
|
||||
# This will fail initially
|
||||
assert result1 is True
|
||||
assert "router_001" in router_manager.routers
|
||||
|
||||
# Add second router
|
||||
result2 = await router_manager.add_router("router_002", "192.168.1.101")
|
||||
assert result2 is True
|
||||
assert "router_002" in router_manager.routers
|
||||
|
||||
# Try to add duplicate router
|
||||
result3 = await router_manager.add_router("router_001", "192.168.1.102")
|
||||
assert result3 is False
|
||||
assert len(router_manager.routers) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_router_connections_should_fail_initially(self, router_manager):
|
||||
"""Test concurrent router connections - should fail initially."""
|
||||
# Add multiple routers
|
||||
await router_manager.add_router("router_001", "192.168.1.100")
|
||||
await router_manager.add_router("router_002", "192.168.1.101")
|
||||
await router_manager.add_router("router_003", "192.168.1.102")
|
||||
|
||||
# Connect to all routers concurrently
|
||||
connection_tasks = [
|
||||
router_manager.connect_router("router_001"),
|
||||
router_manager.connect_router("router_002"),
|
||||
router_manager.connect_router("router_003")
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*connection_tasks)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == 3
|
||||
assert all(results) # All connections should succeed
|
||||
assert router_manager.active_connections == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_status_aggregation_should_fail_initially(self, router_manager):
|
||||
"""Test router status aggregation - should fail initially."""
|
||||
# Add and connect routers
|
||||
await router_manager.add_router("router_001", "192.168.1.100")
|
||||
await router_manager.add_router("router_002", "192.168.1.101")
|
||||
|
||||
await router_manager.connect_router("router_001")
|
||||
await router_manager.connect_router("router_002")
|
||||
|
||||
# Get all status
|
||||
all_status = await router_manager.get_all_status()
|
||||
|
||||
# This will fail initially
|
||||
assert isinstance(all_status, dict)
|
||||
assert len(all_status) == 2
|
||||
assert "router_001" in all_status
|
||||
assert "router_002" in all_status
|
||||
|
||||
# Verify status structure
|
||||
for router_id, status in all_status.items():
|
||||
assert "router_id" in status
|
||||
assert "ip_address" in status
|
||||
assert "is_connected" in status
|
||||
assert status["is_connected"] is True
|
||||
|
||||
|
||||
class TestCSIDataCollection:
|
||||
"""Test CSI data collection from routers."""
|
||||
|
||||
@pytest.fixture
|
||||
def csi_collector(self):
|
||||
"""Create CSI data collector."""
|
||||
class CSICollector:
|
||||
def __init__(self):
|
||||
self.collected_data = []
|
||||
self.is_collecting = False
|
||||
self.collection_rate = 0
|
||||
|
||||
async def start_collection(self, router_interfaces: List[MockRouterInterface]) -> bool:
|
||||
"""Start CSI data collection."""
|
||||
# This should fail initially
|
||||
return False
|
||||
|
||||
async def stop_collection(self) -> bool:
|
||||
"""Stop CSI data collection."""
|
||||
if self.is_collecting:
|
||||
self.is_collecting = False
|
||||
return True
|
||||
return False
|
||||
|
||||
async def collect_frame(self, router_interface: MockRouterInterface) -> Optional[Dict[str, Any]]:
|
||||
"""Collect a single CSI frame."""
|
||||
if not router_interface.csi_streaming:
|
||||
return None
|
||||
|
||||
# Simulate CSI data
|
||||
return {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"router_id": router_interface.router_id,
|
||||
"amplitude": np.random.rand(64, 32).tolist(),
|
||||
"phase": np.random.rand(64, 32).tolist(),
|
||||
"frequency": 5.8e9,
|
||||
"bandwidth": 80e6,
|
||||
"antenna_count": 4,
|
||||
"subcarrier_count": 64,
|
||||
"signal_quality": np.random.uniform(0.7, 0.95)
|
||||
}
|
||||
|
||||
def get_collection_stats(self) -> Dict[str, Any]:
|
||||
"""Get collection statistics."""
|
||||
return {
|
||||
"total_frames": len(self.collected_data),
|
||||
"collection_rate": self.collection_rate,
|
||||
"is_collecting": self.is_collecting,
|
||||
"last_collection": self.collected_data[-1]["timestamp"] if self.collected_data else None
|
||||
}
|
||||
|
||||
return CSICollector()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_csi_collection_start_should_fail_initially(self, csi_collector):
|
||||
"""Test CSI collection start - should fail initially."""
|
||||
router_interfaces = [
|
||||
MockRouterInterface("router_001", "192.168.1.100"),
|
||||
MockRouterInterface("router_002", "192.168.1.101")
|
||||
]
|
||||
|
||||
result = await csi_collector.start_collection(router_interfaces)
|
||||
|
||||
# This will fail initially because the collector is designed to return False
|
||||
assert result is False
|
||||
assert csi_collector.is_collecting is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_frame_collection_should_fail_initially(self, csi_collector):
|
||||
"""Test single frame collection - should fail initially."""
|
||||
router = MockRouterInterface("router_001", "192.168.1.100")
|
||||
|
||||
# Without CSI streaming enabled
|
||||
frame = await csi_collector.collect_frame(router)
|
||||
|
||||
# This will fail initially
|
||||
assert frame is None
|
||||
|
||||
# Enable CSI streaming (manually for testing)
|
||||
router.csi_streaming = True
|
||||
frame = await csi_collector.collect_frame(router)
|
||||
|
||||
assert frame is not None
|
||||
assert "timestamp" in frame
|
||||
assert "router_id" in frame
|
||||
assert "amplitude" in frame
|
||||
assert "phase" in frame
|
||||
assert frame["router_id"] == "router_001"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collection_statistics_should_fail_initially(self, csi_collector):
|
||||
"""Test collection statistics - should fail initially."""
|
||||
stats = csi_collector.get_collection_stats()
|
||||
|
||||
# This will fail initially
|
||||
assert isinstance(stats, dict)
|
||||
assert "total_frames" in stats
|
||||
assert "collection_rate" in stats
|
||||
assert "is_collecting" in stats
|
||||
assert "last_collection" in stats
|
||||
|
||||
assert stats["total_frames"] == 0
|
||||
assert stats["is_collecting"] is False
|
||||
assert stats["last_collection"] is None
|
||||
|
||||
|
||||
class TestHardwareErrorHandling:
|
||||
"""Test hardware error handling scenarios."""
|
||||
|
||||
@pytest.fixture
|
||||
def unreliable_router(self):
|
||||
"""Create unreliable router for error testing."""
|
||||
class UnreliableRouter(MockRouterInterface):
|
||||
def __init__(self, router_id: str, ip_address: str = "192.168.1.1"):
|
||||
super().__init__(router_id, ip_address)
|
||||
self.failure_rate = 0.3 # 30% failure rate
|
||||
self.connection_drops = 0
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Unreliable connection."""
|
||||
if np.random.random() < self.failure_rate:
|
||||
return False
|
||||
return await super().connect()
|
||||
|
||||
async def send_heartbeat(self) -> bool:
|
||||
"""Unreliable heartbeat."""
|
||||
if np.random.random() < self.failure_rate:
|
||||
self.is_connected = False
|
||||
self.connection_drops += 1
|
||||
return False
|
||||
return await super().send_heartbeat()
|
||||
|
||||
async def start_csi_streaming(self, config: Dict[str, Any]) -> bool:
|
||||
"""Unreliable CSI streaming."""
|
||||
if np.random.random() < self.failure_rate:
|
||||
return False
|
||||
|
||||
# Still return False for initial test failure
|
||||
return False
|
||||
|
||||
return UnreliableRouter("unreliable_router", "192.168.1.200")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_retry_mechanism_should_fail_initially(self, unreliable_router):
|
||||
"""Test connection retry mechanism - should fail initially."""
|
||||
max_retries = 5
|
||||
success = False
|
||||
|
||||
for attempt in range(max_retries):
|
||||
result = await unreliable_router.connect()
|
||||
if result:
|
||||
success = True
|
||||
break
|
||||
|
||||
# Wait before retry
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# This will fail initially due to randomness, but should eventually pass
|
||||
# The test demonstrates the need for retry logic
|
||||
assert success or unreliable_router.connection_attempts >= max_retries
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_drop_detection_should_fail_initially(self, unreliable_router):
|
||||
"""Test connection drop detection - should fail initially."""
|
||||
# Establish connection
|
||||
await unreliable_router.connect()
|
||||
await unreliable_router.connect() # Ensure connection
|
||||
|
||||
initial_drops = unreliable_router.connection_drops
|
||||
|
||||
# Send multiple heartbeats to trigger potential drops
|
||||
for _ in range(10):
|
||||
await unreliable_router.send_heartbeat()
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# This will fail initially
|
||||
# Should detect connection drops
|
||||
final_drops = unreliable_router.connection_drops
|
||||
assert final_drops >= initial_drops # May have detected drops
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hardware_timeout_handling_should_fail_initially(self):
|
||||
"""Test hardware timeout handling - should fail initially."""
|
||||
async def slow_operation():
|
||||
"""Simulate slow hardware operation."""
|
||||
await asyncio.sleep(2.0) # 2 second delay
|
||||
return "success"
|
||||
|
||||
# Test with timeout
|
||||
try:
|
||||
result = await asyncio.wait_for(slow_operation(), timeout=1.0)
|
||||
# This should not be reached
|
||||
assert False, "Operation should have timed out"
|
||||
except asyncio.TimeoutError:
|
||||
# This will fail initially because we expect timeout handling
|
||||
assert True # Timeout was properly handled
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_network_error_simulation_should_fail_initially(self):
|
||||
"""Test network error simulation - should fail initially."""
|
||||
class NetworkErrorRouter(MockRouterInterface):
|
||||
async def connect(self) -> bool:
|
||||
"""Simulate network error."""
|
||||
raise ConnectionError("Network unreachable")
|
||||
|
||||
router = NetworkErrorRouter("error_router", "192.168.1.999")
|
||||
|
||||
# This will fail initially
|
||||
with pytest.raises(ConnectionError, match="Network unreachable"):
|
||||
await router.connect()
|
||||
|
||||
|
||||
class TestHardwareConfiguration:
|
||||
"""Test hardware configuration management."""
|
||||
|
||||
@pytest.fixture
|
||||
def config_manager(self):
|
||||
"""Create configuration manager."""
|
||||
class ConfigManager:
|
||||
def __init__(self):
|
||||
self.default_config = {
|
||||
"frequency": 5.8e9,
|
||||
"bandwidth": 80e6,
|
||||
"sample_rate": 1000,
|
||||
"antenna_config": "4x4_mimo",
|
||||
"power_level": 20,
|
||||
"channel": 36
|
||||
}
|
||||
self.router_configs = {}
|
||||
|
||||
def get_router_config(self, router_id: str) -> Dict[str, Any]:
|
||||
"""Get configuration for a specific router."""
|
||||
return self.router_configs.get(router_id, self.default_config.copy())
|
||||
|
||||
def set_router_config(self, router_id: str, config: Dict[str, Any]) -> bool:
|
||||
"""Set configuration for a specific router."""
|
||||
# Validate configuration
|
||||
required_fields = ["frequency", "bandwidth", "sample_rate"]
|
||||
if not all(field in config for field in required_fields):
|
||||
return False
|
||||
|
||||
self.router_configs[router_id] = config
|
||||
return True
|
||||
|
||||
def validate_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate router configuration."""
|
||||
errors = []
|
||||
|
||||
# Frequency validation
|
||||
if "frequency" in config:
|
||||
freq = config["frequency"]
|
||||
if not (2.4e9 <= freq <= 6e9):
|
||||
errors.append("Frequency must be between 2.4GHz and 6GHz")
|
||||
|
||||
# Bandwidth validation
|
||||
if "bandwidth" in config:
|
||||
bw = config["bandwidth"]
|
||||
if bw not in [20e6, 40e6, 80e6, 160e6]:
|
||||
errors.append("Bandwidth must be 20, 40, 80, or 160 MHz")
|
||||
|
||||
# Sample rate validation
|
||||
if "sample_rate" in config:
|
||||
sr = config["sample_rate"]
|
||||
if not (100 <= sr <= 10000):
|
||||
errors.append("Sample rate must be between 100 and 10000 Hz")
|
||||
|
||||
return {
|
||||
"valid": len(errors) == 0,
|
||||
"errors": errors
|
||||
}
|
||||
|
||||
return ConfigManager()
|
||||
|
||||
def test_default_configuration_should_fail_initially(self, config_manager):
|
||||
"""Test default configuration retrieval - should fail initially."""
|
||||
config = config_manager.get_router_config("new_router")
|
||||
|
||||
# This will fail initially
|
||||
assert isinstance(config, dict)
|
||||
assert "frequency" in config
|
||||
assert "bandwidth" in config
|
||||
assert "sample_rate" in config
|
||||
assert "antenna_config" in config
|
||||
assert config["frequency"] == 5.8e9
|
||||
assert config["bandwidth"] == 80e6
|
||||
|
||||
def test_configuration_validation_should_fail_initially(self, config_manager):
|
||||
"""Test configuration validation - should fail initially."""
|
||||
# Valid configuration
|
||||
valid_config = {
|
||||
"frequency": 5.8e9,
|
||||
"bandwidth": 80e6,
|
||||
"sample_rate": 1000
|
||||
}
|
||||
|
||||
result = config_manager.validate_config(valid_config)
|
||||
|
||||
# This will fail initially
|
||||
assert result["valid"] is True
|
||||
assert len(result["errors"]) == 0
|
||||
|
||||
# Invalid configuration
|
||||
invalid_config = {
|
||||
"frequency": 10e9, # Too high
|
||||
"bandwidth": 100e6, # Invalid
|
||||
"sample_rate": 50 # Too low
|
||||
}
|
||||
|
||||
result = config_manager.validate_config(invalid_config)
|
||||
assert result["valid"] is False
|
||||
assert len(result["errors"]) == 3
|
||||
|
||||
def test_router_specific_configuration_should_fail_initially(self, config_manager):
|
||||
"""Test router-specific configuration - should fail initially."""
|
||||
router_id = "router_001"
|
||||
custom_config = {
|
||||
"frequency": 2.4e9,
|
||||
"bandwidth": 40e6,
|
||||
"sample_rate": 500,
|
||||
"antenna_config": "2x2_mimo"
|
||||
}
|
||||
|
||||
# Set custom configuration
|
||||
result = config_manager.set_router_config(router_id, custom_config)
|
||||
|
||||
# This will fail initially
|
||||
assert result is True
|
||||
|
||||
# Retrieve custom configuration
|
||||
retrieved_config = config_manager.get_router_config(router_id)
|
||||
assert retrieved_config["frequency"] == 2.4e9
|
||||
assert retrieved_config["bandwidth"] == 40e6
|
||||
assert retrieved_config["antenna_config"] == "2x2_mimo"
|
||||
|
||||
# Test invalid configuration
|
||||
invalid_config = {"frequency": 5.8e9} # Missing required fields
|
||||
result = config_manager.set_router_config(router_id, invalid_config)
|
||||
assert result is False
|
||||
577
tests/integration/test_pose_pipeline.py
Normal file
577
tests/integration/test_pose_pipeline.py
Normal file
@@ -0,0 +1,577 @@
|
||||
"""
|
||||
Integration tests for end-to-end pose estimation pipeline.
|
||||
|
||||
Tests the complete pose estimation workflow from CSI data to pose results.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import json
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CSIData:
|
||||
"""CSI data structure for testing."""
|
||||
timestamp: datetime
|
||||
router_id: str
|
||||
amplitude: np.ndarray
|
||||
phase: np.ndarray
|
||||
frequency: float
|
||||
bandwidth: float
|
||||
antenna_count: int
|
||||
subcarrier_count: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class PoseResult:
|
||||
"""Pose estimation result structure."""
|
||||
timestamp: datetime
|
||||
frame_id: str
|
||||
persons: List[Dict[str, Any]]
|
||||
zone_summary: Dict[str, int]
|
||||
processing_time_ms: float
|
||||
confidence_scores: List[float]
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class MockCSIProcessor:
|
||||
"""Mock CSI data processor."""
|
||||
|
||||
def __init__(self):
|
||||
self.is_initialized = False
|
||||
self.processing_enabled = True
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the processor."""
|
||||
self.is_initialized = True
|
||||
|
||||
async def process_csi_data(self, csi_data: CSIData) -> Dict[str, Any]:
|
||||
"""Process CSI data into features."""
|
||||
if not self.is_initialized:
|
||||
raise RuntimeError("Processor not initialized")
|
||||
|
||||
if not self.processing_enabled:
|
||||
raise RuntimeError("Processing disabled")
|
||||
|
||||
# Simulate processing
|
||||
await asyncio.sleep(0.01) # Simulate processing time
|
||||
|
||||
return {
|
||||
"features": np.random.rand(64, 32).tolist(), # Mock feature matrix
|
||||
"quality_score": 0.85,
|
||||
"signal_strength": -45.2,
|
||||
"noise_level": -78.1,
|
||||
"processed_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
def set_processing_enabled(self, enabled: bool):
|
||||
"""Enable/disable processing."""
|
||||
self.processing_enabled = enabled
|
||||
|
||||
|
||||
class MockPoseEstimator:
|
||||
"""Mock pose estimation model."""
|
||||
|
||||
def __init__(self):
|
||||
self.is_loaded = False
|
||||
self.model_version = "1.0.0"
|
||||
self.confidence_threshold = 0.5
|
||||
|
||||
async def load_model(self):
|
||||
"""Load the pose estimation model."""
|
||||
await asyncio.sleep(0.1) # Simulate model loading
|
||||
self.is_loaded = True
|
||||
|
||||
async def estimate_poses(self, features: np.ndarray) -> Dict[str, Any]:
|
||||
"""Estimate poses from features."""
|
||||
if not self.is_loaded:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
# Simulate pose estimation
|
||||
await asyncio.sleep(0.05) # Simulate inference time
|
||||
|
||||
# Generate mock pose data
|
||||
num_persons = np.random.randint(0, 4) # 0-3 persons
|
||||
persons = []
|
||||
|
||||
for i in range(num_persons):
|
||||
confidence = np.random.uniform(0.3, 0.95)
|
||||
if confidence >= self.confidence_threshold:
|
||||
persons.append({
|
||||
"person_id": f"person_{i}",
|
||||
"confidence": confidence,
|
||||
"bounding_box": {
|
||||
"x": np.random.uniform(0, 800),
|
||||
"y": np.random.uniform(0, 600),
|
||||
"width": np.random.uniform(50, 200),
|
||||
"height": np.random.uniform(100, 400)
|
||||
},
|
||||
"keypoints": [
|
||||
{
|
||||
"name": "head",
|
||||
"x": np.random.uniform(0, 800),
|
||||
"y": np.random.uniform(0, 200),
|
||||
"confidence": np.random.uniform(0.5, 0.95)
|
||||
},
|
||||
{
|
||||
"name": "torso",
|
||||
"x": np.random.uniform(0, 800),
|
||||
"y": np.random.uniform(200, 400),
|
||||
"confidence": np.random.uniform(0.5, 0.95)
|
||||
}
|
||||
],
|
||||
"activity": "standing" if np.random.random() > 0.2 else "sitting"
|
||||
})
|
||||
|
||||
return {
|
||||
"persons": persons,
|
||||
"processing_time_ms": np.random.uniform(20, 80),
|
||||
"model_version": self.model_version,
|
||||
"confidence_threshold": self.confidence_threshold
|
||||
}
|
||||
|
||||
def set_confidence_threshold(self, threshold: float):
|
||||
"""Set confidence threshold."""
|
||||
self.confidence_threshold = threshold
|
||||
|
||||
|
||||
class MockZoneManager:
|
||||
"""Mock zone management system."""
|
||||
|
||||
def __init__(self):
|
||||
self.zones = {
|
||||
"zone1": {"id": "zone1", "name": "Zone 1", "bounds": [0, 0, 400, 600]},
|
||||
"zone2": {"id": "zone2", "name": "Zone 2", "bounds": [400, 0, 800, 600]},
|
||||
"zone3": {"id": "zone3", "name": "Zone 3", "bounds": [0, 300, 800, 600]}
|
||||
}
|
||||
|
||||
def assign_persons_to_zones(self, persons: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Assign detected persons to zones."""
|
||||
zone_summary = {zone_id: 0 for zone_id in self.zones.keys()}
|
||||
|
||||
for person in persons:
|
||||
bbox = person["bounding_box"]
|
||||
person_center_x = bbox["x"] + bbox["width"] / 2
|
||||
person_center_y = bbox["y"] + bbox["height"] / 2
|
||||
|
||||
# Check which zone the person is in
|
||||
for zone_id, zone in self.zones.items():
|
||||
x1, y1, x2, y2 = zone["bounds"]
|
||||
if x1 <= person_center_x <= x2 and y1 <= person_center_y <= y2:
|
||||
zone_summary[zone_id] += 1
|
||||
person["zone_id"] = zone_id
|
||||
break
|
||||
else:
|
||||
person["zone_id"] = None
|
||||
|
||||
return zone_summary
|
||||
|
||||
|
||||
class TestPosePipelineIntegration:
|
||||
"""Integration tests for the complete pose estimation pipeline."""
|
||||
|
||||
@pytest.fixture
|
||||
def csi_processor(self):
|
||||
"""Create CSI processor."""
|
||||
return MockCSIProcessor()
|
||||
|
||||
@pytest.fixture
|
||||
def pose_estimator(self):
|
||||
"""Create pose estimator."""
|
||||
return MockPoseEstimator()
|
||||
|
||||
@pytest.fixture
|
||||
def zone_manager(self):
|
||||
"""Create zone manager."""
|
||||
return MockZoneManager()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_csi_data(self):
|
||||
"""Create sample CSI data."""
|
||||
return CSIData(
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
amplitude=np.random.rand(64, 32),
|
||||
phase=np.random.rand(64, 32),
|
||||
frequency=5.8e9, # 5.8 GHz
|
||||
bandwidth=80e6, # 80 MHz
|
||||
antenna_count=4,
|
||||
subcarrier_count=64
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
async def pose_pipeline(self, csi_processor, pose_estimator, zone_manager):
|
||||
"""Create complete pose pipeline."""
|
||||
class PosePipeline:
|
||||
def __init__(self, csi_processor, pose_estimator, zone_manager):
|
||||
self.csi_processor = csi_processor
|
||||
self.pose_estimator = pose_estimator
|
||||
self.zone_manager = zone_manager
|
||||
self.is_initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the pipeline."""
|
||||
await self.csi_processor.initialize()
|
||||
await self.pose_estimator.load_model()
|
||||
self.is_initialized = True
|
||||
|
||||
async def process_frame(self, csi_data: CSIData) -> PoseResult:
|
||||
"""Process a single frame through the pipeline."""
|
||||
if not self.is_initialized:
|
||||
raise RuntimeError("Pipeline not initialized")
|
||||
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
# Step 1: Process CSI data
|
||||
processed_data = await self.csi_processor.process_csi_data(csi_data)
|
||||
|
||||
# Step 2: Extract features
|
||||
features = np.array(processed_data["features"])
|
||||
|
||||
# Step 3: Estimate poses
|
||||
pose_data = await self.pose_estimator.estimate_poses(features)
|
||||
|
||||
# Step 4: Assign to zones
|
||||
zone_summary = self.zone_manager.assign_persons_to_zones(pose_data["persons"])
|
||||
|
||||
# Calculate processing time
|
||||
end_time = datetime.utcnow()
|
||||
processing_time = (end_time - start_time).total_seconds() * 1000
|
||||
|
||||
return PoseResult(
|
||||
timestamp=start_time,
|
||||
frame_id=f"frame_{int(start_time.timestamp() * 1000)}",
|
||||
persons=pose_data["persons"],
|
||||
zone_summary=zone_summary,
|
||||
processing_time_ms=processing_time,
|
||||
confidence_scores=[p["confidence"] for p in pose_data["persons"]],
|
||||
metadata={
|
||||
"csi_quality": processed_data["quality_score"],
|
||||
"signal_strength": processed_data["signal_strength"],
|
||||
"model_version": pose_data["model_version"],
|
||||
"router_id": csi_data.router_id
|
||||
}
|
||||
)
|
||||
|
||||
pipeline = PosePipeline(csi_processor, pose_estimator, zone_manager)
|
||||
await pipeline.initialize()
|
||||
return pipeline
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_initialization_should_fail_initially(self, csi_processor, pose_estimator, zone_manager):
|
||||
"""Test pipeline initialization - should fail initially."""
|
||||
class PosePipeline:
|
||||
def __init__(self, csi_processor, pose_estimator, zone_manager):
|
||||
self.csi_processor = csi_processor
|
||||
self.pose_estimator = pose_estimator
|
||||
self.zone_manager = zone_manager
|
||||
self.is_initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
await self.csi_processor.initialize()
|
||||
await self.pose_estimator.load_model()
|
||||
self.is_initialized = True
|
||||
|
||||
pipeline = PosePipeline(csi_processor, pose_estimator, zone_manager)
|
||||
|
||||
# Initially not initialized
|
||||
assert not pipeline.is_initialized
|
||||
assert not csi_processor.is_initialized
|
||||
assert not pose_estimator.is_loaded
|
||||
|
||||
# Initialize pipeline
|
||||
await pipeline.initialize()
|
||||
|
||||
# This will fail initially
|
||||
assert pipeline.is_initialized
|
||||
assert csi_processor.is_initialized
|
||||
assert pose_estimator.is_loaded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_pose_estimation_should_fail_initially(self, pose_pipeline, sample_csi_data):
|
||||
"""Test end-to-end pose estimation - should fail initially."""
|
||||
result = await pose_pipeline.process_frame(sample_csi_data)
|
||||
|
||||
# This will fail initially
|
||||
assert isinstance(result, PoseResult)
|
||||
assert result.timestamp is not None
|
||||
assert result.frame_id.startswith("frame_")
|
||||
assert isinstance(result.persons, list)
|
||||
assert isinstance(result.zone_summary, dict)
|
||||
assert result.processing_time_ms > 0
|
||||
assert isinstance(result.confidence_scores, list)
|
||||
assert isinstance(result.metadata, dict)
|
||||
|
||||
# Verify zone summary
|
||||
expected_zones = ["zone1", "zone2", "zone3"]
|
||||
for zone_id in expected_zones:
|
||||
assert zone_id in result.zone_summary
|
||||
assert isinstance(result.zone_summary[zone_id], int)
|
||||
assert result.zone_summary[zone_id] >= 0
|
||||
|
||||
# Verify metadata
|
||||
assert "csi_quality" in result.metadata
|
||||
assert "signal_strength" in result.metadata
|
||||
assert "model_version" in result.metadata
|
||||
assert "router_id" in result.metadata
|
||||
assert result.metadata["router_id"] == sample_csi_data.router_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_with_multiple_frames_should_fail_initially(self, pose_pipeline):
|
||||
"""Test pipeline with multiple frames - should fail initially."""
|
||||
results = []
|
||||
|
||||
# Process multiple frames
|
||||
for i in range(5):
|
||||
csi_data = CSIData(
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id=f"router_{i % 2 + 1:03d}", # Alternate between router_001 and router_002
|
||||
amplitude=np.random.rand(64, 32),
|
||||
phase=np.random.rand(64, 32),
|
||||
frequency=5.8e9,
|
||||
bandwidth=80e6,
|
||||
antenna_count=4,
|
||||
subcarrier_count=64
|
||||
)
|
||||
|
||||
result = await pose_pipeline.process_frame(csi_data)
|
||||
results.append(result)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == 5
|
||||
|
||||
# Verify each result
|
||||
for i, result in enumerate(results):
|
||||
assert result.frame_id != results[0].frame_id if i > 0 else True
|
||||
assert result.metadata["router_id"] in ["router_001", "router_002"]
|
||||
assert result.processing_time_ms > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_error_handling_should_fail_initially(self, csi_processor, pose_estimator, zone_manager, sample_csi_data):
|
||||
"""Test pipeline error handling - should fail initially."""
|
||||
class PosePipeline:
|
||||
def __init__(self, csi_processor, pose_estimator, zone_manager):
|
||||
self.csi_processor = csi_processor
|
||||
self.pose_estimator = pose_estimator
|
||||
self.zone_manager = zone_manager
|
||||
self.is_initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
await self.csi_processor.initialize()
|
||||
await self.pose_estimator.load_model()
|
||||
self.is_initialized = True
|
||||
|
||||
async def process_frame(self, csi_data):
|
||||
if not self.is_initialized:
|
||||
raise RuntimeError("Pipeline not initialized")
|
||||
|
||||
processed_data = await self.csi_processor.process_csi_data(csi_data)
|
||||
features = np.array(processed_data["features"])
|
||||
pose_data = await self.pose_estimator.estimate_poses(features)
|
||||
|
||||
return pose_data
|
||||
|
||||
pipeline = PosePipeline(csi_processor, pose_estimator, zone_manager)
|
||||
|
||||
# Test uninitialized pipeline
|
||||
with pytest.raises(RuntimeError, match="Pipeline not initialized"):
|
||||
await pipeline.process_frame(sample_csi_data)
|
||||
|
||||
# Initialize pipeline
|
||||
await pipeline.initialize()
|
||||
|
||||
# Test with disabled CSI processor
|
||||
csi_processor.set_processing_enabled(False)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Processing disabled"):
|
||||
await pipeline.process_frame(sample_csi_data)
|
||||
|
||||
# This assertion will fail initially
|
||||
assert True # Test completed successfully
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confidence_threshold_filtering_should_fail_initially(self, pose_pipeline, sample_csi_data):
|
||||
"""Test confidence threshold filtering - should fail initially."""
|
||||
# Set high confidence threshold
|
||||
pose_pipeline.pose_estimator.set_confidence_threshold(0.9)
|
||||
|
||||
result = await pose_pipeline.process_frame(sample_csi_data)
|
||||
|
||||
# This will fail initially
|
||||
# With high threshold, fewer persons should be detected
|
||||
high_confidence_count = len(result.persons)
|
||||
|
||||
# Set low confidence threshold
|
||||
pose_pipeline.pose_estimator.set_confidence_threshold(0.1)
|
||||
|
||||
result = await pose_pipeline.process_frame(sample_csi_data)
|
||||
low_confidence_count = len(result.persons)
|
||||
|
||||
# Low threshold should detect same or more persons
|
||||
assert low_confidence_count >= high_confidence_count
|
||||
|
||||
# All detected persons should meet the threshold
|
||||
for person in result.persons:
|
||||
assert person["confidence"] >= 0.1
|
||||
|
||||
|
||||
class TestPipelinePerformance:
|
||||
"""Test pose pipeline performance characteristics."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_throughput_should_fail_initially(self, pose_pipeline):
|
||||
"""Test pipeline throughput - should fail initially."""
|
||||
frame_count = 10
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
# Process multiple frames
|
||||
for i in range(frame_count):
|
||||
csi_data = CSIData(
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
amplitude=np.random.rand(64, 32),
|
||||
phase=np.random.rand(64, 32),
|
||||
frequency=5.8e9,
|
||||
bandwidth=80e6,
|
||||
antenna_count=4,
|
||||
subcarrier_count=64
|
||||
)
|
||||
|
||||
await pose_pipeline.process_frame(csi_data)
|
||||
|
||||
end_time = datetime.utcnow()
|
||||
total_time = (end_time - start_time).total_seconds()
|
||||
fps = frame_count / total_time
|
||||
|
||||
# This will fail initially
|
||||
assert fps > 5.0 # Should process at least 5 FPS
|
||||
assert total_time < 5.0 # Should complete 10 frames in under 5 seconds
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_frame_processing_should_fail_initially(self, pose_pipeline):
|
||||
"""Test concurrent frame processing - should fail initially."""
|
||||
async def process_single_frame(frame_id: int):
|
||||
csi_data = CSIData(
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id=f"router_{frame_id % 3 + 1:03d}",
|
||||
amplitude=np.random.rand(64, 32),
|
||||
phase=np.random.rand(64, 32),
|
||||
frequency=5.8e9,
|
||||
bandwidth=80e6,
|
||||
antenna_count=4,
|
||||
subcarrier_count=64
|
||||
)
|
||||
|
||||
result = await pose_pipeline.process_frame(csi_data)
|
||||
return result.frame_id
|
||||
|
||||
# Process frames concurrently
|
||||
tasks = [process_single_frame(i) for i in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == 5
|
||||
assert len(set(results)) == 5 # All frame IDs should be unique
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_usage_stability_should_fail_initially(self, pose_pipeline):
|
||||
"""Test memory usage stability - should fail initially."""
|
||||
import psutil
|
||||
import os
|
||||
|
||||
process = psutil.Process(os.getpid())
|
||||
initial_memory = process.memory_info().rss
|
||||
|
||||
# Process many frames
|
||||
for i in range(50):
|
||||
csi_data = CSIData(
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
amplitude=np.random.rand(64, 32),
|
||||
phase=np.random.rand(64, 32),
|
||||
frequency=5.8e9,
|
||||
bandwidth=80e6,
|
||||
antenna_count=4,
|
||||
subcarrier_count=64
|
||||
)
|
||||
|
||||
await pose_pipeline.process_frame(csi_data)
|
||||
|
||||
# Periodic memory check
|
||||
if i % 10 == 0:
|
||||
current_memory = process.memory_info().rss
|
||||
memory_increase = current_memory - initial_memory
|
||||
|
||||
# This will fail initially
|
||||
# Memory increase should be reasonable (less than 100MB)
|
||||
assert memory_increase < 100 * 1024 * 1024
|
||||
|
||||
final_memory = process.memory_info().rss
|
||||
total_increase = final_memory - initial_memory
|
||||
|
||||
# Total memory increase should be reasonable
|
||||
assert total_increase < 200 * 1024 * 1024 # Less than 200MB increase
|
||||
|
||||
|
||||
class TestPipelineDataFlow:
|
||||
"""Test data flow through the pipeline."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_transformation_chain_should_fail_initially(self, csi_processor, pose_estimator, zone_manager, sample_csi_data):
|
||||
"""Test data transformation through the pipeline - should fail initially."""
|
||||
# Step 1: CSI processing
|
||||
await csi_processor.initialize()
|
||||
processed_data = await csi_processor.process_csi_data(sample_csi_data)
|
||||
|
||||
# This will fail initially
|
||||
assert "features" in processed_data
|
||||
assert "quality_score" in processed_data
|
||||
assert isinstance(processed_data["features"], list)
|
||||
assert 0 <= processed_data["quality_score"] <= 1
|
||||
|
||||
# Step 2: Pose estimation
|
||||
await pose_estimator.load_model()
|
||||
features = np.array(processed_data["features"])
|
||||
pose_data = await pose_estimator.estimate_poses(features)
|
||||
|
||||
assert "persons" in pose_data
|
||||
assert "processing_time_ms" in pose_data
|
||||
assert isinstance(pose_data["persons"], list)
|
||||
|
||||
# Step 3: Zone assignment
|
||||
zone_summary = zone_manager.assign_persons_to_zones(pose_data["persons"])
|
||||
|
||||
assert isinstance(zone_summary, dict)
|
||||
assert all(isinstance(count, int) for count in zone_summary.values())
|
||||
|
||||
# Verify person zone assignments
|
||||
for person in pose_data["persons"]:
|
||||
if "zone_id" in person and person["zone_id"]:
|
||||
assert person["zone_id"] in zone_summary
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_state_consistency_should_fail_initially(self, pose_pipeline, sample_csi_data):
|
||||
"""Test pipeline state consistency - should fail initially."""
|
||||
# Process the same frame multiple times
|
||||
results = []
|
||||
for _ in range(3):
|
||||
result = await pose_pipeline.process_frame(sample_csi_data)
|
||||
results.append(result)
|
||||
|
||||
# This will fail initially
|
||||
# Results should be consistent (same input should produce similar output)
|
||||
assert len(results) == 3
|
||||
|
||||
# All results should have the same router_id
|
||||
router_ids = [r.metadata["router_id"] for r in results]
|
||||
assert all(rid == router_ids[0] for rid in router_ids)
|
||||
|
||||
# Processing times should be reasonable and similar
|
||||
processing_times = [r.processing_time_ms for r in results]
|
||||
assert all(10 <= pt <= 200 for pt in processing_times) # Between 10ms and 200ms
|
||||
565
tests/integration/test_rate_limiting.py
Normal file
565
tests/integration/test_rate_limiting.py
Normal file
@@ -0,0 +1,565 @@
|
||||
"""
|
||||
Integration tests for rate limiting functionality.
|
||||
|
||||
Tests rate limit behavior, throttling, and quota management.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import time
|
||||
|
||||
from fastapi import HTTPException, status, Request, Response
|
||||
|
||||
|
||||
class MockRateLimiter:
|
||||
"""Mock rate limiter for testing."""
|
||||
|
||||
def __init__(self, requests_per_minute: int = 60, requests_per_hour: int = 1000):
|
||||
self.requests_per_minute = requests_per_minute
|
||||
self.requests_per_hour = requests_per_hour
|
||||
self.request_history = {}
|
||||
self.blocked_clients = set()
|
||||
|
||||
def _get_client_key(self, client_id: str, endpoint: str = None) -> str:
|
||||
"""Get client key for rate limiting."""
|
||||
return f"{client_id}:{endpoint}" if endpoint else client_id
|
||||
|
||||
def _cleanup_old_requests(self, client_key: str):
|
||||
"""Clean up old request records."""
|
||||
if client_key not in self.request_history:
|
||||
return
|
||||
|
||||
now = datetime.utcnow()
|
||||
minute_ago = now - timedelta(minutes=1)
|
||||
hour_ago = now - timedelta(hours=1)
|
||||
|
||||
# Keep only requests from the last hour
|
||||
self.request_history[client_key] = [
|
||||
req_time for req_time in self.request_history[client_key]
|
||||
if req_time > hour_ago
|
||||
]
|
||||
|
||||
def check_rate_limit(self, client_id: str, endpoint: str = None) -> Dict[str, Any]:
|
||||
"""Check if client is within rate limits."""
|
||||
client_key = self._get_client_key(client_id, endpoint)
|
||||
|
||||
if client_id in self.blocked_clients:
|
||||
return {
|
||||
"allowed": False,
|
||||
"reason": "Client blocked",
|
||||
"retry_after": 3600 # 1 hour
|
||||
}
|
||||
|
||||
self._cleanup_old_requests(client_key)
|
||||
|
||||
if client_key not in self.request_history:
|
||||
self.request_history[client_key] = []
|
||||
|
||||
now = datetime.utcnow()
|
||||
minute_ago = now - timedelta(minutes=1)
|
||||
|
||||
# Count requests in the last minute
|
||||
recent_requests = [
|
||||
req_time for req_time in self.request_history[client_key]
|
||||
if req_time > minute_ago
|
||||
]
|
||||
|
||||
# Count requests in the last hour
|
||||
hour_requests = len(self.request_history[client_key])
|
||||
|
||||
if len(recent_requests) >= self.requests_per_minute:
|
||||
return {
|
||||
"allowed": False,
|
||||
"reason": "Rate limit exceeded (per minute)",
|
||||
"retry_after": 60,
|
||||
"current_requests": len(recent_requests),
|
||||
"limit": self.requests_per_minute
|
||||
}
|
||||
|
||||
if hour_requests >= self.requests_per_hour:
|
||||
return {
|
||||
"allowed": False,
|
||||
"reason": "Rate limit exceeded (per hour)",
|
||||
"retry_after": 3600,
|
||||
"current_requests": hour_requests,
|
||||
"limit": self.requests_per_hour
|
||||
}
|
||||
|
||||
# Record this request
|
||||
self.request_history[client_key].append(now)
|
||||
|
||||
return {
|
||||
"allowed": True,
|
||||
"remaining_minute": self.requests_per_minute - len(recent_requests) - 1,
|
||||
"remaining_hour": self.requests_per_hour - hour_requests - 1,
|
||||
"reset_time": minute_ago + timedelta(minutes=1)
|
||||
}
|
||||
|
||||
def block_client(self, client_id: str):
|
||||
"""Block a client."""
|
||||
self.blocked_clients.add(client_id)
|
||||
|
||||
def unblock_client(self, client_id: str):
|
||||
"""Unblock a client."""
|
||||
self.blocked_clients.discard(client_id)
|
||||
|
||||
|
||||
class TestRateLimitingBasic:
|
||||
"""Test basic rate limiting functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def rate_limiter(self):
|
||||
"""Create rate limiter for testing."""
|
||||
return MockRateLimiter(requests_per_minute=5, requests_per_hour=100)
|
||||
|
||||
def test_rate_limit_within_bounds_should_fail_initially(self, rate_limiter):
|
||||
"""Test rate limiting within bounds - should fail initially."""
|
||||
client_id = "test-client-001"
|
||||
|
||||
# Make requests within limit
|
||||
for i in range(3):
|
||||
result = rate_limiter.check_rate_limit(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert result["allowed"] is True
|
||||
assert "remaining_minute" in result
|
||||
assert "remaining_hour" in result
|
||||
|
||||
def test_rate_limit_per_minute_exceeded_should_fail_initially(self, rate_limiter):
|
||||
"""Test per-minute rate limit exceeded - should fail initially."""
|
||||
client_id = "test-client-002"
|
||||
|
||||
# Make requests up to the limit
|
||||
for i in range(5):
|
||||
result = rate_limiter.check_rate_limit(client_id)
|
||||
assert result["allowed"] is True
|
||||
|
||||
# Next request should be blocked
|
||||
result = rate_limiter.check_rate_limit(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert result["allowed"] is False
|
||||
assert "per minute" in result["reason"]
|
||||
assert result["retry_after"] == 60
|
||||
assert result["current_requests"] == 5
|
||||
assert result["limit"] == 5
|
||||
|
||||
def test_rate_limit_per_hour_exceeded_should_fail_initially(self, rate_limiter):
|
||||
"""Test per-hour rate limit exceeded - should fail initially."""
|
||||
# Create rate limiter with very low hour limit for testing
|
||||
limiter = MockRateLimiter(requests_per_minute=10, requests_per_hour=3)
|
||||
client_id = "test-client-003"
|
||||
|
||||
# Make requests up to hour limit
|
||||
for i in range(3):
|
||||
result = limiter.check_rate_limit(client_id)
|
||||
assert result["allowed"] is True
|
||||
|
||||
# Next request should be blocked
|
||||
result = limiter.check_rate_limit(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert result["allowed"] is False
|
||||
assert "per hour" in result["reason"]
|
||||
assert result["retry_after"] == 3600
|
||||
|
||||
def test_blocked_client_should_fail_initially(self, rate_limiter):
|
||||
"""Test blocked client handling - should fail initially."""
|
||||
client_id = "blocked-client"
|
||||
|
||||
# Block the client
|
||||
rate_limiter.block_client(client_id)
|
||||
|
||||
# Request should be blocked
|
||||
result = rate_limiter.check_rate_limit(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert result["allowed"] is False
|
||||
assert result["reason"] == "Client blocked"
|
||||
assert result["retry_after"] == 3600
|
||||
|
||||
# Unblock and test
|
||||
rate_limiter.unblock_client(client_id)
|
||||
result = rate_limiter.check_rate_limit(client_id)
|
||||
assert result["allowed"] is True
|
||||
|
||||
def test_endpoint_specific_rate_limiting_should_fail_initially(self, rate_limiter):
|
||||
"""Test endpoint-specific rate limiting - should fail initially."""
|
||||
client_id = "test-client-004"
|
||||
|
||||
# Make requests to different endpoints
|
||||
result1 = rate_limiter.check_rate_limit(client_id, "/api/pose/current")
|
||||
result2 = rate_limiter.check_rate_limit(client_id, "/api/stream/status")
|
||||
|
||||
# This will fail initially
|
||||
assert result1["allowed"] is True
|
||||
assert result2["allowed"] is True
|
||||
|
||||
# Each endpoint should have separate rate limiting
|
||||
for i in range(4):
|
||||
rate_limiter.check_rate_limit(client_id, "/api/pose/current")
|
||||
|
||||
# Pose endpoint should be at limit, but stream should still work
|
||||
pose_result = rate_limiter.check_rate_limit(client_id, "/api/pose/current")
|
||||
stream_result = rate_limiter.check_rate_limit(client_id, "/api/stream/status")
|
||||
|
||||
assert pose_result["allowed"] is False
|
||||
assert stream_result["allowed"] is True
|
||||
|
||||
|
||||
class TestRateLimitMiddleware:
|
||||
"""Test rate limiting middleware functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request(self):
|
||||
"""Mock FastAPI request."""
|
||||
class MockRequest:
|
||||
def __init__(self, client_ip="127.0.0.1", path="/api/test", method="GET"):
|
||||
self.client = MagicMock()
|
||||
self.client.host = client_ip
|
||||
self.url = MagicMock()
|
||||
self.url.path = path
|
||||
self.method = method
|
||||
self.headers = {}
|
||||
self.state = MagicMock()
|
||||
|
||||
return MockRequest
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response(self):
|
||||
"""Mock FastAPI response."""
|
||||
class MockResponse:
|
||||
def __init__(self):
|
||||
self.status_code = 200
|
||||
self.headers = {}
|
||||
|
||||
return MockResponse()
|
||||
|
||||
@pytest.fixture
|
||||
def rate_limit_middleware(self, rate_limiter):
|
||||
"""Create rate limiting middleware."""
|
||||
class RateLimitMiddleware:
|
||||
def __init__(self, rate_limiter):
|
||||
self.rate_limiter = rate_limiter
|
||||
|
||||
async def __call__(self, request, call_next):
|
||||
# Get client identifier
|
||||
client_id = self._get_client_id(request)
|
||||
endpoint = request.url.path
|
||||
|
||||
# Check rate limit
|
||||
limit_result = self.rate_limiter.check_rate_limit(client_id, endpoint)
|
||||
|
||||
if not limit_result["allowed"]:
|
||||
# Return rate limit exceeded response
|
||||
response = Response(
|
||||
content=f"Rate limit exceeded: {limit_result['reason']}",
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS
|
||||
)
|
||||
response.headers["Retry-After"] = str(limit_result["retry_after"])
|
||||
response.headers["X-RateLimit-Limit"] = str(limit_result.get("limit", "unknown"))
|
||||
response.headers["X-RateLimit-Remaining"] = "0"
|
||||
return response
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add rate limit headers
|
||||
response.headers["X-RateLimit-Limit"] = str(self.rate_limiter.requests_per_minute)
|
||||
response.headers["X-RateLimit-Remaining"] = str(limit_result.get("remaining_minute", 0))
|
||||
response.headers["X-RateLimit-Reset"] = str(int(limit_result.get("reset_time", datetime.utcnow()).timestamp()))
|
||||
|
||||
return response
|
||||
|
||||
def _get_client_id(self, request):
|
||||
"""Get client identifier from request."""
|
||||
# Check for API key in headers
|
||||
api_key = request.headers.get("X-API-Key")
|
||||
if api_key:
|
||||
return f"api:{api_key}"
|
||||
|
||||
# Check for user ID in request state (from auth)
|
||||
if hasattr(request.state, "user") and request.state.user:
|
||||
return f"user:{request.state.user.get('id', 'unknown')}"
|
||||
|
||||
# Fall back to IP address
|
||||
return f"ip:{request.client.host}"
|
||||
|
||||
return RateLimitMiddleware(rate_limiter)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_allows_normal_requests_should_fail_initially(
|
||||
self, rate_limit_middleware, mock_request, mock_response
|
||||
):
|
||||
"""Test middleware allows normal requests - should fail initially."""
|
||||
request = mock_request()
|
||||
|
||||
async def mock_call_next(req):
|
||||
return mock_response
|
||||
|
||||
response = await rate_limit_middleware(request, mock_call_next)
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == 200
|
||||
assert "X-RateLimit-Limit" in response.headers
|
||||
assert "X-RateLimit-Remaining" in response.headers
|
||||
assert "X-RateLimit-Reset" in response.headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_blocks_excessive_requests_should_fail_initially(
|
||||
self, rate_limit_middleware, mock_request
|
||||
):
|
||||
"""Test middleware blocks excessive requests - should fail initially."""
|
||||
request = mock_request()
|
||||
|
||||
async def mock_call_next(req):
|
||||
response = Response(content="OK", status_code=200)
|
||||
return response
|
||||
|
||||
# Make requests up to the limit
|
||||
for i in range(5):
|
||||
response = await rate_limit_middleware(request, mock_call_next)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Next request should be blocked
|
||||
response = await rate_limit_middleware(request, mock_call_next)
|
||||
|
||||
# This will fail initially
|
||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
assert "Retry-After" in response.headers
|
||||
assert "X-RateLimit-Remaining" in response.headers
|
||||
assert response.headers["X-RateLimit-Remaining"] == "0"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_client_identification_should_fail_initially(
|
||||
self, rate_limit_middleware, mock_request
|
||||
):
|
||||
"""Test middleware client identification - should fail initially."""
|
||||
# Test API key identification
|
||||
request_with_api_key = mock_request()
|
||||
request_with_api_key.headers["X-API-Key"] = "test-api-key-123"
|
||||
|
||||
# Test user identification
|
||||
request_with_user = mock_request()
|
||||
request_with_user.state.user = {"id": "user-123"}
|
||||
|
||||
# Test IP identification
|
||||
request_with_ip = mock_request(client_ip="192.168.1.100")
|
||||
|
||||
async def mock_call_next(req):
|
||||
return Response(content="OK", status_code=200)
|
||||
|
||||
# Each should be treated as different clients
|
||||
response1 = await rate_limit_middleware(request_with_api_key, mock_call_next)
|
||||
response2 = await rate_limit_middleware(request_with_user, mock_call_next)
|
||||
response3 = await rate_limit_middleware(request_with_ip, mock_call_next)
|
||||
|
||||
# This will fail initially
|
||||
assert response1.status_code == 200
|
||||
assert response2.status_code == 200
|
||||
assert response3.status_code == 200
|
||||
|
||||
|
||||
class TestRateLimitingStrategies:
|
||||
"""Test different rate limiting strategies."""
|
||||
|
||||
@pytest.fixture
|
||||
def sliding_window_limiter(self):
|
||||
"""Create sliding window rate limiter."""
|
||||
class SlidingWindowLimiter:
|
||||
def __init__(self, window_size_seconds: int = 60, max_requests: int = 10):
|
||||
self.window_size = window_size_seconds
|
||||
self.max_requests = max_requests
|
||||
self.request_times = {}
|
||||
|
||||
def check_limit(self, client_id: str) -> Dict[str, Any]:
|
||||
now = time.time()
|
||||
|
||||
if client_id not in self.request_times:
|
||||
self.request_times[client_id] = []
|
||||
|
||||
# Remove old requests outside the window
|
||||
cutoff_time = now - self.window_size
|
||||
self.request_times[client_id] = [
|
||||
req_time for req_time in self.request_times[client_id]
|
||||
if req_time > cutoff_time
|
||||
]
|
||||
|
||||
# Check if we're at the limit
|
||||
if len(self.request_times[client_id]) >= self.max_requests:
|
||||
oldest_request = min(self.request_times[client_id])
|
||||
retry_after = int(oldest_request + self.window_size - now)
|
||||
|
||||
return {
|
||||
"allowed": False,
|
||||
"retry_after": max(retry_after, 1),
|
||||
"current_requests": len(self.request_times[client_id]),
|
||||
"limit": self.max_requests
|
||||
}
|
||||
|
||||
# Record this request
|
||||
self.request_times[client_id].append(now)
|
||||
|
||||
return {
|
||||
"allowed": True,
|
||||
"remaining": self.max_requests - len(self.request_times[client_id]),
|
||||
"window_reset": int(now + self.window_size)
|
||||
}
|
||||
|
||||
return SlidingWindowLimiter(window_size_seconds=10, max_requests=3)
|
||||
|
||||
@pytest.fixture
|
||||
def token_bucket_limiter(self):
|
||||
"""Create token bucket rate limiter."""
|
||||
class TokenBucketLimiter:
|
||||
def __init__(self, capacity: int = 10, refill_rate: float = 1.0):
|
||||
self.capacity = capacity
|
||||
self.refill_rate = refill_rate # tokens per second
|
||||
self.buckets = {}
|
||||
|
||||
def check_limit(self, client_id: str) -> Dict[str, Any]:
|
||||
now = time.time()
|
||||
|
||||
if client_id not in self.buckets:
|
||||
self.buckets[client_id] = {
|
||||
"tokens": self.capacity,
|
||||
"last_refill": now
|
||||
}
|
||||
|
||||
bucket = self.buckets[client_id]
|
||||
|
||||
# Refill tokens based on time elapsed
|
||||
time_elapsed = now - bucket["last_refill"]
|
||||
tokens_to_add = time_elapsed * self.refill_rate
|
||||
bucket["tokens"] = min(self.capacity, bucket["tokens"] + tokens_to_add)
|
||||
bucket["last_refill"] = now
|
||||
|
||||
# Check if we have tokens available
|
||||
if bucket["tokens"] < 1:
|
||||
return {
|
||||
"allowed": False,
|
||||
"retry_after": int((1 - bucket["tokens"]) / self.refill_rate),
|
||||
"tokens_remaining": bucket["tokens"]
|
||||
}
|
||||
|
||||
# Consume a token
|
||||
bucket["tokens"] -= 1
|
||||
|
||||
return {
|
||||
"allowed": True,
|
||||
"tokens_remaining": bucket["tokens"]
|
||||
}
|
||||
|
||||
return TokenBucketLimiter(capacity=5, refill_rate=0.5) # 0.5 tokens per second
|
||||
|
||||
def test_sliding_window_limiter_should_fail_initially(self, sliding_window_limiter):
|
||||
"""Test sliding window rate limiter - should fail initially."""
|
||||
client_id = "sliding-test-client"
|
||||
|
||||
# Make requests up to limit
|
||||
for i in range(3):
|
||||
result = sliding_window_limiter.check_limit(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert result["allowed"] is True
|
||||
assert "remaining" in result
|
||||
|
||||
# Next request should be blocked
|
||||
result = sliding_window_limiter.check_limit(client_id)
|
||||
assert result["allowed"] is False
|
||||
assert result["current_requests"] == 3
|
||||
assert result["limit"] == 3
|
||||
|
||||
def test_token_bucket_limiter_should_fail_initially(self, token_bucket_limiter):
|
||||
"""Test token bucket rate limiter - should fail initially."""
|
||||
client_id = "bucket-test-client"
|
||||
|
||||
# Make requests up to capacity
|
||||
for i in range(5):
|
||||
result = token_bucket_limiter.check_limit(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert result["allowed"] is True
|
||||
assert "tokens_remaining" in result
|
||||
|
||||
# Next request should be blocked (no tokens left)
|
||||
result = token_bucket_limiter.check_limit(client_id)
|
||||
assert result["allowed"] is False
|
||||
assert result["tokens_remaining"] < 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_bucket_refill_should_fail_initially(self, token_bucket_limiter):
|
||||
"""Test token bucket refill mechanism - should fail initially."""
|
||||
client_id = "refill-test-client"
|
||||
|
||||
# Exhaust all tokens
|
||||
for i in range(5):
|
||||
token_bucket_limiter.check_limit(client_id)
|
||||
|
||||
# Should be blocked
|
||||
result = token_bucket_limiter.check_limit(client_id)
|
||||
assert result["allowed"] is False
|
||||
|
||||
# Wait for refill (simulate time passing)
|
||||
await asyncio.sleep(2.1) # Wait for 1 token to be refilled (0.5 tokens/sec * 2.1 sec > 1)
|
||||
|
||||
# Should now be allowed
|
||||
result = token_bucket_limiter.check_limit(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert result["allowed"] is True
|
||||
|
||||
|
||||
class TestRateLimitingPerformance:
|
||||
"""Test rate limiting performance characteristics."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_rate_limit_checks_should_fail_initially(self):
|
||||
"""Test concurrent rate limit checks - should fail initially."""
|
||||
rate_limiter = MockRateLimiter(requests_per_minute=100, requests_per_hour=1000)
|
||||
|
||||
async def make_request(client_id: str, request_id: int):
|
||||
result = rate_limiter.check_rate_limit(f"{client_id}-{request_id}")
|
||||
return result["allowed"]
|
||||
|
||||
# Create many concurrent requests
|
||||
tasks = [
|
||||
make_request("concurrent-client", i)
|
||||
for i in range(50)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == 50
|
||||
assert all(results) # All should be allowed since they're different clients
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limiter_memory_cleanup_should_fail_initially(self):
|
||||
"""Test rate limiter memory cleanup - should fail initially."""
|
||||
rate_limiter = MockRateLimiter(requests_per_minute=10, requests_per_hour=100)
|
||||
|
||||
# Make requests for many different clients
|
||||
for i in range(100):
|
||||
rate_limiter.check_rate_limit(f"client-{i}")
|
||||
|
||||
initial_memory_size = len(rate_limiter.request_history)
|
||||
|
||||
# Simulate time passing and cleanup
|
||||
for client_key in list(rate_limiter.request_history.keys()):
|
||||
rate_limiter._cleanup_old_requests(client_key)
|
||||
|
||||
# This will fail initially
|
||||
assert initial_memory_size == 100
|
||||
|
||||
# After cleanup, old entries should be removed
|
||||
# (In a real implementation, this would clean up old timestamps)
|
||||
final_memory_size = len([
|
||||
key for key, history in rate_limiter.request_history.items()
|
||||
if history # Only count non-empty histories
|
||||
])
|
||||
|
||||
assert final_memory_size <= initial_memory_size
|
||||
729
tests/integration/test_streaming_pipeline.py
Normal file
729
tests/integration/test_streaming_pipeline.py
Normal file
@@ -0,0 +1,729 @@
|
||||
"""
|
||||
Integration tests for real-time streaming pipeline.
|
||||
|
||||
Tests the complete real-time data flow from CSI collection to client delivery.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional, AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import json
|
||||
import queue
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamFrame:
|
||||
"""Streaming frame data structure."""
|
||||
frame_id: str
|
||||
timestamp: datetime
|
||||
router_id: str
|
||||
pose_data: Dict[str, Any]
|
||||
processing_time_ms: float
|
||||
quality_score: float
|
||||
|
||||
|
||||
class MockStreamBuffer:
|
||||
"""Mock streaming buffer for testing."""
|
||||
|
||||
def __init__(self, max_size: int = 100):
|
||||
self.max_size = max_size
|
||||
self.buffer = asyncio.Queue(maxsize=max_size)
|
||||
self.dropped_frames = 0
|
||||
self.total_frames = 0
|
||||
|
||||
async def put_frame(self, frame: StreamFrame) -> bool:
|
||||
"""Add frame to buffer."""
|
||||
self.total_frames += 1
|
||||
|
||||
try:
|
||||
self.buffer.put_nowait(frame)
|
||||
return True
|
||||
except asyncio.QueueFull:
|
||||
self.dropped_frames += 1
|
||||
return False
|
||||
|
||||
async def get_frame(self, timeout: float = 1.0) -> Optional[StreamFrame]:
|
||||
"""Get frame from buffer."""
|
||||
try:
|
||||
return await asyncio.wait_for(self.buffer.get(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
return None
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get buffer statistics."""
|
||||
return {
|
||||
"buffer_size": self.buffer.qsize(),
|
||||
"max_size": self.max_size,
|
||||
"total_frames": self.total_frames,
|
||||
"dropped_frames": self.dropped_frames,
|
||||
"drop_rate": self.dropped_frames / max(self.total_frames, 1)
|
||||
}
|
||||
|
||||
|
||||
class MockStreamProcessor:
|
||||
"""Mock stream processor for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.is_running = False
|
||||
self.processing_rate = 30 # FPS
|
||||
self.frame_counter = 0
|
||||
self.error_rate = 0.0
|
||||
|
||||
async def start_processing(self, input_buffer: MockStreamBuffer, output_buffer: MockStreamBuffer):
|
||||
"""Start stream processing."""
|
||||
self.is_running = True
|
||||
|
||||
while self.is_running:
|
||||
try:
|
||||
# Get frame from input
|
||||
frame = await input_buffer.get_frame(timeout=0.1)
|
||||
if frame is None:
|
||||
continue
|
||||
|
||||
# Simulate processing error
|
||||
if np.random.random() < self.error_rate:
|
||||
continue # Skip frame due to error
|
||||
|
||||
# Process frame
|
||||
processed_frame = await self._process_frame(frame)
|
||||
|
||||
# Put to output buffer
|
||||
await output_buffer.put_frame(processed_frame)
|
||||
|
||||
# Control processing rate
|
||||
await asyncio.sleep(1.0 / self.processing_rate)
|
||||
|
||||
except Exception as e:
|
||||
# Handle processing errors
|
||||
continue
|
||||
|
||||
async def _process_frame(self, frame: StreamFrame) -> StreamFrame:
|
||||
"""Process a single frame."""
|
||||
# Simulate processing time
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Add processing metadata
|
||||
processed_pose_data = frame.pose_data.copy()
|
||||
processed_pose_data["processed_at"] = datetime.utcnow().isoformat()
|
||||
processed_pose_data["processor_id"] = "stream_processor_001"
|
||||
|
||||
return StreamFrame(
|
||||
frame_id=f"processed_{frame.frame_id}",
|
||||
timestamp=frame.timestamp,
|
||||
router_id=frame.router_id,
|
||||
pose_data=processed_pose_data,
|
||||
processing_time_ms=frame.processing_time_ms + 10, # Add processing overhead
|
||||
quality_score=frame.quality_score * 0.95 # Slight quality degradation
|
||||
)
|
||||
|
||||
def stop_processing(self):
|
||||
"""Stop stream processing."""
|
||||
self.is_running = False
|
||||
|
||||
def set_error_rate(self, error_rate: float):
|
||||
"""Set processing error rate."""
|
||||
self.error_rate = error_rate
|
||||
|
||||
|
||||
class MockWebSocketManager:
|
||||
"""Mock WebSocket manager for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.connected_clients = {}
|
||||
self.message_queue = asyncio.Queue()
|
||||
self.total_messages_sent = 0
|
||||
self.failed_sends = 0
|
||||
|
||||
async def add_client(self, client_id: str, websocket_mock) -> bool:
|
||||
"""Add WebSocket client."""
|
||||
if client_id in self.connected_clients:
|
||||
return False
|
||||
|
||||
self.connected_clients[client_id] = {
|
||||
"websocket": websocket_mock,
|
||||
"connected_at": datetime.utcnow(),
|
||||
"messages_sent": 0,
|
||||
"last_ping": datetime.utcnow()
|
||||
}
|
||||
return True
|
||||
|
||||
async def remove_client(self, client_id: str) -> bool:
|
||||
"""Remove WebSocket client."""
|
||||
if client_id in self.connected_clients:
|
||||
del self.connected_clients[client_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
async def broadcast_frame(self, frame: StreamFrame) -> Dict[str, bool]:
|
||||
"""Broadcast frame to all connected clients."""
|
||||
results = {}
|
||||
|
||||
message = {
|
||||
"type": "pose_update",
|
||||
"frame_id": frame.frame_id,
|
||||
"timestamp": frame.timestamp.isoformat(),
|
||||
"router_id": frame.router_id,
|
||||
"pose_data": frame.pose_data,
|
||||
"processing_time_ms": frame.processing_time_ms,
|
||||
"quality_score": frame.quality_score
|
||||
}
|
||||
|
||||
for client_id, client_info in self.connected_clients.items():
|
||||
try:
|
||||
# Simulate WebSocket send
|
||||
success = await self._send_to_client(client_id, message)
|
||||
results[client_id] = success
|
||||
|
||||
if success:
|
||||
client_info["messages_sent"] += 1
|
||||
self.total_messages_sent += 1
|
||||
else:
|
||||
self.failed_sends += 1
|
||||
|
||||
except Exception:
|
||||
results[client_id] = False
|
||||
self.failed_sends += 1
|
||||
|
||||
return results
|
||||
|
||||
async def _send_to_client(self, client_id: str, message: Dict[str, Any]) -> bool:
|
||||
"""Send message to specific client."""
|
||||
# Simulate network issues
|
||||
if np.random.random() < 0.05: # 5% failure rate
|
||||
return False
|
||||
|
||||
# Simulate send delay
|
||||
await asyncio.sleep(0.001)
|
||||
return True
|
||||
|
||||
def get_client_stats(self) -> Dict[str, Any]:
|
||||
"""Get client statistics."""
|
||||
return {
|
||||
"connected_clients": len(self.connected_clients),
|
||||
"total_messages_sent": self.total_messages_sent,
|
||||
"failed_sends": self.failed_sends,
|
||||
"clients": {
|
||||
client_id: {
|
||||
"messages_sent": info["messages_sent"],
|
||||
"connected_duration": (datetime.utcnow() - info["connected_at"]).total_seconds()
|
||||
}
|
||||
for client_id, info in self.connected_clients.items()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TestStreamingPipelineBasic:
|
||||
"""Test basic streaming pipeline functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def stream_buffer(self):
|
||||
"""Create stream buffer."""
|
||||
return MockStreamBuffer(max_size=50)
|
||||
|
||||
@pytest.fixture
|
||||
def stream_processor(self):
|
||||
"""Create stream processor."""
|
||||
return MockStreamProcessor()
|
||||
|
||||
@pytest.fixture
|
||||
def websocket_manager(self):
|
||||
"""Create WebSocket manager."""
|
||||
return MockWebSocketManager()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_frame(self):
|
||||
"""Create sample stream frame."""
|
||||
return StreamFrame(
|
||||
frame_id="frame_001",
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
pose_data={
|
||||
"persons": [
|
||||
{
|
||||
"person_id": "person_1",
|
||||
"confidence": 0.85,
|
||||
"bounding_box": {"x": 100, "y": 150, "width": 80, "height": 180},
|
||||
"activity": "standing"
|
||||
}
|
||||
],
|
||||
"zone_summary": {"zone1": 1, "zone2": 0}
|
||||
},
|
||||
processing_time_ms=45.2,
|
||||
quality_score=0.92
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffer_frame_operations_should_fail_initially(self, stream_buffer, sample_frame):
|
||||
"""Test buffer frame operations - should fail initially."""
|
||||
# Put frame in buffer
|
||||
result = await stream_buffer.put_frame(sample_frame)
|
||||
|
||||
# This will fail initially
|
||||
assert result is True
|
||||
|
||||
# Get frame from buffer
|
||||
retrieved_frame = await stream_buffer.get_frame()
|
||||
assert retrieved_frame is not None
|
||||
assert retrieved_frame.frame_id == sample_frame.frame_id
|
||||
assert retrieved_frame.router_id == sample_frame.router_id
|
||||
|
||||
# Buffer should be empty now
|
||||
empty_frame = await stream_buffer.get_frame(timeout=0.1)
|
||||
assert empty_frame is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffer_overflow_handling_should_fail_initially(self, sample_frame):
|
||||
"""Test buffer overflow handling - should fail initially."""
|
||||
small_buffer = MockStreamBuffer(max_size=2)
|
||||
|
||||
# Fill buffer to capacity
|
||||
result1 = await small_buffer.put_frame(sample_frame)
|
||||
result2 = await small_buffer.put_frame(sample_frame)
|
||||
|
||||
# This will fail initially
|
||||
assert result1 is True
|
||||
assert result2 is True
|
||||
|
||||
# Next frame should be dropped
|
||||
result3 = await small_buffer.put_frame(sample_frame)
|
||||
assert result3 is False
|
||||
|
||||
# Check statistics
|
||||
stats = small_buffer.get_stats()
|
||||
assert stats["total_frames"] == 3
|
||||
assert stats["dropped_frames"] == 1
|
||||
assert stats["drop_rate"] > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_processing_should_fail_initially(self, stream_processor, sample_frame):
|
||||
"""Test stream processing - should fail initially."""
|
||||
input_buffer = MockStreamBuffer()
|
||||
output_buffer = MockStreamBuffer()
|
||||
|
||||
# Add frame to input buffer
|
||||
await input_buffer.put_frame(sample_frame)
|
||||
|
||||
# Start processing task
|
||||
processing_task = asyncio.create_task(
|
||||
stream_processor.start_processing(input_buffer, output_buffer)
|
||||
)
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Stop processing
|
||||
stream_processor.stop_processing()
|
||||
await processing_task
|
||||
|
||||
# Check output
|
||||
processed_frame = await output_buffer.get_frame(timeout=0.1)
|
||||
|
||||
# This will fail initially
|
||||
assert processed_frame is not None
|
||||
assert processed_frame.frame_id.startswith("processed_")
|
||||
assert "processed_at" in processed_frame.pose_data
|
||||
assert processed_frame.processing_time_ms > sample_frame.processing_time_ms
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_client_management_should_fail_initially(self, websocket_manager):
|
||||
"""Test WebSocket client management - should fail initially."""
|
||||
mock_websocket = MagicMock()
|
||||
|
||||
# Add client
|
||||
result = await websocket_manager.add_client("client_001", mock_websocket)
|
||||
|
||||
# This will fail initially
|
||||
assert result is True
|
||||
assert "client_001" in websocket_manager.connected_clients
|
||||
|
||||
# Try to add duplicate client
|
||||
result = await websocket_manager.add_client("client_001", mock_websocket)
|
||||
assert result is False
|
||||
|
||||
# Remove client
|
||||
result = await websocket_manager.remove_client("client_001")
|
||||
assert result is True
|
||||
assert "client_001" not in websocket_manager.connected_clients
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_frame_broadcasting_should_fail_initially(self, websocket_manager, sample_frame):
|
||||
"""Test frame broadcasting - should fail initially."""
|
||||
# Add multiple clients
|
||||
for i in range(3):
|
||||
await websocket_manager.add_client(f"client_{i:03d}", MagicMock())
|
||||
|
||||
# Broadcast frame
|
||||
results = await websocket_manager.broadcast_frame(sample_frame)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == 3
|
||||
assert all(isinstance(success, bool) for success in results.values())
|
||||
|
||||
# Check statistics
|
||||
stats = websocket_manager.get_client_stats()
|
||||
assert stats["connected_clients"] == 3
|
||||
assert stats["total_messages_sent"] >= 0
|
||||
|
||||
|
||||
class TestStreamingPipelineIntegration:
|
||||
"""Test complete streaming pipeline integration."""
|
||||
|
||||
@pytest.fixture
|
||||
async def streaming_pipeline(self):
|
||||
"""Create complete streaming pipeline."""
|
||||
class StreamingPipeline:
|
||||
def __init__(self):
|
||||
self.input_buffer = MockStreamBuffer(max_size=100)
|
||||
self.output_buffer = MockStreamBuffer(max_size=100)
|
||||
self.processor = MockStreamProcessor()
|
||||
self.websocket_manager = MockWebSocketManager()
|
||||
self.is_running = False
|
||||
self.processing_task = None
|
||||
self.broadcasting_task = None
|
||||
|
||||
async def start(self):
|
||||
"""Start the streaming pipeline."""
|
||||
if self.is_running:
|
||||
return False
|
||||
|
||||
self.is_running = True
|
||||
|
||||
# Start processing task
|
||||
self.processing_task = asyncio.create_task(
|
||||
self.processor.start_processing(self.input_buffer, self.output_buffer)
|
||||
)
|
||||
|
||||
# Start broadcasting task
|
||||
self.broadcasting_task = asyncio.create_task(
|
||||
self._broadcast_loop()
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the streaming pipeline."""
|
||||
if not self.is_running:
|
||||
return False
|
||||
|
||||
self.is_running = False
|
||||
self.processor.stop_processing()
|
||||
|
||||
# Cancel tasks
|
||||
if self.processing_task:
|
||||
self.processing_task.cancel()
|
||||
if self.broadcasting_task:
|
||||
self.broadcasting_task.cancel()
|
||||
|
||||
return True
|
||||
|
||||
async def add_frame(self, frame: StreamFrame) -> bool:
|
||||
"""Add frame to pipeline."""
|
||||
return await self.input_buffer.put_frame(frame)
|
||||
|
||||
async def add_client(self, client_id: str, websocket_mock) -> bool:
|
||||
"""Add WebSocket client."""
|
||||
return await self.websocket_manager.add_client(client_id, websocket_mock)
|
||||
|
||||
async def _broadcast_loop(self):
|
||||
"""Broadcasting loop."""
|
||||
while self.is_running:
|
||||
try:
|
||||
frame = await self.output_buffer.get_frame(timeout=0.1)
|
||||
if frame:
|
||||
await self.websocket_manager.broadcast_frame(frame)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
def get_pipeline_stats(self) -> Dict[str, Any]:
|
||||
"""Get pipeline statistics."""
|
||||
return {
|
||||
"is_running": self.is_running,
|
||||
"input_buffer": self.input_buffer.get_stats(),
|
||||
"output_buffer": self.output_buffer.get_stats(),
|
||||
"websocket_clients": self.websocket_manager.get_client_stats()
|
||||
}
|
||||
|
||||
return StreamingPipeline()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_streaming_should_fail_initially(self, streaming_pipeline):
|
||||
"""Test end-to-end streaming - should fail initially."""
|
||||
# Start pipeline
|
||||
result = await streaming_pipeline.start()
|
||||
|
||||
# This will fail initially
|
||||
assert result is True
|
||||
assert streaming_pipeline.is_running is True
|
||||
|
||||
# Add clients
|
||||
for i in range(2):
|
||||
await streaming_pipeline.add_client(f"client_{i}", MagicMock())
|
||||
|
||||
# Add frames
|
||||
for i in range(5):
|
||||
frame = StreamFrame(
|
||||
frame_id=f"frame_{i:03d}",
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
pose_data={"persons": [], "zone_summary": {}},
|
||||
processing_time_ms=30.0,
|
||||
quality_score=0.9
|
||||
)
|
||||
await streaming_pipeline.add_frame(frame)
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Stop pipeline
|
||||
await streaming_pipeline.stop()
|
||||
|
||||
# Check statistics
|
||||
stats = streaming_pipeline.get_pipeline_stats()
|
||||
assert stats["input_buffer"]["total_frames"] == 5
|
||||
assert stats["websocket_clients"]["connected_clients"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_performance_should_fail_initially(self, streaming_pipeline):
|
||||
"""Test pipeline performance - should fail initially."""
|
||||
await streaming_pipeline.start()
|
||||
|
||||
# Add multiple clients
|
||||
for i in range(10):
|
||||
await streaming_pipeline.add_client(f"client_{i:03d}", MagicMock())
|
||||
|
||||
# Measure throughput
|
||||
start_time = datetime.utcnow()
|
||||
frame_count = 50
|
||||
|
||||
for i in range(frame_count):
|
||||
frame = StreamFrame(
|
||||
frame_id=f"perf_frame_{i:03d}",
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
pose_data={"persons": [], "zone_summary": {}},
|
||||
processing_time_ms=25.0,
|
||||
quality_score=0.88
|
||||
)
|
||||
await streaming_pipeline.add_frame(frame)
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(2.0)
|
||||
|
||||
end_time = datetime.utcnow()
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
|
||||
await streaming_pipeline.stop()
|
||||
|
||||
# This will fail initially
|
||||
# Check performance metrics
|
||||
stats = streaming_pipeline.get_pipeline_stats()
|
||||
throughput = frame_count / duration
|
||||
|
||||
assert throughput > 10 # Should process at least 10 FPS
|
||||
assert stats["input_buffer"]["drop_rate"] < 0.1 # Less than 10% drop rate
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_error_recovery_should_fail_initially(self, streaming_pipeline):
|
||||
"""Test pipeline error recovery - should fail initially."""
|
||||
await streaming_pipeline.start()
|
||||
|
||||
# Set high error rate
|
||||
streaming_pipeline.processor.set_error_rate(0.5) # 50% error rate
|
||||
|
||||
# Add frames
|
||||
for i in range(20):
|
||||
frame = StreamFrame(
|
||||
frame_id=f"error_frame_{i:03d}",
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
pose_data={"persons": [], "zone_summary": {}},
|
||||
processing_time_ms=30.0,
|
||||
quality_score=0.9
|
||||
)
|
||||
await streaming_pipeline.add_frame(frame)
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
await streaming_pipeline.stop()
|
||||
|
||||
# This will fail initially
|
||||
# Pipeline should continue running despite errors
|
||||
stats = streaming_pipeline.get_pipeline_stats()
|
||||
assert stats["input_buffer"]["total_frames"] == 20
|
||||
# Some frames should be processed despite errors
|
||||
assert stats["output_buffer"]["total_frames"] > 0
|
||||
|
||||
|
||||
class TestStreamingLatency:
|
||||
"""Test streaming latency characteristics."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_latency_should_fail_initially(self):
|
||||
"""Test end-to-end latency - should fail initially."""
|
||||
class LatencyTracker:
|
||||
def __init__(self):
|
||||
self.latencies = []
|
||||
|
||||
async def measure_latency(self, frame: StreamFrame) -> float:
|
||||
"""Measure processing latency."""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
# Simulate processing pipeline
|
||||
await asyncio.sleep(0.05) # 50ms processing time
|
||||
|
||||
end_time = datetime.utcnow()
|
||||
latency = (end_time - start_time).total_seconds() * 1000 # Convert to ms
|
||||
|
||||
self.latencies.append(latency)
|
||||
return latency
|
||||
|
||||
tracker = LatencyTracker()
|
||||
|
||||
# Measure latency for multiple frames
|
||||
for i in range(10):
|
||||
frame = StreamFrame(
|
||||
frame_id=f"latency_frame_{i}",
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
pose_data={},
|
||||
processing_time_ms=0,
|
||||
quality_score=1.0
|
||||
)
|
||||
|
||||
latency = await tracker.measure_latency(frame)
|
||||
|
||||
# This will fail initially
|
||||
assert latency > 0
|
||||
assert latency < 200 # Should be less than 200ms
|
||||
|
||||
# Check average latency
|
||||
avg_latency = sum(tracker.latencies) / len(tracker.latencies)
|
||||
assert avg_latency < 100 # Average should be less than 100ms
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_stream_handling_should_fail_initially(self):
|
||||
"""Test concurrent stream handling - should fail initially."""
|
||||
async def process_stream(stream_id: str, frame_count: int) -> Dict[str, Any]:
|
||||
"""Process a single stream."""
|
||||
buffer = MockStreamBuffer()
|
||||
processed_frames = 0
|
||||
|
||||
for i in range(frame_count):
|
||||
frame = StreamFrame(
|
||||
frame_id=f"{stream_id}_frame_{i}",
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id=stream_id,
|
||||
pose_data={},
|
||||
processing_time_ms=20.0,
|
||||
quality_score=0.9
|
||||
)
|
||||
|
||||
success = await buffer.put_frame(frame)
|
||||
if success:
|
||||
processed_frames += 1
|
||||
|
||||
await asyncio.sleep(0.01) # Simulate frame rate
|
||||
|
||||
return {
|
||||
"stream_id": stream_id,
|
||||
"processed_frames": processed_frames,
|
||||
"total_frames": frame_count
|
||||
}
|
||||
|
||||
# Process multiple streams concurrently
|
||||
streams = ["router_001", "router_002", "router_003"]
|
||||
tasks = [process_stream(stream_id, 20) for stream_id in streams]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == 3
|
||||
|
||||
for result in results:
|
||||
assert result["processed_frames"] == result["total_frames"]
|
||||
assert result["stream_id"] in streams
|
||||
|
||||
|
||||
class TestStreamingResilience:
|
||||
"""Test streaming pipeline resilience."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_disconnection_handling_should_fail_initially(self):
|
||||
"""Test client disconnection handling - should fail initially."""
|
||||
websocket_manager = MockWebSocketManager()
|
||||
|
||||
# Add clients
|
||||
client_ids = [f"client_{i:03d}" for i in range(5)]
|
||||
for client_id in client_ids:
|
||||
await websocket_manager.add_client(client_id, MagicMock())
|
||||
|
||||
# Simulate frame broadcasting
|
||||
frame = StreamFrame(
|
||||
frame_id="disconnect_test_frame",
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
pose_data={},
|
||||
processing_time_ms=30.0,
|
||||
quality_score=0.9
|
||||
)
|
||||
|
||||
# Broadcast to all clients
|
||||
results = await websocket_manager.broadcast_frame(frame)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == 5
|
||||
|
||||
# Simulate client disconnections
|
||||
await websocket_manager.remove_client("client_001")
|
||||
await websocket_manager.remove_client("client_003")
|
||||
|
||||
# Broadcast again
|
||||
results = await websocket_manager.broadcast_frame(frame)
|
||||
assert len(results) == 3 # Only remaining clients
|
||||
|
||||
# Check statistics
|
||||
stats = websocket_manager.get_client_stats()
|
||||
assert stats["connected_clients"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_pressure_handling_should_fail_initially(self):
|
||||
"""Test memory pressure handling - should fail initially."""
|
||||
# Create small buffers to simulate memory pressure
|
||||
small_buffer = MockStreamBuffer(max_size=5)
|
||||
|
||||
# Generate many frames quickly
|
||||
frames_generated = 0
|
||||
frames_accepted = 0
|
||||
|
||||
for i in range(20):
|
||||
frame = StreamFrame(
|
||||
frame_id=f"memory_pressure_frame_{i}",
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
pose_data={},
|
||||
processing_time_ms=25.0,
|
||||
quality_score=0.85
|
||||
)
|
||||
|
||||
frames_generated += 1
|
||||
success = await small_buffer.put_frame(frame)
|
||||
if success:
|
||||
frames_accepted += 1
|
||||
|
||||
# This will fail initially
|
||||
# Buffer should handle memory pressure gracefully
|
||||
stats = small_buffer.get_stats()
|
||||
assert stats["total_frames"] == frames_generated
|
||||
assert stats["dropped_frames"] > 0 # Some frames should be dropped
|
||||
assert frames_accepted <= small_buffer.max_size
|
||||
|
||||
# Drop rate should be reasonable
|
||||
assert stats["drop_rate"] > 0.5 # More than 50% dropped due to small buffer
|
||||
419
tests/integration/test_websocket_streaming.py
Normal file
419
tests/integration/test_websocket_streaming.py
Normal file
@@ -0,0 +1,419 @@
|
||||
"""
|
||||
Integration tests for WebSocket streaming functionality.
|
||||
|
||||
Tests WebSocket connections, message handling, and real-time data streaming.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import websockets
|
||||
from fastapi import FastAPI, WebSocket
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
class MockWebSocket:
|
||||
"""Mock WebSocket for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.messages_sent = []
|
||||
self.messages_received = []
|
||||
self.closed = False
|
||||
self.accept_called = False
|
||||
|
||||
async def accept(self):
|
||||
"""Mock accept method."""
|
||||
self.accept_called = True
|
||||
|
||||
async def send_json(self, data: Dict[str, Any]):
|
||||
"""Mock send_json method."""
|
||||
self.messages_sent.append(data)
|
||||
|
||||
async def send_text(self, text: str):
|
||||
"""Mock send_text method."""
|
||||
self.messages_sent.append(text)
|
||||
|
||||
async def receive_text(self) -> str:
|
||||
"""Mock receive_text method."""
|
||||
if self.messages_received:
|
||||
return self.messages_received.pop(0)
|
||||
# Simulate WebSocket disconnect
|
||||
from fastapi import WebSocketDisconnect
|
||||
raise WebSocketDisconnect()
|
||||
|
||||
async def close(self):
|
||||
"""Mock close method."""
|
||||
self.closed = True
|
||||
|
||||
def add_received_message(self, message: str):
|
||||
"""Add a message to be received."""
|
||||
self.messages_received.append(message)
|
||||
|
||||
|
||||
class TestWebSocketStreaming:
|
||||
"""Integration tests for WebSocket streaming."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket(self):
|
||||
"""Create mock WebSocket."""
|
||||
return MockWebSocket()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_connection_manager(self):
|
||||
"""Mock connection manager."""
|
||||
manager = AsyncMock()
|
||||
manager.connect.return_value = "client-001"
|
||||
manager.disconnect.return_value = True
|
||||
manager.get_connection_stats.return_value = {
|
||||
"total_clients": 1,
|
||||
"active_streams": ["pose"]
|
||||
}
|
||||
manager.broadcast.return_value = 1
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stream_service(self):
|
||||
"""Mock stream service."""
|
||||
service = AsyncMock()
|
||||
service.get_status.return_value = {
|
||||
"is_active": True,
|
||||
"active_streams": [],
|
||||
"uptime_seconds": 3600.0
|
||||
}
|
||||
service.is_active.return_value = True
|
||||
service.start.return_value = None
|
||||
service.stop.return_value = None
|
||||
return service
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_pose_connection_should_fail_initially(self, mock_websocket, mock_connection_manager):
|
||||
"""Test WebSocket pose connection establishment - should fail initially."""
|
||||
# This test should fail because we haven't implemented the WebSocket handler properly
|
||||
|
||||
# Simulate WebSocket connection
|
||||
zone_ids = "zone1,zone2"
|
||||
min_confidence = 0.7
|
||||
max_fps = 30
|
||||
|
||||
# Mock the websocket_pose_stream function
|
||||
async def mock_websocket_handler(websocket, zone_ids, min_confidence, max_fps):
|
||||
await websocket.accept()
|
||||
|
||||
# Parse zone IDs
|
||||
zone_list = [zone.strip() for zone in zone_ids.split(",") if zone.strip()]
|
||||
|
||||
# Register client
|
||||
client_id = await mock_connection_manager.connect(
|
||||
websocket=websocket,
|
||||
stream_type="pose",
|
||||
zone_ids=zone_list,
|
||||
min_confidence=min_confidence,
|
||||
max_fps=max_fps
|
||||
)
|
||||
|
||||
# Send confirmation
|
||||
await websocket.send_json({
|
||||
"type": "connection_established",
|
||||
"client_id": client_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"config": {
|
||||
"zone_ids": zone_list,
|
||||
"min_confidence": min_confidence,
|
||||
"max_fps": max_fps
|
||||
}
|
||||
})
|
||||
|
||||
return client_id
|
||||
|
||||
# Execute the handler
|
||||
client_id = await mock_websocket_handler(mock_websocket, zone_ids, min_confidence, max_fps)
|
||||
|
||||
# This assertion will fail initially, driving us to implement the WebSocket handler
|
||||
assert mock_websocket.accept_called
|
||||
assert len(mock_websocket.messages_sent) == 1
|
||||
assert mock_websocket.messages_sent[0]["type"] == "connection_established"
|
||||
assert mock_websocket.messages_sent[0]["client_id"] == "client-001"
|
||||
assert "config" in mock_websocket.messages_sent[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_message_handling_should_fail_initially(self, mock_websocket):
|
||||
"""Test WebSocket message handling - should fail initially."""
|
||||
# Mock message handler
|
||||
async def handle_websocket_message(client_id: str, data: Dict[str, Any], websocket):
|
||||
message_type = data.get("type")
|
||||
|
||||
if message_type == "ping":
|
||||
await websocket.send_json({
|
||||
"type": "pong",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
elif message_type == "update_config":
|
||||
config = data.get("config", {})
|
||||
await websocket.send_json({
|
||||
"type": "config_updated",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"config": config
|
||||
})
|
||||
else:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": f"Unknown message type: {message_type}"
|
||||
})
|
||||
|
||||
# Test ping message
|
||||
ping_data = {"type": "ping"}
|
||||
await handle_websocket_message("client-001", ping_data, mock_websocket)
|
||||
|
||||
# This will fail initially
|
||||
assert len(mock_websocket.messages_sent) == 1
|
||||
assert mock_websocket.messages_sent[0]["type"] == "pong"
|
||||
|
||||
# Test config update
|
||||
mock_websocket.messages_sent.clear()
|
||||
config_data = {
|
||||
"type": "update_config",
|
||||
"config": {"min_confidence": 0.8, "max_fps": 15}
|
||||
}
|
||||
await handle_websocket_message("client-001", config_data, mock_websocket)
|
||||
|
||||
# This will fail initially
|
||||
assert len(mock_websocket.messages_sent) == 1
|
||||
assert mock_websocket.messages_sent[0]["type"] == "config_updated"
|
||||
assert mock_websocket.messages_sent[0]["config"]["min_confidence"] == 0.8
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_events_stream_should_fail_initially(self, mock_websocket, mock_connection_manager):
|
||||
"""Test WebSocket events stream - should fail initially."""
|
||||
# Mock events stream handler
|
||||
async def mock_events_handler(websocket, event_types, zone_ids):
|
||||
await websocket.accept()
|
||||
|
||||
# Parse parameters
|
||||
event_list = [event.strip() for event in event_types.split(",") if event.strip()] if event_types else None
|
||||
zone_list = [zone.strip() for zone in zone_ids.split(",") if zone.strip()] if zone_ids else None
|
||||
|
||||
# Register client
|
||||
client_id = await mock_connection_manager.connect(
|
||||
websocket=websocket,
|
||||
stream_type="events",
|
||||
zone_ids=zone_list,
|
||||
event_types=event_list
|
||||
)
|
||||
|
||||
# Send confirmation
|
||||
await websocket.send_json({
|
||||
"type": "connection_established",
|
||||
"client_id": client_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"config": {
|
||||
"event_types": event_list,
|
||||
"zone_ids": zone_list
|
||||
}
|
||||
})
|
||||
|
||||
return client_id
|
||||
|
||||
# Execute handler
|
||||
client_id = await mock_events_handler(mock_websocket, "fall_detection,intrusion", "zone1")
|
||||
|
||||
# This will fail initially
|
||||
assert mock_websocket.accept_called
|
||||
assert len(mock_websocket.messages_sent) == 1
|
||||
assert mock_websocket.messages_sent[0]["type"] == "connection_established"
|
||||
assert mock_websocket.messages_sent[0]["config"]["event_types"] == ["fall_detection", "intrusion"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_disconnect_handling_should_fail_initially(self, mock_websocket, mock_connection_manager):
|
||||
"""Test WebSocket disconnect handling - should fail initially."""
|
||||
# Mock disconnect scenario
|
||||
client_id = "client-001"
|
||||
|
||||
# Simulate disconnect
|
||||
disconnect_result = await mock_connection_manager.disconnect(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert disconnect_result is True
|
||||
mock_connection_manager.disconnect.assert_called_once_with(client_id)
|
||||
|
||||
|
||||
class TestWebSocketConnectionManager:
|
||||
"""Test WebSocket connection management."""
|
||||
|
||||
@pytest.fixture
|
||||
def connection_manager(self):
|
||||
"""Create connection manager for testing."""
|
||||
# Mock connection manager implementation
|
||||
class MockConnectionManager:
|
||||
def __init__(self):
|
||||
self.connections = {}
|
||||
self.client_counter = 0
|
||||
|
||||
async def connect(self, websocket, stream_type, zone_ids=None, **kwargs):
|
||||
self.client_counter += 1
|
||||
client_id = f"client-{self.client_counter:03d}"
|
||||
self.connections[client_id] = {
|
||||
"websocket": websocket,
|
||||
"stream_type": stream_type,
|
||||
"zone_ids": zone_ids or [],
|
||||
"connected_at": datetime.utcnow(),
|
||||
**kwargs
|
||||
}
|
||||
return client_id
|
||||
|
||||
async def disconnect(self, client_id):
|
||||
if client_id in self.connections:
|
||||
del self.connections[client_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_connected_clients(self):
|
||||
return list(self.connections.keys())
|
||||
|
||||
async def get_connection_stats(self):
|
||||
return {
|
||||
"total_clients": len(self.connections),
|
||||
"active_streams": list(set(conn["stream_type"] for conn in self.connections.values()))
|
||||
}
|
||||
|
||||
async def broadcast(self, data, stream_type=None, zone_ids=None):
|
||||
sent_count = 0
|
||||
for client_id, conn in self.connections.items():
|
||||
if stream_type and conn["stream_type"] != stream_type:
|
||||
continue
|
||||
if zone_ids and not any(zone in conn["zone_ids"] for zone in zone_ids):
|
||||
continue
|
||||
|
||||
# Mock sending data
|
||||
sent_count += 1
|
||||
|
||||
return sent_count
|
||||
|
||||
return MockConnectionManager()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_manager_connect_should_fail_initially(self, connection_manager, mock_websocket):
|
||||
"""Test connection manager connect functionality - should fail initially."""
|
||||
client_id = await connection_manager.connect(
|
||||
websocket=mock_websocket,
|
||||
stream_type="pose",
|
||||
zone_ids=["zone1", "zone2"],
|
||||
min_confidence=0.7
|
||||
)
|
||||
|
||||
# This will fail initially
|
||||
assert client_id == "client-001"
|
||||
assert client_id in connection_manager.connections
|
||||
assert connection_manager.connections[client_id]["stream_type"] == "pose"
|
||||
assert connection_manager.connections[client_id]["zone_ids"] == ["zone1", "zone2"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_manager_disconnect_should_fail_initially(self, connection_manager, mock_websocket):
|
||||
"""Test connection manager disconnect functionality - should fail initially."""
|
||||
# Connect first
|
||||
client_id = await connection_manager.connect(
|
||||
websocket=mock_websocket,
|
||||
stream_type="pose"
|
||||
)
|
||||
|
||||
# Disconnect
|
||||
result = await connection_manager.disconnect(client_id)
|
||||
|
||||
# This will fail initially
|
||||
assert result is True
|
||||
assert client_id not in connection_manager.connections
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_manager_broadcast_should_fail_initially(self, connection_manager):
|
||||
"""Test connection manager broadcast functionality - should fail initially."""
|
||||
# Connect multiple clients
|
||||
ws1 = MockWebSocket()
|
||||
ws2 = MockWebSocket()
|
||||
|
||||
client1 = await connection_manager.connect(ws1, "pose", zone_ids=["zone1"])
|
||||
client2 = await connection_manager.connect(ws2, "events", zone_ids=["zone2"])
|
||||
|
||||
# Broadcast to pose stream
|
||||
sent_count = await connection_manager.broadcast(
|
||||
data={"type": "pose_data", "data": {}},
|
||||
stream_type="pose"
|
||||
)
|
||||
|
||||
# This will fail initially
|
||||
assert sent_count == 1
|
||||
|
||||
# Broadcast to specific zone
|
||||
sent_count = await connection_manager.broadcast(
|
||||
data={"type": "zone_event", "data": {}},
|
||||
zone_ids=["zone1"]
|
||||
)
|
||||
|
||||
# This will fail initially
|
||||
assert sent_count == 1
|
||||
|
||||
|
||||
class TestWebSocketPerformance:
|
||||
"""Test WebSocket performance characteristics."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_concurrent_connections_should_fail_initially(self):
|
||||
"""Test handling multiple concurrent WebSocket connections - should fail initially."""
|
||||
# Mock multiple connections
|
||||
connection_count = 10
|
||||
connections = []
|
||||
|
||||
for i in range(connection_count):
|
||||
mock_ws = MockWebSocket()
|
||||
connections.append(mock_ws)
|
||||
|
||||
# Simulate concurrent connections
|
||||
async def simulate_connection(websocket, client_id):
|
||||
await websocket.accept()
|
||||
await websocket.send_json({
|
||||
"type": "connection_established",
|
||||
"client_id": client_id
|
||||
})
|
||||
return True
|
||||
|
||||
# Execute concurrent connections
|
||||
tasks = [
|
||||
simulate_connection(ws, f"client-{i:03d}")
|
||||
for i, ws in enumerate(connections)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == connection_count
|
||||
assert all(results)
|
||||
assert all(ws.accept_called for ws in connections)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_message_throughput_should_fail_initially(self):
|
||||
"""Test WebSocket message throughput - should fail initially."""
|
||||
mock_ws = MockWebSocket()
|
||||
message_count = 100
|
||||
|
||||
# Simulate high-frequency message sending
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
for i in range(message_count):
|
||||
await mock_ws.send_json({
|
||||
"type": "pose_data",
|
||||
"frame_id": f"frame-{i:04d}",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
|
||||
end_time = datetime.utcnow()
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
|
||||
# This will fail initially
|
||||
assert len(mock_ws.messages_sent) == message_count
|
||||
assert duration < 1.0 # Should handle 100 messages in under 1 second
|
||||
|
||||
# Calculate throughput
|
||||
throughput = message_count / duration if duration > 0 else float('inf')
|
||||
assert throughput > 100 # Should handle at least 100 messages per second
|
||||
Reference in New Issue
Block a user