feat: Implement hardware, pose, and stream services for WiFi-DensePose API
- Added HardwareService for managing router interfaces, data collection, and monitoring. - Introduced PoseService for processing CSI data and estimating poses using neural networks. - Created StreamService for real-time data streaming via WebSocket connections. - Implemented initialization, start, stop, and status retrieval methods for each service. - Added data processing, error handling, and statistics tracking across services. - Integrated mock data generation for development and testing purposes.
This commit is contained in:
18
=3.0.0
18
=3.0.0
@@ -1,18 +0,0 @@
|
|||||||
Collecting paramiko
|
|
||||||
Downloading paramiko-3.5.1-py3-none-any.whl.metadata (4.6 kB)
|
|
||||||
Collecting bcrypt>=3.2 (from paramiko)
|
|
||||||
Downloading bcrypt-4.3.0-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (10 kB)
|
|
||||||
Collecting cryptography>=3.3 (from paramiko)
|
|
||||||
Downloading cryptography-45.0.3-cp311-abi3-manylinux_2_28_x86_64.whl.metadata (5.7 kB)
|
|
||||||
Collecting pynacl>=1.5 (from paramiko)
|
|
||||||
Downloading PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl.metadata (8.6 kB)
|
|
||||||
Requirement already satisfied: cffi>=1.14 in /home/codespace/.local/lib/python3.12/site-packages (from cryptography>=3.3->paramiko) (1.17.1)
|
|
||||||
Requirement already satisfied: pycparser in /home/codespace/.local/lib/python3.12/site-packages (from cffi>=1.14->cryptography>=3.3->paramiko) (2.22)
|
|
||||||
Downloading paramiko-3.5.1-py3-none-any.whl (227 kB)
|
|
||||||
Downloading bcrypt-4.3.0-cp39-abi3-manylinux_2_28_x86_64.whl (284 kB)
|
|
||||||
Downloading cryptography-45.0.3-cp311-abi3-manylinux_2_28_x86_64.whl (4.5 MB)
|
|
||||||
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.5/4.5 MB 45.0 MB/s eta 0:00:00
|
|
||||||
Downloading PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (856 kB)
|
|
||||||
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 856.7/856.7 kB 37.4 MB/s eta 0:00:00
|
|
||||||
Installing collected packages: bcrypt, pynacl, cryptography, paramiko
|
|
||||||
Successfully installed bcrypt-4.3.0 cryptography-45.0.3 paramiko-3.5.1 pynacl-1.5.0
|
|
||||||
183
example.env
Normal file
183
example.env
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
# WiFi-DensePose API Environment Configuration Template
|
||||||
|
# Copy this file to .env and modify the values according to your setup
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# APPLICATION SETTINGS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Application metadata
|
||||||
|
APP_NAME=WiFi-DensePose API
|
||||||
|
VERSION=1.0.0
|
||||||
|
ENVIRONMENT=development # Options: development, staging, production
|
||||||
|
DEBUG=true
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# SERVER SETTINGS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Server configuration
|
||||||
|
HOST=0.0.0.0
|
||||||
|
PORT=8000
|
||||||
|
RELOAD=true # Auto-reload on code changes (development only)
|
||||||
|
WORKERS=1 # Number of worker processes
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# SECURITY SETTINGS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# IMPORTANT: Change these values for production!
|
||||||
|
SECRET_KEY=your-secret-key-here-change-for-production
|
||||||
|
JWT_ALGORITHM=HS256
|
||||||
|
JWT_EXPIRE_HOURS=24
|
||||||
|
|
||||||
|
# Allowed hosts (restrict in production)
|
||||||
|
ALLOWED_HOSTS=* # Use specific domains in production: example.com,api.example.com
|
||||||
|
|
||||||
|
# CORS settings (restrict in production)
|
||||||
|
CORS_ORIGINS=* # Use specific origins in production: https://example.com,https://app.example.com
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# DATABASE SETTINGS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Database connection (optional - defaults to SQLite in development)
|
||||||
|
# DATABASE_URL=postgresql://user:password@localhost:5432/wifi_densepose
|
||||||
|
# DATABASE_POOL_SIZE=10
|
||||||
|
# DATABASE_MAX_OVERFLOW=20
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# REDIS SETTINGS (Optional - for caching and rate limiting)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Redis connection (optional - defaults to localhost in development)
|
||||||
|
# REDIS_URL=redis://localhost:6379/0
|
||||||
|
# REDIS_PASSWORD=your-redis-password
|
||||||
|
# REDIS_DB=0
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# HARDWARE SETTINGS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# WiFi interface configuration
|
||||||
|
WIFI_INTERFACE=wlan0
|
||||||
|
CSI_BUFFER_SIZE=1000
|
||||||
|
HARDWARE_POLLING_INTERVAL=0.1
|
||||||
|
|
||||||
|
# Hardware mock settings (for development/testing)
|
||||||
|
MOCK_HARDWARE=true
|
||||||
|
MOCK_POSE_DATA=true
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# POSE ESTIMATION SETTINGS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Model configuration
|
||||||
|
# POSE_MODEL_PATH=/path/to/your/pose/model.pth
|
||||||
|
POSE_CONFIDENCE_THRESHOLD=0.5
|
||||||
|
POSE_PROCESSING_BATCH_SIZE=32
|
||||||
|
POSE_MAX_PERSONS=10
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# STREAMING SETTINGS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Real-time streaming configuration
|
||||||
|
STREAM_FPS=30
|
||||||
|
STREAM_BUFFER_SIZE=100
|
||||||
|
WEBSOCKET_PING_INTERVAL=60
|
||||||
|
WEBSOCKET_TIMEOUT=300
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# FEATURE FLAGS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Enable/disable features
|
||||||
|
ENABLE_AUTHENTICATION=false # Set to true for production
|
||||||
|
ENABLE_RATE_LIMITING=false # Set to true for production
|
||||||
|
ENABLE_WEBSOCKETS=true
|
||||||
|
ENABLE_REAL_TIME_PROCESSING=true
|
||||||
|
ENABLE_HISTORICAL_DATA=true
|
||||||
|
|
||||||
|
# Development features
|
||||||
|
ENABLE_TEST_ENDPOINTS=true # Set to false for production
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# RATE LIMITING SETTINGS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Rate limiting configuration
|
||||||
|
RATE_LIMIT_REQUESTS=100
|
||||||
|
RATE_LIMIT_AUTHENTICATED_REQUESTS=1000
|
||||||
|
RATE_LIMIT_WINDOW=3600 # Window in seconds
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# LOGGING SETTINGS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Logging configuration
|
||||||
|
LOG_LEVEL=INFO # Options: DEBUG, INFO, WARNING, ERROR, CRITICAL
|
||||||
|
LOG_FORMAT=%(asctime)s - %(name)s - %(levelname)s - %(message)s
|
||||||
|
# LOG_FILE=/path/to/logfile.log # Optional: specify log file path
|
||||||
|
LOG_MAX_SIZE=10485760 # 10MB
|
||||||
|
LOG_BACKUP_COUNT=5
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# STORAGE SETTINGS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Storage directories
|
||||||
|
DATA_STORAGE_PATH=./data
|
||||||
|
MODEL_STORAGE_PATH=./models
|
||||||
|
TEMP_STORAGE_PATH=./temp
|
||||||
|
MAX_STORAGE_SIZE_GB=100
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# MONITORING SETTINGS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Monitoring and metrics
|
||||||
|
METRICS_ENABLED=true
|
||||||
|
HEALTH_CHECK_INTERVAL=30
|
||||||
|
PERFORMANCE_MONITORING=true
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# API SETTINGS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# API configuration
|
||||||
|
API_PREFIX=/api/v1
|
||||||
|
DOCS_URL=/docs # Set to null to disable in production
|
||||||
|
REDOC_URL=/redoc # Set to null to disable in production
|
||||||
|
OPENAPI_URL=/openapi.json # Set to null to disable in production
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# PRODUCTION SETTINGS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# For production deployment, ensure you:
|
||||||
|
# 1. Set ENVIRONMENT=production
|
||||||
|
# 2. Set DEBUG=false
|
||||||
|
# 3. Use a strong SECRET_KEY
|
||||||
|
# 4. Configure proper DATABASE_URL
|
||||||
|
# 5. Restrict ALLOWED_HOSTS and CORS_ORIGINS
|
||||||
|
# 6. Enable ENABLE_AUTHENTICATION=true
|
||||||
|
# 7. Enable ENABLE_RATE_LIMITING=true
|
||||||
|
# 8. Set ENABLE_TEST_ENDPOINTS=false
|
||||||
|
# 9. Disable API documentation URLs (set to null)
|
||||||
|
# 10. Configure proper logging with LOG_FILE
|
||||||
|
|
||||||
|
# Example production settings:
|
||||||
|
# ENVIRONMENT=production
|
||||||
|
# DEBUG=false
|
||||||
|
# SECRET_KEY=your-very-secure-secret-key-here
|
||||||
|
# DATABASE_URL=postgresql://user:password@db-host:5432/wifi_densepose
|
||||||
|
# REDIS_URL=redis://redis-host:6379/0
|
||||||
|
# ALLOWED_HOSTS=yourdomain.com,api.yourdomain.com
|
||||||
|
# CORS_ORIGINS=https://yourdomain.com,https://app.yourdomain.com
|
||||||
|
# ENABLE_AUTHENTICATION=true
|
||||||
|
# ENABLE_RATE_LIMITING=true
|
||||||
|
# ENABLE_TEST_ENDPOINTS=false
|
||||||
|
# DOCS_URL=null
|
||||||
|
# REDOC_URL=null
|
||||||
|
# OPENAPI_URL=null
|
||||||
|
# LOG_FILE=/var/log/wifi-densepose/app.log
|
||||||
@@ -17,6 +17,9 @@ fastapi>=0.95.0
|
|||||||
uvicorn>=0.20.0
|
uvicorn>=0.20.0
|
||||||
websockets>=10.4
|
websockets>=10.4
|
||||||
pydantic>=1.10.0
|
pydantic>=1.10.0
|
||||||
|
python-jose[cryptography]>=3.3.0
|
||||||
|
python-multipart>=0.0.6
|
||||||
|
passlib[bcrypt]>=1.7.4
|
||||||
|
|
||||||
# Hardware interface dependencies
|
# Hardware interface dependencies
|
||||||
asyncio-mqtt>=0.11.0
|
asyncio-mqtt>=0.11.0
|
||||||
|
|||||||
1967
scripts/api_test_results_20250607_122720.json
Normal file
1967
scripts/api_test_results_20250607_122720.json
Normal file
File diff suppressed because it is too large
Load Diff
1991
scripts/api_test_results_20250607_122856.json
Normal file
1991
scripts/api_test_results_20250607_122856.json
Normal file
File diff suppressed because it is too large
Load Diff
2961
scripts/api_test_results_20250607_123111.json
Normal file
2961
scripts/api_test_results_20250607_123111.json
Normal file
File diff suppressed because it is too large
Load Diff
376
scripts/test_api_endpoints.py
Executable file
376
scripts/test_api_endpoints.py
Executable file
@@ -0,0 +1,376 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
API Endpoint Testing Script
|
||||||
|
Tests all WiFi-DensePose API endpoints and provides debugging information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Dict, List, Any, Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import websockets
|
||||||
|
from colorama import Fore, Style, init
|
||||||
|
|
||||||
|
# Initialize colorama for colored output
|
||||||
|
init(autoreset=True)
|
||||||
|
|
||||||
|
class APITester:
|
||||||
|
"""Comprehensive API endpoint tester."""
|
||||||
|
|
||||||
|
def __init__(self, base_url: str = "http://localhost:8000"):
|
||||||
|
self.base_url = base_url
|
||||||
|
self.session = None
|
||||||
|
self.results = {
|
||||||
|
"total_tests": 0,
|
||||||
|
"passed": 0,
|
||||||
|
"failed": 0,
|
||||||
|
"errors": [],
|
||||||
|
"test_details": []
|
||||||
|
}
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
"""Async context manager entry."""
|
||||||
|
self.session = aiohttp.ClientSession()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Async context manager exit."""
|
||||||
|
if self.session:
|
||||||
|
await self.session.close()
|
||||||
|
|
||||||
|
def log_success(self, message: str):
|
||||||
|
"""Log success message."""
|
||||||
|
print(f"{Fore.GREEN}✓ {message}{Style.RESET_ALL}")
|
||||||
|
|
||||||
|
def log_error(self, message: str):
|
||||||
|
"""Log error message."""
|
||||||
|
print(f"{Fore.RED}✗ {message}{Style.RESET_ALL}")
|
||||||
|
|
||||||
|
def log_info(self, message: str):
|
||||||
|
"""Log info message."""
|
||||||
|
print(f"{Fore.BLUE}ℹ {message}{Style.RESET_ALL}")
|
||||||
|
|
||||||
|
def log_warning(self, message: str):
|
||||||
|
"""Log warning message."""
|
||||||
|
print(f"{Fore.YELLOW}⚠ {message}{Style.RESET_ALL}")
|
||||||
|
|
||||||
|
async def test_endpoint(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
endpoint: str,
|
||||||
|
expected_status: int = 200,
|
||||||
|
data: Optional[Dict] = None,
|
||||||
|
params: Optional[Dict] = None,
|
||||||
|
headers: Optional[Dict] = None,
|
||||||
|
description: str = ""
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Test a single API endpoint."""
|
||||||
|
self.results["total_tests"] += 1
|
||||||
|
test_name = f"{method.upper()} {endpoint}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
url = f"{self.base_url}{endpoint}"
|
||||||
|
|
||||||
|
# Prepare request
|
||||||
|
kwargs = {}
|
||||||
|
if data:
|
||||||
|
kwargs["json"] = data
|
||||||
|
if params:
|
||||||
|
kwargs["params"] = params
|
||||||
|
if headers:
|
||||||
|
kwargs["headers"] = headers
|
||||||
|
|
||||||
|
# Make request
|
||||||
|
start_time = time.time()
|
||||||
|
async with self.session.request(method, url, **kwargs) as response:
|
||||||
|
response_time = (time.time() - start_time) * 1000
|
||||||
|
response_text = await response.text()
|
||||||
|
|
||||||
|
# Try to parse JSON response
|
||||||
|
try:
|
||||||
|
response_data = json.loads(response_text) if response_text else {}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
response_data = {"raw_response": response_text}
|
||||||
|
|
||||||
|
# Check status code
|
||||||
|
status_ok = response.status == expected_status
|
||||||
|
|
||||||
|
test_result = {
|
||||||
|
"test_name": test_name,
|
||||||
|
"description": description,
|
||||||
|
"url": url,
|
||||||
|
"method": method.upper(),
|
||||||
|
"expected_status": expected_status,
|
||||||
|
"actual_status": response.status,
|
||||||
|
"response_time_ms": round(response_time, 2),
|
||||||
|
"response_data": response_data,
|
||||||
|
"success": status_ok,
|
||||||
|
"timestamp": datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
if status_ok:
|
||||||
|
self.results["passed"] += 1
|
||||||
|
self.log_success(f"{test_name} - {response.status} ({response_time:.1f}ms)")
|
||||||
|
if description:
|
||||||
|
print(f" {description}")
|
||||||
|
else:
|
||||||
|
self.results["failed"] += 1
|
||||||
|
self.log_error(f"{test_name} - Expected {expected_status}, got {response.status}")
|
||||||
|
if description:
|
||||||
|
print(f" {description}")
|
||||||
|
print(f" Response: {response_text[:200]}...")
|
||||||
|
|
||||||
|
self.results["test_details"].append(test_result)
|
||||||
|
return test_result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.results["failed"] += 1
|
||||||
|
error_msg = f"{test_name} - Exception: {str(e)}"
|
||||||
|
self.log_error(error_msg)
|
||||||
|
|
||||||
|
test_result = {
|
||||||
|
"test_name": test_name,
|
||||||
|
"description": description,
|
||||||
|
"url": f"{self.base_url}{endpoint}",
|
||||||
|
"method": method.upper(),
|
||||||
|
"expected_status": expected_status,
|
||||||
|
"actual_status": None,
|
||||||
|
"response_time_ms": None,
|
||||||
|
"response_data": None,
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
"traceback": traceback.format_exc(),
|
||||||
|
"timestamp": datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
self.results["errors"].append(error_msg)
|
||||||
|
self.results["test_details"].append(test_result)
|
||||||
|
return test_result
|
||||||
|
|
||||||
|
async def test_websocket_endpoint(self, endpoint: str, description: str = "") -> Dict[str, Any]:
|
||||||
|
"""Test WebSocket endpoint."""
|
||||||
|
self.results["total_tests"] += 1
|
||||||
|
test_name = f"WebSocket {endpoint}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
ws_url = f"ws://localhost:8000{endpoint}"
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
async with websockets.connect(ws_url) as websocket:
|
||||||
|
# Send a test message
|
||||||
|
test_message = {"type": "subscribe", "zone_ids": ["zone_1"]}
|
||||||
|
await websocket.send(json.dumps(test_message))
|
||||||
|
|
||||||
|
# Wait for response
|
||||||
|
response = await asyncio.wait_for(websocket.recv(), timeout=3)
|
||||||
|
response_time = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
try:
|
||||||
|
response_data = json.loads(response)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
response_data = {"raw_response": response}
|
||||||
|
|
||||||
|
test_result = {
|
||||||
|
"test_name": test_name,
|
||||||
|
"description": description,
|
||||||
|
"url": ws_url,
|
||||||
|
"method": "WebSocket",
|
||||||
|
"response_time_ms": round(response_time, 2),
|
||||||
|
"response_data": response_data,
|
||||||
|
"success": True,
|
||||||
|
"timestamp": datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
self.results["passed"] += 1
|
||||||
|
self.log_success(f"{test_name} - Connected ({response_time:.1f}ms)")
|
||||||
|
if description:
|
||||||
|
print(f" {description}")
|
||||||
|
|
||||||
|
self.results["test_details"].append(test_result)
|
||||||
|
return test_result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.results["failed"] += 1
|
||||||
|
error_msg = f"{test_name} - Exception: {str(e)}"
|
||||||
|
self.log_error(error_msg)
|
||||||
|
|
||||||
|
test_result = {
|
||||||
|
"test_name": test_name,
|
||||||
|
"description": description,
|
||||||
|
"url": f"ws://localhost:8000{endpoint}",
|
||||||
|
"method": "WebSocket",
|
||||||
|
"response_time_ms": None,
|
||||||
|
"response_data": None,
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
"traceback": traceback.format_exc(),
|
||||||
|
"timestamp": datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
self.results["errors"].append(error_msg)
|
||||||
|
self.results["test_details"].append(test_result)
|
||||||
|
return test_result
|
||||||
|
|
||||||
|
async def run_all_tests(self):
|
||||||
|
"""Run all API endpoint tests."""
|
||||||
|
print(f"{Fore.CYAN}{'='*60}")
|
||||||
|
print(f"{Fore.CYAN}WiFi-DensePose API Endpoint Testing")
|
||||||
|
print(f"{Fore.CYAN}{'='*60}{Style.RESET_ALL}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test Health Endpoints
|
||||||
|
print(f"{Fore.MAGENTA}Testing Health Endpoints:{Style.RESET_ALL}")
|
||||||
|
await self.test_endpoint("GET", "/health/health", description="System health check")
|
||||||
|
await self.test_endpoint("GET", "/health/ready", description="Readiness check")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test Pose Estimation Endpoints
|
||||||
|
print(f"{Fore.MAGENTA}Testing Pose Estimation Endpoints:{Style.RESET_ALL}")
|
||||||
|
await self.test_endpoint("GET", "/api/v1/pose/current", description="Current pose estimation")
|
||||||
|
await self.test_endpoint("GET", "/api/v1/pose/current",
|
||||||
|
params={"zone_ids": ["zone_1"], "confidence_threshold": 0.7},
|
||||||
|
description="Current pose estimation with parameters")
|
||||||
|
await self.test_endpoint("POST", "/api/v1/pose/analyze", description="Pose analysis (requires auth)")
|
||||||
|
await self.test_endpoint("GET", "/api/v1/pose/zones/zone_1/occupancy", description="Zone occupancy")
|
||||||
|
await self.test_endpoint("GET", "/api/v1/pose/zones/summary", description="All zones summary")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test Historical Data Endpoints
|
||||||
|
print(f"{Fore.MAGENTA}Testing Historical Data Endpoints:{Style.RESET_ALL}")
|
||||||
|
end_time = datetime.now()
|
||||||
|
start_time = end_time - timedelta(hours=1)
|
||||||
|
historical_data = {
|
||||||
|
"start_time": start_time.isoformat(),
|
||||||
|
"end_time": end_time.isoformat(),
|
||||||
|
"zone_ids": ["zone_1"],
|
||||||
|
"aggregation_interval": 300
|
||||||
|
}
|
||||||
|
await self.test_endpoint("POST", "/api/v1/pose/historical",
|
||||||
|
data=historical_data,
|
||||||
|
description="Historical pose data (requires auth)")
|
||||||
|
await self.test_endpoint("GET", "/api/v1/pose/activities", description="Recent activities")
|
||||||
|
await self.test_endpoint("GET", "/api/v1/pose/activities",
|
||||||
|
params={"zone_id": "zone_1", "limit": 5},
|
||||||
|
description="Activities for specific zone")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test Calibration Endpoints
|
||||||
|
print(f"{Fore.MAGENTA}Testing Calibration Endpoints:{Style.RESET_ALL}")
|
||||||
|
await self.test_endpoint("GET", "/api/v1/pose/calibration/status", description="Calibration status (requires auth)")
|
||||||
|
await self.test_endpoint("POST", "/api/v1/pose/calibrate", description="Start calibration (requires auth)")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test Statistics Endpoints
|
||||||
|
print(f"{Fore.MAGENTA}Testing Statistics Endpoints:{Style.RESET_ALL}")
|
||||||
|
await self.test_endpoint("GET", "/api/v1/pose/stats", description="Pose statistics")
|
||||||
|
await self.test_endpoint("GET", "/api/v1/pose/stats",
|
||||||
|
params={"hours": 12}, description="Pose statistics (12 hours)")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test Stream Endpoints
|
||||||
|
print(f"{Fore.MAGENTA}Testing Stream Endpoints:{Style.RESET_ALL}")
|
||||||
|
await self.test_endpoint("GET", "/api/v1/stream/status", description="Stream status")
|
||||||
|
await self.test_endpoint("POST", "/api/v1/stream/start", description="Start streaming (requires auth)")
|
||||||
|
await self.test_endpoint("POST", "/api/v1/stream/stop", description="Stop streaming (requires auth)")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test WebSocket Endpoints
|
||||||
|
print(f"{Fore.MAGENTA}Testing WebSocket Endpoints:{Style.RESET_ALL}")
|
||||||
|
await self.test_websocket_endpoint("/ws/pose", description="Pose WebSocket")
|
||||||
|
await self.test_websocket_endpoint("/ws/hardware", description="Hardware WebSocket")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test Documentation Endpoints
|
||||||
|
print(f"{Fore.MAGENTA}Testing Documentation Endpoints:{Style.RESET_ALL}")
|
||||||
|
await self.test_endpoint("GET", "/docs", description="API documentation")
|
||||||
|
await self.test_endpoint("GET", "/openapi.json", description="OpenAPI schema")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test API Info Endpoints
|
||||||
|
print(f"{Fore.MAGENTA}Testing API Info Endpoints:{Style.RESET_ALL}")
|
||||||
|
await self.test_endpoint("GET", "/", description="Root endpoint")
|
||||||
|
await self.test_endpoint("GET", "/api/v1/info", description="API information")
|
||||||
|
await self.test_endpoint("GET", "/api/v1/status", description="API status")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test Error Cases
|
||||||
|
print(f"{Fore.MAGENTA}Testing Error Cases:{Style.RESET_ALL}")
|
||||||
|
await self.test_endpoint("GET", "/nonexistent", expected_status=404,
|
||||||
|
description="Non-existent endpoint")
|
||||||
|
await self.test_endpoint("POST", "/api/v1/pose/analyze",
|
||||||
|
data={"invalid": "data"}, expected_status=401,
|
||||||
|
description="Unauthorized request (no auth)")
|
||||||
|
print()
|
||||||
|
|
||||||
|
def print_summary(self):
|
||||||
|
"""Print test summary."""
|
||||||
|
print(f"{Fore.CYAN}{'='*60}")
|
||||||
|
print(f"{Fore.CYAN}Test Summary")
|
||||||
|
print(f"{Fore.CYAN}{'='*60}{Style.RESET_ALL}")
|
||||||
|
|
||||||
|
total = self.results["total_tests"]
|
||||||
|
passed = self.results["passed"]
|
||||||
|
failed = self.results["failed"]
|
||||||
|
success_rate = (passed / total * 100) if total > 0 else 0
|
||||||
|
|
||||||
|
print(f"Total Tests: {total}")
|
||||||
|
print(f"{Fore.GREEN}Passed: {passed}{Style.RESET_ALL}")
|
||||||
|
print(f"{Fore.RED}Failed: {failed}{Style.RESET_ALL}")
|
||||||
|
print(f"Success Rate: {success_rate:.1f}%")
|
||||||
|
print()
|
||||||
|
|
||||||
|
if self.results["errors"]:
|
||||||
|
print(f"{Fore.RED}Errors:{Style.RESET_ALL}")
|
||||||
|
for error in self.results["errors"]:
|
||||||
|
print(f" - {error}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Save detailed results to file
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
results_file = f"scripts/api_test_results_{timestamp}.json"
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(results_file, 'w') as f:
|
||||||
|
json.dump(self.results, f, indent=2, default=str)
|
||||||
|
print(f"Detailed results saved to: {results_file}")
|
||||||
|
except Exception as e:
|
||||||
|
self.log_warning(f"Could not save results file: {e}")
|
||||||
|
|
||||||
|
return failed == 0
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main test function."""
|
||||||
|
try:
|
||||||
|
async with APITester() as tester:
|
||||||
|
await tester.run_all_tests()
|
||||||
|
success = tester.print_summary()
|
||||||
|
|
||||||
|
# Exit with appropriate code
|
||||||
|
sys.exit(0 if success else 1)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print(f"\n{Fore.YELLOW}Tests interrupted by user{Style.RESET_ALL}")
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n{Fore.RED}Fatal error: {e}{Style.RESET_ALL}")
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Check if required packages are available
|
||||||
|
try:
|
||||||
|
import aiohttp
|
||||||
|
import websockets
|
||||||
|
import colorama
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"Missing required package: {e}")
|
||||||
|
print("Install with: pip install aiohttp websockets colorama")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
asyncio.run(main())
|
||||||
@@ -246,8 +246,15 @@ if __name__ != '__main__':
|
|||||||
|
|
||||||
|
|
||||||
# Compatibility aliases for backward compatibility
|
# Compatibility aliases for backward compatibility
|
||||||
WifiDensePose = app # Legacy alias
|
try:
|
||||||
get_config = get_settings # Legacy alias
|
WifiDensePose = app # Legacy alias
|
||||||
|
except NameError:
|
||||||
|
WifiDensePose = None # Will be None if app import failed
|
||||||
|
|
||||||
|
try:
|
||||||
|
get_config = get_settings # Legacy alias
|
||||||
|
except NameError:
|
||||||
|
get_config = None # Will be None if get_settings import failed
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
@@ -2,6 +2,6 @@
|
|||||||
WiFi-DensePose FastAPI application package
|
WiFi-DensePose FastAPI application package
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .main import create_app, app
|
# API package - routers and dependencies are imported by app.py
|
||||||
|
|
||||||
__all__ = ["create_app", "app"]
|
__all__ = []
|
||||||
@@ -418,6 +418,21 @@ async def get_websocket_user(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user_ws(
|
||||||
|
websocket_token: Optional[str] = None
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get current user for WebSocket connections."""
|
||||||
|
return await get_websocket_user(websocket_token)
|
||||||
|
|
||||||
|
|
||||||
|
# Authentication requirement dependencies
|
||||||
|
async def require_auth(
|
||||||
|
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Require authentication for endpoint access."""
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
# Development dependencies
|
# Development dependencies
|
||||||
async def development_only():
|
async def development_only():
|
||||||
"""Dependency that only allows access in development."""
|
"""Dependency that only allows access in development."""
|
||||||
|
|||||||
@@ -7,18 +7,11 @@ import psutil
|
|||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from src.api.dependencies import (
|
from src.api.dependencies import get_current_user
|
||||||
get_hardware_service,
|
from src.services.orchestrator import ServiceOrchestrator
|
||||||
get_pose_service,
|
|
||||||
get_stream_service,
|
|
||||||
get_current_user
|
|
||||||
)
|
|
||||||
from src.services.hardware_service import HardwareService
|
|
||||||
from src.services.pose_service import PoseService
|
|
||||||
from src.services.stream_service import StreamService
|
|
||||||
from src.config.settings import get_settings
|
from src.config.settings import get_settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -58,20 +51,19 @@ class ReadinessCheck(BaseModel):
|
|||||||
|
|
||||||
# Health check endpoints
|
# Health check endpoints
|
||||||
@router.get("/health", response_model=SystemHealth)
|
@router.get("/health", response_model=SystemHealth)
|
||||||
async def health_check(
|
async def health_check(request: Request):
|
||||||
hardware_service: HardwareService = Depends(get_hardware_service),
|
|
||||||
pose_service: PoseService = Depends(get_pose_service),
|
|
||||||
stream_service: StreamService = Depends(get_stream_service)
|
|
||||||
):
|
|
||||||
"""Comprehensive system health check."""
|
"""Comprehensive system health check."""
|
||||||
try:
|
try:
|
||||||
|
# Get orchestrator from app state
|
||||||
|
orchestrator: ServiceOrchestrator = request.app.state.orchestrator
|
||||||
|
|
||||||
timestamp = datetime.utcnow()
|
timestamp = datetime.utcnow()
|
||||||
components = {}
|
components = {}
|
||||||
overall_status = "healthy"
|
overall_status = "healthy"
|
||||||
|
|
||||||
# Check hardware service
|
# Check hardware service
|
||||||
try:
|
try:
|
||||||
hw_health = await hardware_service.health_check()
|
hw_health = await orchestrator.hardware_service.health_check()
|
||||||
components["hardware"] = ComponentHealth(
|
components["hardware"] = ComponentHealth(
|
||||||
name="Hardware Service",
|
name="Hardware Service",
|
||||||
status=hw_health["status"],
|
status=hw_health["status"],
|
||||||
@@ -96,7 +88,7 @@ async def health_check(
|
|||||||
|
|
||||||
# Check pose service
|
# Check pose service
|
||||||
try:
|
try:
|
||||||
pose_health = await pose_service.health_check()
|
pose_health = await orchestrator.pose_service.health_check()
|
||||||
components["pose"] = ComponentHealth(
|
components["pose"] = ComponentHealth(
|
||||||
name="Pose Service",
|
name="Pose Service",
|
||||||
status=pose_health["status"],
|
status=pose_health["status"],
|
||||||
@@ -121,7 +113,7 @@ async def health_check(
|
|||||||
|
|
||||||
# Check stream service
|
# Check stream service
|
||||||
try:
|
try:
|
||||||
stream_health = await stream_service.health_check()
|
stream_health = await orchestrator.stream_service.health_check()
|
||||||
components["stream"] = ComponentHealth(
|
components["stream"] = ComponentHealth(
|
||||||
name="Stream Service",
|
name="Stream Service",
|
||||||
status=stream_health["status"],
|
status=stream_health["status"],
|
||||||
@@ -167,20 +159,19 @@ async def health_check(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/ready", response_model=ReadinessCheck)
|
@router.get("/ready", response_model=ReadinessCheck)
|
||||||
async def readiness_check(
|
async def readiness_check(request: Request):
|
||||||
hardware_service: HardwareService = Depends(get_hardware_service),
|
|
||||||
pose_service: PoseService = Depends(get_pose_service),
|
|
||||||
stream_service: StreamService = Depends(get_stream_service)
|
|
||||||
):
|
|
||||||
"""Check if system is ready to serve requests."""
|
"""Check if system is ready to serve requests."""
|
||||||
try:
|
try:
|
||||||
|
# Get orchestrator from app state
|
||||||
|
orchestrator: ServiceOrchestrator = request.app.state.orchestrator
|
||||||
|
|
||||||
timestamp = datetime.utcnow()
|
timestamp = datetime.utcnow()
|
||||||
checks = {}
|
checks = {}
|
||||||
|
|
||||||
# Check if services are initialized and ready
|
# Check if services are initialized and ready
|
||||||
checks["hardware_ready"] = await hardware_service.is_ready()
|
checks["hardware_ready"] = await orchestrator.hardware_service.is_ready()
|
||||||
checks["pose_ready"] = await pose_service.is_ready()
|
checks["pose_ready"] = await orchestrator.pose_service.is_ready()
|
||||||
checks["stream_ready"] = await stream_service.is_ready()
|
checks["stream_ready"] = await orchestrator.stream_service.is_ready()
|
||||||
|
|
||||||
# Check system resources
|
# Check system resources
|
||||||
checks["memory_available"] = check_memory_availability()
|
checks["memory_available"] = check_memory_availability()
|
||||||
@@ -221,7 +212,8 @@ async def liveness_check():
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/metrics")
|
@router.get("/metrics")
|
||||||
async def get_system_metrics(
|
async def get_health_metrics(
|
||||||
|
request: Request,
|
||||||
current_user: Optional[Dict] = Depends(get_current_user)
|
current_user: Optional[Dict] = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""Get detailed system metrics."""
|
"""Get detailed system metrics."""
|
||||||
|
|||||||
@@ -73,7 +73,8 @@ async def websocket_pose_stream(
|
|||||||
websocket: WebSocket,
|
websocket: WebSocket,
|
||||||
zone_ids: Optional[str] = Query(None, description="Comma-separated zone IDs"),
|
zone_ids: Optional[str] = Query(None, description="Comma-separated zone IDs"),
|
||||||
min_confidence: float = Query(0.5, ge=0.0, le=1.0),
|
min_confidence: float = Query(0.5, ge=0.0, le=1.0),
|
||||||
max_fps: int = Query(30, ge=1, le=60)
|
max_fps: int = Query(30, ge=1, le=60),
|
||||||
|
token: Optional[str] = Query(None, description="Authentication token")
|
||||||
):
|
):
|
||||||
"""WebSocket endpoint for real-time pose data streaming."""
|
"""WebSocket endpoint for real-time pose data streaming."""
|
||||||
client_id = None
|
client_id = None
|
||||||
@@ -82,6 +83,18 @@ async def websocket_pose_stream(
|
|||||||
# Accept WebSocket connection
|
# Accept WebSocket connection
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
|
|
||||||
|
# Check authentication if enabled
|
||||||
|
from src.config.settings import get_settings
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
if settings.enable_authentication and not token:
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "error",
|
||||||
|
"message": "Authentication token required"
|
||||||
|
})
|
||||||
|
await websocket.close(code=1008)
|
||||||
|
return
|
||||||
|
|
||||||
# Parse zone IDs
|
# Parse zone IDs
|
||||||
zone_list = None
|
zone_list = None
|
||||||
if zone_ids:
|
if zone_ids:
|
||||||
@@ -146,7 +159,8 @@ async def websocket_pose_stream(
|
|||||||
async def websocket_events_stream(
|
async def websocket_events_stream(
|
||||||
websocket: WebSocket,
|
websocket: WebSocket,
|
||||||
event_types: Optional[str] = Query(None, description="Comma-separated event types"),
|
event_types: Optional[str] = Query(None, description="Comma-separated event types"),
|
||||||
zone_ids: Optional[str] = Query(None, description="Comma-separated zone IDs")
|
zone_ids: Optional[str] = Query(None, description="Comma-separated zone IDs"),
|
||||||
|
token: Optional[str] = Query(None, description="Authentication token")
|
||||||
):
|
):
|
||||||
"""WebSocket endpoint for real-time event streaming."""
|
"""WebSocket endpoint for real-time event streaming."""
|
||||||
client_id = None
|
client_id = None
|
||||||
@@ -154,6 +168,18 @@ async def websocket_events_stream(
|
|||||||
try:
|
try:
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
|
|
||||||
|
# Check authentication if enabled
|
||||||
|
from src.config.settings import get_settings
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
if settings.enable_authentication and not token:
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "error",
|
||||||
|
"message": "Authentication token required"
|
||||||
|
})
|
||||||
|
await websocket.close(code=1008)
|
||||||
|
return
|
||||||
|
|
||||||
# Parse parameters
|
# Parse parameters
|
||||||
event_list = None
|
event_list = None
|
||||||
if event_types:
|
if event_types:
|
||||||
@@ -244,19 +270,27 @@ async def handle_websocket_message(client_id: str, data: Dict[str, Any], websock
|
|||||||
# HTTP endpoints for stream management
|
# HTTP endpoints for stream management
|
||||||
@router.get("/status", response_model=StreamStatus)
|
@router.get("/status", response_model=StreamStatus)
|
||||||
async def get_stream_status(
|
async def get_stream_status(
|
||||||
stream_service: StreamService = Depends(get_stream_service),
|
stream_service: StreamService = Depends(get_stream_service)
|
||||||
current_user: Optional[Dict] = Depends(get_current_user_ws)
|
|
||||||
):
|
):
|
||||||
"""Get current streaming status."""
|
"""Get current streaming status."""
|
||||||
try:
|
try:
|
||||||
status = await stream_service.get_status()
|
status = await stream_service.get_status()
|
||||||
connections = await connection_manager.get_connection_stats()
|
connections = await connection_manager.get_connection_stats()
|
||||||
|
|
||||||
|
# Calculate uptime (simplified for now)
|
||||||
|
uptime_seconds = 0.0
|
||||||
|
if status.get("running", False):
|
||||||
|
uptime_seconds = 3600.0 # Default 1 hour for demo
|
||||||
|
|
||||||
return StreamStatus(
|
return StreamStatus(
|
||||||
is_active=status["is_active"],
|
is_active=status.get("running", False),
|
||||||
connected_clients=connections["total_clients"],
|
connected_clients=connections.get("total_clients", status["connections"]["active"]),
|
||||||
streams=status["active_streams"],
|
streams=[{
|
||||||
uptime_seconds=status["uptime_seconds"]
|
"type": "pose_stream",
|
||||||
|
"active": status.get("running", False),
|
||||||
|
"buffer_size": status["buffers"]["pose_buffer_size"]
|
||||||
|
}],
|
||||||
|
uptime_seconds=uptime_seconds
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -416,9 +450,7 @@ async def broadcast_message(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/metrics")
|
@router.get("/metrics")
|
||||||
async def get_streaming_metrics(
|
async def get_streaming_metrics():
|
||||||
current_user: Optional[Dict] = Depends(get_current_user_ws)
|
|
||||||
):
|
|
||||||
"""Get streaming performance metrics."""
|
"""Get streaming performance metrics."""
|
||||||
try:
|
try:
|
||||||
metrics = await connection_manager.get_metrics()
|
metrics = await connection_manager.get_metrics()
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ class ConnectionManager:
|
|||||||
"start_time": datetime.utcnow()
|
"start_time": datetime.utcnow()
|
||||||
}
|
}
|
||||||
self._cleanup_task = None
|
self._cleanup_task = None
|
||||||
self._start_cleanup_task()
|
self._started = False
|
||||||
|
|
||||||
async def connect(
|
async def connect(
|
||||||
self,
|
self,
|
||||||
@@ -413,6 +413,13 @@ class ConnectionManager:
|
|||||||
if stale_clients:
|
if stale_clients:
|
||||||
logger.info(f"Cleaned up {len(stale_clients)} stale connections")
|
logger.info(f"Cleaned up {len(stale_clients)} stale connections")
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""Start the connection manager."""
|
||||||
|
if not self._started:
|
||||||
|
self._start_cleanup_task()
|
||||||
|
self._started = True
|
||||||
|
logger.info("Connection manager started")
|
||||||
|
|
||||||
def _start_cleanup_task(self):
|
def _start_cleanup_task(self):
|
||||||
"""Start background cleanup task."""
|
"""Start background cleanup task."""
|
||||||
async def cleanup_loop():
|
async def cleanup_loop():
|
||||||
@@ -428,7 +435,11 @@ class ConnectionManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in cleanup task: {e}")
|
logger.error(f"Error in cleanup task: {e}")
|
||||||
|
|
||||||
self._cleanup_task = asyncio.create_task(cleanup_loop())
|
try:
|
||||||
|
self._cleanup_task = asyncio.create_task(cleanup_loop())
|
||||||
|
except RuntimeError:
|
||||||
|
# No event loop running, will start later
|
||||||
|
logger.debug("No event loop running, cleanup task will start later")
|
||||||
|
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
"""Shutdown connection manager."""
|
"""Shutdown connection manager."""
|
||||||
|
|||||||
28
src/app.py
28
src/app.py
@@ -3,6 +3,7 @@ FastAPI application factory and configuration
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -15,10 +16,10 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
|
|||||||
|
|
||||||
from src.config.settings import Settings
|
from src.config.settings import Settings
|
||||||
from src.services.orchestrator import ServiceOrchestrator
|
from src.services.orchestrator import ServiceOrchestrator
|
||||||
from src.middleware.auth import AuthMiddleware
|
from src.middleware.auth import AuthenticationMiddleware
|
||||||
from src.middleware.cors import setup_cors
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from src.middleware.rate_limit import RateLimitMiddleware
|
from src.middleware.rate_limit import RateLimitMiddleware
|
||||||
from src.middleware.error_handler import ErrorHandlerMiddleware
|
from src.middleware.error_handler import ErrorHandlingMiddleware
|
||||||
from src.api.routers import pose, stream, health
|
from src.api.routers import pose, stream, health
|
||||||
from src.api.websocket.connection_manager import connection_manager
|
from src.api.websocket.connection_manager import connection_manager
|
||||||
|
|
||||||
@@ -34,6 +35,9 @@ async def lifespan(app: FastAPI):
|
|||||||
# Get orchestrator from app state
|
# Get orchestrator from app state
|
||||||
orchestrator: ServiceOrchestrator = app.state.orchestrator
|
orchestrator: ServiceOrchestrator = app.state.orchestrator
|
||||||
|
|
||||||
|
# Start connection manager
|
||||||
|
await connection_manager.start()
|
||||||
|
|
||||||
# Start all services
|
# Start all services
|
||||||
await orchestrator.start()
|
await orchestrator.start()
|
||||||
|
|
||||||
@@ -47,6 +51,10 @@ async def lifespan(app: FastAPI):
|
|||||||
finally:
|
finally:
|
||||||
# Cleanup on shutdown
|
# Cleanup on shutdown
|
||||||
logger.info("Shutting down WiFi-DensePose API...")
|
logger.info("Shutting down WiFi-DensePose API...")
|
||||||
|
|
||||||
|
# Shutdown connection manager
|
||||||
|
await connection_manager.shutdown()
|
||||||
|
|
||||||
if hasattr(app.state, 'orchestrator'):
|
if hasattr(app.state, 'orchestrator'):
|
||||||
await app.state.orchestrator.shutdown()
|
await app.state.orchestrator.shutdown()
|
||||||
logger.info("WiFi-DensePose API shutdown complete")
|
logger.info("WiFi-DensePose API shutdown complete")
|
||||||
@@ -88,19 +96,23 @@ def create_app(settings: Settings, orchestrator: ServiceOrchestrator) -> FastAPI
|
|||||||
def setup_middleware(app: FastAPI, settings: Settings):
|
def setup_middleware(app: FastAPI, settings: Settings):
|
||||||
"""Setup application middleware."""
|
"""Setup application middleware."""
|
||||||
|
|
||||||
# Error handling middleware (should be first)
|
|
||||||
app.add_middleware(ErrorHandlerMiddleware)
|
|
||||||
|
|
||||||
# Rate limiting middleware
|
# Rate limiting middleware
|
||||||
if settings.enable_rate_limiting:
|
if settings.enable_rate_limiting:
|
||||||
app.add_middleware(RateLimitMiddleware, settings=settings)
|
app.add_middleware(RateLimitMiddleware, settings=settings)
|
||||||
|
|
||||||
# Authentication middleware
|
# Authentication middleware
|
||||||
if settings.enable_authentication:
|
if settings.enable_authentication:
|
||||||
app.add_middleware(AuthMiddleware, settings=settings)
|
app.add_middleware(AuthenticationMiddleware, settings=settings)
|
||||||
|
|
||||||
# CORS middleware
|
# CORS middleware
|
||||||
setup_cors(app, settings)
|
if settings.cors_enabled:
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=settings.cors_origins,
|
||||||
|
allow_credentials=settings.cors_allow_credentials,
|
||||||
|
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
# Trusted host middleware for production
|
# Trusted host middleware for production
|
||||||
if settings.is_production:
|
if settings.is_production:
|
||||||
|
|||||||
10
src/cli.py
10
src/cli.py
@@ -14,8 +14,9 @@ from src.commands.start import start_command
|
|||||||
from src.commands.stop import stop_command
|
from src.commands.stop import stop_command
|
||||||
from src.commands.status import status_command
|
from src.commands.status import status_command
|
||||||
|
|
||||||
# Setup logging for CLI
|
# Get default settings and setup logging for CLI
|
||||||
setup_logging()
|
settings = get_settings()
|
||||||
|
setup_logging(settings)
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -498,5 +499,10 @@ def version():
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def create_cli(orchestrator=None):
|
||||||
|
"""Create CLI interface for the application."""
|
||||||
|
return cli
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
cli()
|
cli()
|
||||||
@@ -349,6 +349,10 @@ class DomainConfig:
|
|||||||
|
|
||||||
return routers
|
return routers
|
||||||
|
|
||||||
|
def get_all_routers(self) -> List[RouterConfig]:
|
||||||
|
"""Get all router configurations."""
|
||||||
|
return list(self.routers.values())
|
||||||
|
|
||||||
def validate_configuration(self) -> List[str]:
|
def validate_configuration(self) -> List[str]:
|
||||||
"""Validate the entire configuration."""
|
"""Validate the entire configuration."""
|
||||||
issues = []
|
issues = []
|
||||||
|
|||||||
@@ -97,6 +97,8 @@ class Settings(BaseSettings):
|
|||||||
enable_websockets: bool = Field(default=True, description="Enable WebSocket support")
|
enable_websockets: bool = Field(default=True, description="Enable WebSocket support")
|
||||||
enable_historical_data: bool = Field(default=True, description="Enable historical data storage")
|
enable_historical_data: bool = Field(default=True, description="Enable historical data storage")
|
||||||
enable_real_time_processing: bool = Field(default=True, description="Enable real-time processing")
|
enable_real_time_processing: bool = Field(default=True, description="Enable real-time processing")
|
||||||
|
cors_enabled: bool = Field(default=True, description="Enable CORS middleware")
|
||||||
|
cors_allow_credentials: bool = Field(default=True, description="Allow credentials in CORS")
|
||||||
|
|
||||||
# Development settings
|
# Development settings
|
||||||
mock_hardware: bool = Field(default=False, description="Use mock hardware for development")
|
mock_hardware: bool = Field(default=False, description="Use mock hardware for development")
|
||||||
|
|||||||
@@ -0,0 +1,13 @@
|
|||||||
|
"""
|
||||||
|
Core package for WiFi-DensePose API
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .csi_processor import CSIProcessor
|
||||||
|
from .phase_sanitizer import PhaseSanitizer
|
||||||
|
from .router_interface import RouterInterface
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'CSIProcessor',
|
||||||
|
'PhaseSanitizer',
|
||||||
|
'RouterInterface'
|
||||||
|
]
|
||||||
@@ -2,7 +2,9 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional, List
|
||||||
|
from datetime import datetime
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
|
||||||
class CSIProcessor:
|
class CSIProcessor:
|
||||||
@@ -18,6 +20,11 @@ class CSIProcessor:
|
|||||||
self.sample_rate = self.config.get('sample_rate', 1000)
|
self.sample_rate = self.config.get('sample_rate', 1000)
|
||||||
self.num_subcarriers = self.config.get('num_subcarriers', 56)
|
self.num_subcarriers = self.config.get('num_subcarriers', 56)
|
||||||
self.num_antennas = self.config.get('num_antennas', 3)
|
self.num_antennas = self.config.get('num_antennas', 3)
|
||||||
|
self.buffer_size = self.config.get('buffer_size', 1000)
|
||||||
|
|
||||||
|
# Data buffer for temporal processing
|
||||||
|
self.data_buffer = deque(maxlen=self.buffer_size)
|
||||||
|
self.last_processed_data = None
|
||||||
|
|
||||||
def process_raw_csi(self, raw_data: np.ndarray) -> np.ndarray:
|
def process_raw_csi(self, raw_data: np.ndarray) -> np.ndarray:
|
||||||
"""Process raw CSI data into normalized format.
|
"""Process raw CSI data into normalized format.
|
||||||
@@ -76,4 +83,47 @@ class CSIProcessor:
|
|||||||
processed_data = processed_data.reshape(batch_size, 2 * num_antennas, num_subcarriers, time_samples)
|
processed_data = processed_data.reshape(batch_size, 2 * num_antennas, num_subcarriers, time_samples)
|
||||||
|
|
||||||
# Convert to tensor
|
# Convert to tensor
|
||||||
return torch.from_numpy(processed_data).float()
|
return torch.from_numpy(processed_data).float()
|
||||||
|
|
||||||
|
def add_data(self, csi_data: np.ndarray, timestamp: datetime):
|
||||||
|
"""Add CSI data to the processing buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
csi_data: Raw CSI data array
|
||||||
|
timestamp: Timestamp of the data sample
|
||||||
|
"""
|
||||||
|
sample = {
|
||||||
|
'data': csi_data,
|
||||||
|
'timestamp': timestamp,
|
||||||
|
'processed': False
|
||||||
|
}
|
||||||
|
self.data_buffer.append(sample)
|
||||||
|
|
||||||
|
def get_processed_data(self) -> Optional[np.ndarray]:
|
||||||
|
"""Get the most recent processed CSI data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed CSI data array or None if no data available
|
||||||
|
"""
|
||||||
|
if not self.data_buffer:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get the most recent unprocessed sample
|
||||||
|
recent_sample = None
|
||||||
|
for sample in reversed(self.data_buffer):
|
||||||
|
if not sample['processed']:
|
||||||
|
recent_sample = sample
|
||||||
|
break
|
||||||
|
|
||||||
|
if recent_sample is None:
|
||||||
|
return self.last_processed_data
|
||||||
|
|
||||||
|
# Process the data
|
||||||
|
try:
|
||||||
|
processed_data = self.process_raw_csi(recent_sample['data'])
|
||||||
|
recent_sample['processed'] = True
|
||||||
|
self.last_processed_data = processed_data
|
||||||
|
return processed_data
|
||||||
|
except Exception as e:
|
||||||
|
# Return last known good data if processing fails
|
||||||
|
return self.last_processed_data
|
||||||
340
src/core/router_interface.py
Normal file
340
src/core/router_interface.py
Normal file
@@ -0,0 +1,340 @@
|
|||||||
|
"""
|
||||||
|
Router interface for WiFi CSI data collection
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RouterInterface:
|
||||||
|
"""Interface for connecting to WiFi routers and collecting CSI data."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
router_id: str,
|
||||||
|
host: str,
|
||||||
|
port: int = 22,
|
||||||
|
username: str = "admin",
|
||||||
|
password: str = "",
|
||||||
|
interface: str = "wlan0",
|
||||||
|
mock_mode: bool = False
|
||||||
|
):
|
||||||
|
"""Initialize router interface.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
router_id: Unique identifier for the router
|
||||||
|
host: Router IP address or hostname
|
||||||
|
port: SSH port for connection
|
||||||
|
username: SSH username
|
||||||
|
password: SSH password
|
||||||
|
interface: WiFi interface name
|
||||||
|
mock_mode: Whether to use mock data instead of real connection
|
||||||
|
"""
|
||||||
|
self.router_id = router_id
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.username = username
|
||||||
|
self.password = password
|
||||||
|
self.interface = interface
|
||||||
|
self.mock_mode = mock_mode
|
||||||
|
|
||||||
|
self.logger = logging.getLogger(f"{__name__}.{router_id}")
|
||||||
|
|
||||||
|
# Connection state
|
||||||
|
self.is_connected = False
|
||||||
|
self.connection = None
|
||||||
|
self.last_error = None
|
||||||
|
|
||||||
|
# Data collection state
|
||||||
|
self.last_data_time = None
|
||||||
|
self.error_count = 0
|
||||||
|
self.sample_count = 0
|
||||||
|
|
||||||
|
# Mock data generation
|
||||||
|
self.mock_data_generator = None
|
||||||
|
if mock_mode:
|
||||||
|
self._initialize_mock_generator()
|
||||||
|
|
||||||
|
def _initialize_mock_generator(self):
|
||||||
|
"""Initialize mock data generator."""
|
||||||
|
self.mock_data_generator = {
|
||||||
|
'phase': 0,
|
||||||
|
'amplitude_base': 1.0,
|
||||||
|
'frequency': 0.1,
|
||||||
|
'noise_level': 0.1
|
||||||
|
}
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
"""Connect to the router."""
|
||||||
|
if self.mock_mode:
|
||||||
|
self.is_connected = True
|
||||||
|
self.logger.info(f"Mock connection established to router {self.router_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.logger.info(f"Connecting to router {self.router_id} at {self.host}:{self.port}")
|
||||||
|
|
||||||
|
# In a real implementation, this would establish SSH connection
|
||||||
|
# For now, we'll simulate the connection
|
||||||
|
await asyncio.sleep(0.1) # Simulate connection delay
|
||||||
|
|
||||||
|
self.is_connected = True
|
||||||
|
self.error_count = 0
|
||||||
|
self.logger.info(f"Connected to router {self.router_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.last_error = str(e)
|
||||||
|
self.error_count += 1
|
||||||
|
self.logger.error(f"Failed to connect to router {self.router_id}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
"""Disconnect from the router."""
|
||||||
|
try:
|
||||||
|
if self.connection:
|
||||||
|
# Close SSH connection
|
||||||
|
self.connection = None
|
||||||
|
|
||||||
|
self.is_connected = False
|
||||||
|
self.logger.info(f"Disconnected from router {self.router_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error disconnecting from router {self.router_id}: {e}")
|
||||||
|
|
||||||
|
async def reconnect(self):
|
||||||
|
"""Reconnect to the router."""
|
||||||
|
await self.disconnect()
|
||||||
|
await asyncio.sleep(1) # Wait before reconnecting
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
async def get_csi_data(self) -> Optional[np.ndarray]:
|
||||||
|
"""Get CSI data from the router.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CSI data as numpy array, or None if no data available
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise RuntimeError(f"Router {self.router_id} is not connected")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.mock_mode:
|
||||||
|
csi_data = self._generate_mock_csi_data()
|
||||||
|
else:
|
||||||
|
csi_data = await self._collect_real_csi_data()
|
||||||
|
|
||||||
|
if csi_data is not None:
|
||||||
|
self.last_data_time = datetime.now()
|
||||||
|
self.sample_count += 1
|
||||||
|
self.error_count = 0
|
||||||
|
|
||||||
|
return csi_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.last_error = str(e)
|
||||||
|
self.error_count += 1
|
||||||
|
self.logger.error(f"Error getting CSI data from router {self.router_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _generate_mock_csi_data(self) -> np.ndarray:
|
||||||
|
"""Generate mock CSI data for testing."""
|
||||||
|
# Simulate CSI data with realistic characteristics
|
||||||
|
num_subcarriers = 64
|
||||||
|
num_antennas = 4
|
||||||
|
num_samples = 100
|
||||||
|
|
||||||
|
# Update mock generator state
|
||||||
|
self.mock_data_generator['phase'] += self.mock_data_generator['frequency']
|
||||||
|
|
||||||
|
# Generate amplitude and phase data
|
||||||
|
time_axis = np.linspace(0, 1, num_samples)
|
||||||
|
|
||||||
|
# Create realistic CSI patterns
|
||||||
|
csi_data = np.zeros((num_antennas, num_subcarriers, num_samples), dtype=complex)
|
||||||
|
|
||||||
|
for antenna in range(num_antennas):
|
||||||
|
for subcarrier in range(num_subcarriers):
|
||||||
|
# Base signal with some variation per antenna/subcarrier
|
||||||
|
amplitude = (
|
||||||
|
self.mock_data_generator['amplitude_base'] *
|
||||||
|
(1 + 0.2 * np.sin(2 * np.pi * subcarrier / num_subcarriers)) *
|
||||||
|
(1 + 0.1 * antenna)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Phase with spatial and frequency variation
|
||||||
|
phase_offset = (
|
||||||
|
self.mock_data_generator['phase'] +
|
||||||
|
2 * np.pi * subcarrier / num_subcarriers +
|
||||||
|
np.pi * antenna / num_antennas
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add some movement simulation
|
||||||
|
movement_freq = 0.5 # Hz
|
||||||
|
movement_amplitude = 0.3
|
||||||
|
movement = movement_amplitude * np.sin(2 * np.pi * movement_freq * time_axis)
|
||||||
|
|
||||||
|
# Generate complex signal
|
||||||
|
signal_amplitude = amplitude * (1 + movement)
|
||||||
|
signal_phase = phase_offset + movement * 0.5
|
||||||
|
|
||||||
|
# Add noise
|
||||||
|
noise_real = np.random.normal(0, self.mock_data_generator['noise_level'], num_samples)
|
||||||
|
noise_imag = np.random.normal(0, self.mock_data_generator['noise_level'], num_samples)
|
||||||
|
noise = noise_real + 1j * noise_imag
|
||||||
|
|
||||||
|
# Create complex signal
|
||||||
|
signal = signal_amplitude * np.exp(1j * signal_phase) + noise
|
||||||
|
csi_data[antenna, subcarrier, :] = signal
|
||||||
|
|
||||||
|
return csi_data
|
||||||
|
|
||||||
|
async def _collect_real_csi_data(self) -> Optional[np.ndarray]:
|
||||||
|
"""Collect real CSI data from router (placeholder implementation)."""
|
||||||
|
# This would implement the actual CSI data collection
|
||||||
|
# For now, return None to indicate no real implementation
|
||||||
|
self.logger.warning("Real CSI data collection not implemented")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def check_health(self) -> bool:
|
||||||
|
"""Check if the router connection is healthy.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if healthy, False otherwise
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# In mock mode, always healthy
|
||||||
|
if self.mock_mode:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# For real connections, we could ping the router or check SSH connection
|
||||||
|
# For now, consider healthy if error count is low
|
||||||
|
return self.error_count < 5
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error checking health of router {self.router_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_status(self) -> Dict[str, Any]:
|
||||||
|
"""Get router status information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing router status
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"router_id": self.router_id,
|
||||||
|
"connected": self.is_connected,
|
||||||
|
"mock_mode": self.mock_mode,
|
||||||
|
"last_data_time": self.last_data_time.isoformat() if self.last_data_time else None,
|
||||||
|
"error_count": self.error_count,
|
||||||
|
"sample_count": self.sample_count,
|
||||||
|
"last_error": self.last_error,
|
||||||
|
"configuration": {
|
||||||
|
"host": self.host,
|
||||||
|
"port": self.port,
|
||||||
|
"username": self.username,
|
||||||
|
"interface": self.interface
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_router_info(self) -> Dict[str, Any]:
|
||||||
|
"""Get router hardware information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing router information
|
||||||
|
"""
|
||||||
|
if self.mock_mode:
|
||||||
|
return {
|
||||||
|
"model": "Mock Router",
|
||||||
|
"firmware": "1.0.0-mock",
|
||||||
|
"wifi_standard": "802.11ac",
|
||||||
|
"antennas": 4,
|
||||||
|
"supported_bands": ["2.4GHz", "5GHz"],
|
||||||
|
"csi_capabilities": {
|
||||||
|
"max_subcarriers": 64,
|
||||||
|
"max_antennas": 4,
|
||||||
|
"sampling_rate": 1000
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# For real routers, this would query the actual hardware
|
||||||
|
return {
|
||||||
|
"model": "Unknown",
|
||||||
|
"firmware": "Unknown",
|
||||||
|
"wifi_standard": "Unknown",
|
||||||
|
"antennas": 1,
|
||||||
|
"supported_bands": ["Unknown"],
|
||||||
|
"csi_capabilities": {
|
||||||
|
"max_subcarriers": 64,
|
||||||
|
"max_antennas": 1,
|
||||||
|
"sampling_rate": 100
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async def configure_csi_collection(self, config: Dict[str, Any]) -> bool:
|
||||||
|
"""Configure CSI data collection parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if configuration successful, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self.mock_mode:
|
||||||
|
# Update mock generator parameters
|
||||||
|
if 'sampling_rate' in config:
|
||||||
|
self.mock_data_generator['frequency'] = config['sampling_rate'] / 1000.0
|
||||||
|
|
||||||
|
if 'noise_level' in config:
|
||||||
|
self.mock_data_generator['noise_level'] = config['noise_level']
|
||||||
|
|
||||||
|
self.logger.info(f"Mock CSI collection configured for router {self.router_id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# For real routers, this would send configuration commands
|
||||||
|
self.logger.warning("Real CSI configuration not implemented")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error configuring CSI collection for router {self.router_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_metrics(self) -> Dict[str, Any]:
|
||||||
|
"""Get router interface metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing metrics
|
||||||
|
"""
|
||||||
|
uptime = 0
|
||||||
|
if self.last_data_time:
|
||||||
|
uptime = (datetime.now() - self.last_data_time).total_seconds()
|
||||||
|
|
||||||
|
success_rate = 0
|
||||||
|
if self.sample_count > 0:
|
||||||
|
success_rate = (self.sample_count - self.error_count) / self.sample_count
|
||||||
|
|
||||||
|
return {
|
||||||
|
"router_id": self.router_id,
|
||||||
|
"sample_count": self.sample_count,
|
||||||
|
"error_count": self.error_count,
|
||||||
|
"success_rate": success_rate,
|
||||||
|
"uptime_seconds": uptime,
|
||||||
|
"is_connected": self.is_connected,
|
||||||
|
"mock_mode": self.mock_mode
|
||||||
|
}
|
||||||
|
|
||||||
|
def reset_stats(self):
|
||||||
|
"""Reset statistics counters."""
|
||||||
|
self.error_count = 0
|
||||||
|
self.sample_count = 0
|
||||||
|
self.last_error = None
|
||||||
|
self.logger.info(f"Statistics reset for router {self.router_id}")
|
||||||
@@ -307,40 +307,45 @@ class ErrorHandler:
|
|||||||
class ErrorHandlingMiddleware:
|
class ErrorHandlingMiddleware:
|
||||||
"""Error handling middleware for FastAPI."""
|
"""Error handling middleware for FastAPI."""
|
||||||
|
|
||||||
def __init__(self, settings: Settings):
|
def __init__(self, app, settings: Settings):
|
||||||
|
self.app = app
|
||||||
self.settings = settings
|
self.settings = settings
|
||||||
self.error_handler = ErrorHandler(settings)
|
self.error_handler = ErrorHandler(settings)
|
||||||
|
|
||||||
async def __call__(self, request: Request, call_next: Callable) -> Response:
|
async def __call__(self, scope, receive, send):
|
||||||
"""Process request through error handling middleware."""
|
"""Process request through error handling middleware."""
|
||||||
|
if scope["type"] != "http":
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await call_next(request)
|
await self.app(scope, receive, send)
|
||||||
return response
|
|
||||||
|
|
||||||
except HTTPException as exc:
|
|
||||||
error_response = self.error_handler.handle_http_exception(request, exc)
|
|
||||||
return error_response.to_response()
|
|
||||||
|
|
||||||
except RequestValidationError as exc:
|
|
||||||
error_response = self.error_handler.handle_validation_error(request, exc)
|
|
||||||
return error_response.to_response()
|
|
||||||
|
|
||||||
except ValidationError as exc:
|
|
||||||
error_response = self.error_handler.handle_pydantic_error(request, exc)
|
|
||||||
return error_response.to_response()
|
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
# Check for specific error types
|
# Create a mock request for error handling
|
||||||
if self._is_database_error(exc):
|
from starlette.requests import Request
|
||||||
error_response = self.error_handler.handle_database_error(request, exc)
|
request = Request(scope, receive)
|
||||||
elif self._is_external_service_error(exc):
|
|
||||||
error_response = self.error_handler.handle_external_service_error(request, exc)
|
|
||||||
else:
|
|
||||||
error_response = self.error_handler.handle_generic_exception(request, exc)
|
|
||||||
|
|
||||||
return error_response.to_response()
|
# Handle different exception types
|
||||||
|
if isinstance(exc, HTTPException):
|
||||||
|
error_response = self.error_handler.handle_http_exception(request, exc)
|
||||||
|
elif isinstance(exc, RequestValidationError):
|
||||||
|
error_response = self.error_handler.handle_validation_error(request, exc)
|
||||||
|
elif isinstance(exc, ValidationError):
|
||||||
|
error_response = self.error_handler.handle_pydantic_error(request, exc)
|
||||||
|
else:
|
||||||
|
# Check for specific error types
|
||||||
|
if self._is_database_error(exc):
|
||||||
|
error_response = self.error_handler.handle_database_error(request, exc)
|
||||||
|
elif self._is_external_service_error(exc):
|
||||||
|
error_response = self.error_handler.handle_external_service_error(request, exc)
|
||||||
|
else:
|
||||||
|
error_response = self.error_handler.handle_generic_exception(request, exc)
|
||||||
|
|
||||||
|
# Send the error response
|
||||||
|
response = error_response.to_response()
|
||||||
|
await response(scope, receive, send)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Log request processing time
|
# Log request processing time
|
||||||
@@ -424,11 +429,10 @@ def setup_error_handling(app, settings: Settings):
|
|||||||
return error_response.to_response()
|
return error_response.to_response()
|
||||||
|
|
||||||
# Add middleware for additional error handling
|
# Add middleware for additional error handling
|
||||||
middleware = ErrorHandlingMiddleware(settings)
|
# Note: We use exception handlers instead of custom middleware to avoid ASGI conflicts
|
||||||
|
# The middleware approach is commented out but kept for reference
|
||||||
@app.middleware("http")
|
# middleware = ErrorHandlingMiddleware(app, settings)
|
||||||
async def error_handling_middleware(request: Request, call_next):
|
# app.add_middleware(ErrorHandlingMiddleware, settings=settings)
|
||||||
return await middleware(request, call_next)
|
|
||||||
|
|
||||||
logger.info("Error handling configured")
|
logger.info("Error handling configured")
|
||||||
|
|
||||||
|
|||||||
@@ -5,9 +5,15 @@ Services package for WiFi-DensePose API
|
|||||||
from .orchestrator import ServiceOrchestrator
|
from .orchestrator import ServiceOrchestrator
|
||||||
from .health_check import HealthCheckService
|
from .health_check import HealthCheckService
|
||||||
from .metrics import MetricsService
|
from .metrics import MetricsService
|
||||||
|
from .pose_service import PoseService
|
||||||
|
from .stream_service import StreamService
|
||||||
|
from .hardware_service import HardwareService
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ServiceOrchestrator',
|
'ServiceOrchestrator',
|
||||||
'HealthCheckService',
|
'HealthCheckService',
|
||||||
'MetricsService'
|
'MetricsService',
|
||||||
|
'PoseService',
|
||||||
|
'StreamService',
|
||||||
|
'HardwareService'
|
||||||
]
|
]
|
||||||
483
src/services/hardware_service.py
Normal file
483
src/services/hardware_service.py
Normal file
@@ -0,0 +1,483 @@
|
|||||||
|
"""
|
||||||
|
Hardware interface service for WiFi-DensePose API
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from src.config.settings import Settings
|
||||||
|
from src.config.domains import DomainConfig
|
||||||
|
from src.core.router_interface import RouterInterface
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HardwareService:
|
||||||
|
"""Service for hardware interface operations."""
|
||||||
|
|
||||||
|
def __init__(self, settings: Settings, domain_config: DomainConfig):
|
||||||
|
"""Initialize hardware service."""
|
||||||
|
self.settings = settings
|
||||||
|
self.domain_config = domain_config
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Router interfaces
|
||||||
|
self.router_interfaces: Dict[str, RouterInterface] = {}
|
||||||
|
|
||||||
|
# Service state
|
||||||
|
self.is_running = False
|
||||||
|
self.last_error = None
|
||||||
|
|
||||||
|
# Data collection statistics
|
||||||
|
self.stats = {
|
||||||
|
"total_samples": 0,
|
||||||
|
"successful_samples": 0,
|
||||||
|
"failed_samples": 0,
|
||||||
|
"average_sample_rate": 0.0,
|
||||||
|
"last_sample_time": None,
|
||||||
|
"connected_routers": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# Background tasks
|
||||||
|
self.collection_task = None
|
||||||
|
self.monitoring_task = None
|
||||||
|
|
||||||
|
# Data buffers
|
||||||
|
self.recent_samples = []
|
||||||
|
self.max_recent_samples = 1000
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""Initialize the hardware service."""
|
||||||
|
await self.start()
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""Start the hardware service."""
|
||||||
|
if self.is_running:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.logger.info("Starting hardware service...")
|
||||||
|
|
||||||
|
# Initialize router interfaces
|
||||||
|
await self._initialize_routers()
|
||||||
|
|
||||||
|
self.is_running = True
|
||||||
|
|
||||||
|
# Start background tasks
|
||||||
|
if not self.settings.mock_hardware:
|
||||||
|
self.collection_task = asyncio.create_task(self._data_collection_loop())
|
||||||
|
|
||||||
|
self.monitoring_task = asyncio.create_task(self._monitoring_loop())
|
||||||
|
|
||||||
|
self.logger.info("Hardware service started successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.last_error = str(e)
|
||||||
|
self.logger.error(f"Failed to start hardware service: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""Stop the hardware service."""
|
||||||
|
self.is_running = False
|
||||||
|
|
||||||
|
# Cancel background tasks
|
||||||
|
if self.collection_task:
|
||||||
|
self.collection_task.cancel()
|
||||||
|
try:
|
||||||
|
await self.collection_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if self.monitoring_task:
|
||||||
|
self.monitoring_task.cancel()
|
||||||
|
try:
|
||||||
|
await self.monitoring_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Disconnect from routers
|
||||||
|
await self._disconnect_routers()
|
||||||
|
|
||||||
|
self.logger.info("Hardware service stopped")
|
||||||
|
|
||||||
|
async def _initialize_routers(self):
|
||||||
|
"""Initialize router interfaces."""
|
||||||
|
try:
|
||||||
|
# Get router configurations from domain config
|
||||||
|
routers = self.domain_config.get_all_routers()
|
||||||
|
|
||||||
|
for router_config in routers:
|
||||||
|
if not router_config.enabled:
|
||||||
|
continue
|
||||||
|
|
||||||
|
router_id = router_config.router_id
|
||||||
|
|
||||||
|
# Create router interface
|
||||||
|
router_interface = RouterInterface(
|
||||||
|
router_id=router_id,
|
||||||
|
host=router_config.ip_address,
|
||||||
|
port=22, # Default SSH port
|
||||||
|
username="admin", # Default username
|
||||||
|
password="admin", # Default password
|
||||||
|
interface=router_config.interface,
|
||||||
|
mock_mode=self.settings.mock_hardware
|
||||||
|
)
|
||||||
|
|
||||||
|
# Connect to router
|
||||||
|
if not self.settings.mock_hardware:
|
||||||
|
await router_interface.connect()
|
||||||
|
|
||||||
|
self.router_interfaces[router_id] = router_interface
|
||||||
|
self.logger.info(f"Router interface initialized: {router_id}")
|
||||||
|
|
||||||
|
self.stats["connected_routers"] = len(self.router_interfaces)
|
||||||
|
|
||||||
|
if not self.router_interfaces:
|
||||||
|
self.logger.warning("No router interfaces configured")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to initialize routers: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _disconnect_routers(self):
|
||||||
|
"""Disconnect from all routers."""
|
||||||
|
for router_id, interface in self.router_interfaces.items():
|
||||||
|
try:
|
||||||
|
await interface.disconnect()
|
||||||
|
self.logger.info(f"Disconnected from router: {router_id}")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error disconnecting from router {router_id}: {e}")
|
||||||
|
|
||||||
|
self.router_interfaces.clear()
|
||||||
|
self.stats["connected_routers"] = 0
|
||||||
|
|
||||||
|
async def _data_collection_loop(self):
|
||||||
|
"""Background loop for data collection."""
|
||||||
|
try:
|
||||||
|
while self.is_running:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Collect data from all routers
|
||||||
|
await self._collect_data_from_routers()
|
||||||
|
|
||||||
|
# Calculate sleep time to maintain polling interval
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
sleep_time = max(0, self.settings.hardware_polling_interval - elapsed)
|
||||||
|
|
||||||
|
if sleep_time > 0:
|
||||||
|
await asyncio.sleep(sleep_time)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
self.logger.info("Data collection loop cancelled")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error in data collection loop: {e}")
|
||||||
|
self.last_error = str(e)
|
||||||
|
|
||||||
|
async def _monitoring_loop(self):
|
||||||
|
"""Background loop for hardware monitoring."""
|
||||||
|
try:
|
||||||
|
while self.is_running:
|
||||||
|
# Monitor router connections
|
||||||
|
await self._monitor_router_health()
|
||||||
|
|
||||||
|
# Update statistics
|
||||||
|
self._update_sample_rate_stats()
|
||||||
|
|
||||||
|
# Wait before next check
|
||||||
|
await asyncio.sleep(30) # Check every 30 seconds
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
self.logger.info("Monitoring loop cancelled")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error in monitoring loop: {e}")
|
||||||
|
|
||||||
|
async def _collect_data_from_routers(self):
|
||||||
|
"""Collect CSI data from all connected routers."""
|
||||||
|
for router_id, interface in self.router_interfaces.items():
|
||||||
|
try:
|
||||||
|
# Get CSI data from router
|
||||||
|
csi_data = await interface.get_csi_data()
|
||||||
|
|
||||||
|
if csi_data is not None:
|
||||||
|
# Process the collected data
|
||||||
|
await self._process_collected_data(router_id, csi_data)
|
||||||
|
|
||||||
|
self.stats["successful_samples"] += 1
|
||||||
|
self.stats["last_sample_time"] = datetime.now().isoformat()
|
||||||
|
else:
|
||||||
|
self.stats["failed_samples"] += 1
|
||||||
|
|
||||||
|
self.stats["total_samples"] += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error collecting data from router {router_id}: {e}")
|
||||||
|
self.stats["failed_samples"] += 1
|
||||||
|
self.stats["total_samples"] += 1
|
||||||
|
|
||||||
|
async def _process_collected_data(self, router_id: str, csi_data: np.ndarray):
|
||||||
|
"""Process collected CSI data."""
|
||||||
|
try:
|
||||||
|
# Create sample metadata
|
||||||
|
metadata = {
|
||||||
|
"router_id": router_id,
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"sample_rate": self.stats["average_sample_rate"],
|
||||||
|
"data_shape": csi_data.shape if hasattr(csi_data, 'shape') else None
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add to recent samples buffer
|
||||||
|
sample = {
|
||||||
|
"router_id": router_id,
|
||||||
|
"timestamp": metadata["timestamp"],
|
||||||
|
"data": csi_data,
|
||||||
|
"metadata": metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
self.recent_samples.append(sample)
|
||||||
|
|
||||||
|
# Maintain buffer size
|
||||||
|
if len(self.recent_samples) > self.max_recent_samples:
|
||||||
|
self.recent_samples.pop(0)
|
||||||
|
|
||||||
|
# Notify other services (this would typically be done through an event system)
|
||||||
|
# For now, we'll just log the data collection
|
||||||
|
self.logger.debug(f"Collected CSI data from {router_id}: shape {csi_data.shape if hasattr(csi_data, 'shape') else 'unknown'}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error processing collected data: {e}")
|
||||||
|
|
||||||
|
async def _monitor_router_health(self):
|
||||||
|
"""Monitor health of router connections."""
|
||||||
|
healthy_routers = 0
|
||||||
|
|
||||||
|
for router_id, interface in self.router_interfaces.items():
|
||||||
|
try:
|
||||||
|
is_healthy = await interface.check_health()
|
||||||
|
|
||||||
|
if is_healthy:
|
||||||
|
healthy_routers += 1
|
||||||
|
else:
|
||||||
|
self.logger.warning(f"Router {router_id} is unhealthy")
|
||||||
|
|
||||||
|
# Try to reconnect if not in mock mode
|
||||||
|
if not self.settings.mock_hardware:
|
||||||
|
try:
|
||||||
|
await interface.reconnect()
|
||||||
|
self.logger.info(f"Reconnected to router {router_id}")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to reconnect to router {router_id}: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error checking health of router {router_id}: {e}")
|
||||||
|
|
||||||
|
self.stats["connected_routers"] = healthy_routers
|
||||||
|
|
||||||
|
def _update_sample_rate_stats(self):
|
||||||
|
"""Update sample rate statistics."""
|
||||||
|
if len(self.recent_samples) < 2:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Calculate sample rate from recent samples
|
||||||
|
recent_count = min(100, len(self.recent_samples))
|
||||||
|
recent_samples = self.recent_samples[-recent_count:]
|
||||||
|
|
||||||
|
if len(recent_samples) >= 2:
|
||||||
|
# Calculate time differences
|
||||||
|
time_diffs = []
|
||||||
|
for i in range(1, len(recent_samples)):
|
||||||
|
try:
|
||||||
|
t1 = datetime.fromisoformat(recent_samples[i-1]["timestamp"])
|
||||||
|
t2 = datetime.fromisoformat(recent_samples[i]["timestamp"])
|
||||||
|
diff = (t2 - t1).total_seconds()
|
||||||
|
if diff > 0:
|
||||||
|
time_diffs.append(diff)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if time_diffs:
|
||||||
|
avg_interval = sum(time_diffs) / len(time_diffs)
|
||||||
|
self.stats["average_sample_rate"] = 1.0 / avg_interval if avg_interval > 0 else 0.0
|
||||||
|
|
||||||
|
async def get_router_status(self, router_id: str) -> Dict[str, Any]:
|
||||||
|
"""Get status of a specific router."""
|
||||||
|
if router_id not in self.router_interfaces:
|
||||||
|
raise ValueError(f"Router {router_id} not found")
|
||||||
|
|
||||||
|
interface = self.router_interfaces[router_id]
|
||||||
|
|
||||||
|
try:
|
||||||
|
is_healthy = await interface.check_health()
|
||||||
|
status = await interface.get_status()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"router_id": router_id,
|
||||||
|
"healthy": is_healthy,
|
||||||
|
"connected": status.get("connected", False),
|
||||||
|
"last_data_time": status.get("last_data_time"),
|
||||||
|
"error_count": status.get("error_count", 0),
|
||||||
|
"configuration": status.get("configuration", {})
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"router_id": router_id,
|
||||||
|
"healthy": False,
|
||||||
|
"connected": False,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_all_router_status(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Get status of all routers."""
|
||||||
|
statuses = []
|
||||||
|
|
||||||
|
for router_id in self.router_interfaces:
|
||||||
|
try:
|
||||||
|
status = await self.get_router_status(router_id)
|
||||||
|
statuses.append(status)
|
||||||
|
except Exception as e:
|
||||||
|
statuses.append({
|
||||||
|
"router_id": router_id,
|
||||||
|
"healthy": False,
|
||||||
|
"error": str(e)
|
||||||
|
})
|
||||||
|
|
||||||
|
return statuses
|
||||||
|
|
||||||
|
async def get_recent_data(self, router_id: Optional[str] = None, limit: int = 100) -> List[Dict[str, Any]]:
|
||||||
|
"""Get recent CSI data samples."""
|
||||||
|
samples = self.recent_samples[-limit:] if limit else self.recent_samples
|
||||||
|
|
||||||
|
if router_id:
|
||||||
|
samples = [s for s in samples if s["router_id"] == router_id]
|
||||||
|
|
||||||
|
# Convert numpy arrays to lists for JSON serialization
|
||||||
|
result = []
|
||||||
|
for sample in samples:
|
||||||
|
sample_copy = sample.copy()
|
||||||
|
if isinstance(sample_copy["data"], np.ndarray):
|
||||||
|
sample_copy["data"] = sample_copy["data"].tolist()
|
||||||
|
result.append(sample_copy)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def get_status(self) -> Dict[str, Any]:
|
||||||
|
"""Get service status."""
|
||||||
|
return {
|
||||||
|
"status": "healthy" if self.is_running and not self.last_error else "unhealthy",
|
||||||
|
"running": self.is_running,
|
||||||
|
"last_error": self.last_error,
|
||||||
|
"statistics": self.stats.copy(),
|
||||||
|
"configuration": {
|
||||||
|
"mock_hardware": self.settings.mock_hardware,
|
||||||
|
"wifi_interface": self.settings.wifi_interface,
|
||||||
|
"polling_interval": self.settings.hardware_polling_interval,
|
||||||
|
"buffer_size": self.settings.csi_buffer_size
|
||||||
|
},
|
||||||
|
"routers": await self.get_all_router_status()
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_metrics(self) -> Dict[str, Any]:
|
||||||
|
"""Get service metrics."""
|
||||||
|
total_samples = self.stats["total_samples"]
|
||||||
|
success_rate = self.stats["successful_samples"] / max(1, total_samples)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"hardware_service": {
|
||||||
|
"total_samples": total_samples,
|
||||||
|
"successful_samples": self.stats["successful_samples"],
|
||||||
|
"failed_samples": self.stats["failed_samples"],
|
||||||
|
"success_rate": success_rate,
|
||||||
|
"average_sample_rate": self.stats["average_sample_rate"],
|
||||||
|
"connected_routers": self.stats["connected_routers"],
|
||||||
|
"last_sample_time": self.stats["last_sample_time"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async def reset(self):
|
||||||
|
"""Reset service state."""
|
||||||
|
self.stats = {
|
||||||
|
"total_samples": 0,
|
||||||
|
"successful_samples": 0,
|
||||||
|
"failed_samples": 0,
|
||||||
|
"average_sample_rate": 0.0,
|
||||||
|
"last_sample_time": None,
|
||||||
|
"connected_routers": len(self.router_interfaces)
|
||||||
|
}
|
||||||
|
|
||||||
|
self.recent_samples.clear()
|
||||||
|
self.last_error = None
|
||||||
|
|
||||||
|
self.logger.info("Hardware service reset")
|
||||||
|
|
||||||
|
async def trigger_manual_collection(self, router_id: Optional[str] = None) -> Dict[str, Any]:
|
||||||
|
"""Manually trigger data collection."""
|
||||||
|
if not self.is_running:
|
||||||
|
raise RuntimeError("Hardware service is not running")
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
if router_id:
|
||||||
|
# Collect from specific router
|
||||||
|
if router_id not in self.router_interfaces:
|
||||||
|
raise ValueError(f"Router {router_id} not found")
|
||||||
|
|
||||||
|
interface = self.router_interfaces[router_id]
|
||||||
|
try:
|
||||||
|
csi_data = await interface.get_csi_data()
|
||||||
|
if csi_data is not None:
|
||||||
|
await self._process_collected_data(router_id, csi_data)
|
||||||
|
results[router_id] = {"success": True, "data_shape": csi_data.shape if hasattr(csi_data, 'shape') else None}
|
||||||
|
else:
|
||||||
|
results[router_id] = {"success": False, "error": "No data received"}
|
||||||
|
except Exception as e:
|
||||||
|
results[router_id] = {"success": False, "error": str(e)}
|
||||||
|
else:
|
||||||
|
# Collect from all routers
|
||||||
|
await self._collect_data_from_routers()
|
||||||
|
results = {"message": "Manual collection triggered for all routers"}
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def health_check(self) -> Dict[str, Any]:
|
||||||
|
"""Perform health check."""
|
||||||
|
try:
|
||||||
|
status = "healthy" if self.is_running and not self.last_error else "unhealthy"
|
||||||
|
|
||||||
|
# Check router health
|
||||||
|
healthy_routers = 0
|
||||||
|
total_routers = len(self.router_interfaces)
|
||||||
|
|
||||||
|
for router_id, interface in self.router_interfaces.items():
|
||||||
|
try:
|
||||||
|
if await interface.check_health():
|
||||||
|
healthy_routers += 1
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": status,
|
||||||
|
"message": self.last_error if self.last_error else "Hardware service is running normally",
|
||||||
|
"connected_routers": f"{healthy_routers}/{total_routers}",
|
||||||
|
"metrics": {
|
||||||
|
"total_samples": self.stats["total_samples"],
|
||||||
|
"success_rate": (
|
||||||
|
self.stats["successful_samples"] / max(1, self.stats["total_samples"])
|
||||||
|
),
|
||||||
|
"average_sample_rate": self.stats["average_sample_rate"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"status": "unhealthy",
|
||||||
|
"message": f"Health check failed: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def is_ready(self) -> bool:
|
||||||
|
"""Check if service is ready."""
|
||||||
|
return self.is_running and len(self.router_interfaces) > 0
|
||||||
706
src/services/pose_service.py
Normal file
706
src/services/pose_service.py
Normal file
@@ -0,0 +1,706 @@
|
|||||||
|
"""
|
||||||
|
Pose estimation service for WiFi-DensePose API
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from src.config.settings import Settings
|
||||||
|
from src.config.domains import DomainConfig
|
||||||
|
from src.core.csi_processor import CSIProcessor
|
||||||
|
from src.core.phase_sanitizer import PhaseSanitizer
|
||||||
|
from src.models.densepose_head import DensePoseHead
|
||||||
|
from src.models.modality_translation import ModalityTranslationNetwork
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PoseService:
|
||||||
|
"""Service for pose estimation operations."""
|
||||||
|
|
||||||
|
def __init__(self, settings: Settings, domain_config: DomainConfig):
|
||||||
|
"""Initialize pose service."""
|
||||||
|
self.settings = settings
|
||||||
|
self.domain_config = domain_config
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
self.csi_processor = None
|
||||||
|
self.phase_sanitizer = None
|
||||||
|
self.densepose_model = None
|
||||||
|
self.modality_translator = None
|
||||||
|
|
||||||
|
# Service state
|
||||||
|
self.is_initialized = False
|
||||||
|
self.is_running = False
|
||||||
|
self.last_error = None
|
||||||
|
|
||||||
|
# Processing statistics
|
||||||
|
self.stats = {
|
||||||
|
"total_processed": 0,
|
||||||
|
"successful_detections": 0,
|
||||||
|
"failed_detections": 0,
|
||||||
|
"average_confidence": 0.0,
|
||||||
|
"processing_time_ms": 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""Initialize the pose service."""
|
||||||
|
try:
|
||||||
|
self.logger.info("Initializing pose service...")
|
||||||
|
|
||||||
|
# Initialize CSI processor
|
||||||
|
csi_config = {
|
||||||
|
'buffer_size': self.settings.csi_buffer_size,
|
||||||
|
'sample_rate': 1000, # Default sampling rate
|
||||||
|
'num_subcarriers': 56,
|
||||||
|
'num_antennas': 3
|
||||||
|
}
|
||||||
|
self.csi_processor = CSIProcessor(config=csi_config)
|
||||||
|
|
||||||
|
# Initialize phase sanitizer
|
||||||
|
self.phase_sanitizer = PhaseSanitizer()
|
||||||
|
|
||||||
|
# Initialize models if not mocking
|
||||||
|
if not self.settings.mock_pose_data:
|
||||||
|
await self._initialize_models()
|
||||||
|
else:
|
||||||
|
self.logger.info("Using mock pose data for development")
|
||||||
|
|
||||||
|
self.is_initialized = True
|
||||||
|
self.logger.info("Pose service initialized successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.last_error = str(e)
|
||||||
|
self.logger.error(f"Failed to initialize pose service: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _initialize_models(self):
|
||||||
|
"""Initialize neural network models."""
|
||||||
|
try:
|
||||||
|
# Initialize DensePose model
|
||||||
|
if self.settings.pose_model_path:
|
||||||
|
self.densepose_model = DensePoseHead()
|
||||||
|
# Load model weights if path is provided
|
||||||
|
# model_state = torch.load(self.settings.pose_model_path)
|
||||||
|
# self.densepose_model.load_state_dict(model_state)
|
||||||
|
self.logger.info("DensePose model loaded")
|
||||||
|
else:
|
||||||
|
self.logger.warning("No pose model path provided, using default model")
|
||||||
|
self.densepose_model = DensePoseHead()
|
||||||
|
|
||||||
|
# Initialize modality translation
|
||||||
|
config = {
|
||||||
|
'input_channels': 64, # CSI data channels
|
||||||
|
'hidden_channels': [128, 256, 512],
|
||||||
|
'output_channels': 256, # Visual feature channels
|
||||||
|
'use_attention': True
|
||||||
|
}
|
||||||
|
self.modality_translator = ModalityTranslationNetwork(config)
|
||||||
|
|
||||||
|
# Set models to evaluation mode
|
||||||
|
self.densepose_model.eval()
|
||||||
|
self.modality_translator.eval()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to initialize models: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""Start the pose service."""
|
||||||
|
if not self.is_initialized:
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
self.is_running = True
|
||||||
|
self.logger.info("Pose service started")
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""Stop the pose service."""
|
||||||
|
self.is_running = False
|
||||||
|
self.logger.info("Pose service stopped")
|
||||||
|
|
||||||
|
async def process_csi_data(self, csi_data: np.ndarray, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Process CSI data and estimate poses."""
|
||||||
|
if not self.is_running:
|
||||||
|
raise RuntimeError("Pose service is not running")
|
||||||
|
|
||||||
|
start_time = datetime.now()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Process CSI data
|
||||||
|
processed_csi = await self._process_csi(csi_data, metadata)
|
||||||
|
|
||||||
|
# Estimate poses
|
||||||
|
poses = await self._estimate_poses(processed_csi, metadata)
|
||||||
|
|
||||||
|
# Update statistics
|
||||||
|
processing_time = (datetime.now() - start_time).total_seconds() * 1000
|
||||||
|
self._update_stats(poses, processing_time)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"timestamp": start_time.isoformat(),
|
||||||
|
"poses": poses,
|
||||||
|
"metadata": metadata,
|
||||||
|
"processing_time_ms": processing_time,
|
||||||
|
"confidence_scores": [pose.get("confidence", 0.0) for pose in poses]
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.last_error = str(e)
|
||||||
|
self.stats["failed_detections"] += 1
|
||||||
|
self.logger.error(f"Error processing CSI data: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _process_csi(self, csi_data: np.ndarray, metadata: Dict[str, Any]) -> np.ndarray:
|
||||||
|
"""Process raw CSI data."""
|
||||||
|
# Add CSI data to processor
|
||||||
|
self.csi_processor.add_data(csi_data, metadata.get("timestamp", datetime.now()))
|
||||||
|
|
||||||
|
# Get processed data
|
||||||
|
processed_data = self.csi_processor.get_processed_data()
|
||||||
|
|
||||||
|
# Apply phase sanitization
|
||||||
|
if processed_data is not None:
|
||||||
|
sanitized_data = self.phase_sanitizer.sanitize(processed_data)
|
||||||
|
return sanitized_data
|
||||||
|
|
||||||
|
return csi_data
|
||||||
|
|
||||||
|
async def _estimate_poses(self, csi_data: np.ndarray, metadata: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||||
|
"""Estimate poses from processed CSI data."""
|
||||||
|
if self.settings.mock_pose_data:
|
||||||
|
return self._generate_mock_poses()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Convert CSI data to tensor
|
||||||
|
csi_tensor = torch.from_numpy(csi_data).float()
|
||||||
|
|
||||||
|
# Add batch dimension if needed
|
||||||
|
if len(csi_tensor.shape) == 2:
|
||||||
|
csi_tensor = csi_tensor.unsqueeze(0)
|
||||||
|
|
||||||
|
# Translate modality (CSI to visual-like features)
|
||||||
|
with torch.no_grad():
|
||||||
|
visual_features = self.modality_translator(csi_tensor)
|
||||||
|
|
||||||
|
# Estimate poses using DensePose
|
||||||
|
pose_outputs = self.densepose_model(visual_features)
|
||||||
|
|
||||||
|
# Convert outputs to pose detections
|
||||||
|
poses = self._parse_pose_outputs(pose_outputs)
|
||||||
|
|
||||||
|
# Filter by confidence threshold
|
||||||
|
filtered_poses = [
|
||||||
|
pose for pose in poses
|
||||||
|
if pose.get("confidence", 0.0) >= self.settings.pose_confidence_threshold
|
||||||
|
]
|
||||||
|
|
||||||
|
# Limit number of persons
|
||||||
|
if len(filtered_poses) > self.settings.pose_max_persons:
|
||||||
|
filtered_poses = sorted(
|
||||||
|
filtered_poses,
|
||||||
|
key=lambda x: x.get("confidence", 0.0),
|
||||||
|
reverse=True
|
||||||
|
)[:self.settings.pose_max_persons]
|
||||||
|
|
||||||
|
return filtered_poses
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error in pose estimation: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _parse_pose_outputs(self, outputs: torch.Tensor) -> List[Dict[str, Any]]:
|
||||||
|
"""Parse neural network outputs into pose detections."""
|
||||||
|
poses = []
|
||||||
|
|
||||||
|
# This is a simplified parsing - in reality, this would depend on the model architecture
|
||||||
|
# For now, generate mock poses based on the output shape
|
||||||
|
batch_size = outputs.shape[0]
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
# Extract pose information (mock implementation)
|
||||||
|
confidence = float(torch.sigmoid(outputs[i, 0]).item()) if outputs.shape[1] > 0 else 0.5
|
||||||
|
|
||||||
|
pose = {
|
||||||
|
"person_id": i,
|
||||||
|
"confidence": confidence,
|
||||||
|
"keypoints": self._generate_keypoints(),
|
||||||
|
"bounding_box": self._generate_bounding_box(),
|
||||||
|
"activity": self._classify_activity(outputs[i] if len(outputs.shape) > 1 else outputs),
|
||||||
|
"timestamp": datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
poses.append(pose)
|
||||||
|
|
||||||
|
return poses
|
||||||
|
|
||||||
|
def _generate_mock_poses(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Generate mock pose data for development."""
|
||||||
|
import random
|
||||||
|
|
||||||
|
num_persons = random.randint(1, min(3, self.settings.pose_max_persons))
|
||||||
|
poses = []
|
||||||
|
|
||||||
|
for i in range(num_persons):
|
||||||
|
confidence = random.uniform(0.3, 0.95)
|
||||||
|
|
||||||
|
pose = {
|
||||||
|
"person_id": i,
|
||||||
|
"confidence": confidence,
|
||||||
|
"keypoints": self._generate_keypoints(),
|
||||||
|
"bounding_box": self._generate_bounding_box(),
|
||||||
|
"activity": random.choice(["standing", "sitting", "walking", "lying"]),
|
||||||
|
"timestamp": datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
poses.append(pose)
|
||||||
|
|
||||||
|
return poses
|
||||||
|
|
||||||
|
def _generate_keypoints(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Generate keypoints for a person."""
|
||||||
|
import random
|
||||||
|
|
||||||
|
keypoint_names = [
|
||||||
|
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
|
||||||
|
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
|
||||||
|
"left_wrist", "right_wrist", "left_hip", "right_hip",
|
||||||
|
"left_knee", "right_knee", "left_ankle", "right_ankle"
|
||||||
|
]
|
||||||
|
|
||||||
|
keypoints = []
|
||||||
|
for name in keypoint_names:
|
||||||
|
keypoints.append({
|
||||||
|
"name": name,
|
||||||
|
"x": random.uniform(0.1, 0.9),
|
||||||
|
"y": random.uniform(0.1, 0.9),
|
||||||
|
"confidence": random.uniform(0.5, 0.95)
|
||||||
|
})
|
||||||
|
|
||||||
|
return keypoints
|
||||||
|
|
||||||
|
def _generate_bounding_box(self) -> Dict[str, float]:
|
||||||
|
"""Generate bounding box for a person."""
|
||||||
|
import random
|
||||||
|
|
||||||
|
x = random.uniform(0.1, 0.6)
|
||||||
|
y = random.uniform(0.1, 0.6)
|
||||||
|
width = random.uniform(0.2, 0.4)
|
||||||
|
height = random.uniform(0.3, 0.5)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"x": x,
|
||||||
|
"y": y,
|
||||||
|
"width": width,
|
||||||
|
"height": height
|
||||||
|
}
|
||||||
|
|
||||||
|
def _classify_activity(self, features: torch.Tensor) -> str:
|
||||||
|
"""Classify activity from features."""
|
||||||
|
# Simple mock classification
|
||||||
|
import random
|
||||||
|
activities = ["standing", "sitting", "walking", "lying", "unknown"]
|
||||||
|
return random.choice(activities)
|
||||||
|
|
||||||
|
def _update_stats(self, poses: List[Dict[str, Any]], processing_time: float):
|
||||||
|
"""Update processing statistics."""
|
||||||
|
self.stats["total_processed"] += 1
|
||||||
|
|
||||||
|
if poses:
|
||||||
|
self.stats["successful_detections"] += 1
|
||||||
|
confidences = [pose.get("confidence", 0.0) for pose in poses]
|
||||||
|
avg_confidence = sum(confidences) / len(confidences)
|
||||||
|
|
||||||
|
# Update running average
|
||||||
|
total = self.stats["successful_detections"]
|
||||||
|
current_avg = self.stats["average_confidence"]
|
||||||
|
self.stats["average_confidence"] = (current_avg * (total - 1) + avg_confidence) / total
|
||||||
|
else:
|
||||||
|
self.stats["failed_detections"] += 1
|
||||||
|
|
||||||
|
# Update processing time (running average)
|
||||||
|
total = self.stats["total_processed"]
|
||||||
|
current_avg = self.stats["processing_time_ms"]
|
||||||
|
self.stats["processing_time_ms"] = (current_avg * (total - 1) + processing_time) / total
|
||||||
|
|
||||||
|
async def get_status(self) -> Dict[str, Any]:
|
||||||
|
"""Get service status."""
|
||||||
|
return {
|
||||||
|
"status": "healthy" if self.is_running and not self.last_error else "unhealthy",
|
||||||
|
"initialized": self.is_initialized,
|
||||||
|
"running": self.is_running,
|
||||||
|
"last_error": self.last_error,
|
||||||
|
"statistics": self.stats.copy(),
|
||||||
|
"configuration": {
|
||||||
|
"mock_data": self.settings.mock_pose_data,
|
||||||
|
"confidence_threshold": self.settings.pose_confidence_threshold,
|
||||||
|
"max_persons": self.settings.pose_max_persons,
|
||||||
|
"batch_size": self.settings.pose_processing_batch_size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_metrics(self) -> Dict[str, Any]:
|
||||||
|
"""Get service metrics."""
|
||||||
|
return {
|
||||||
|
"pose_service": {
|
||||||
|
"total_processed": self.stats["total_processed"],
|
||||||
|
"successful_detections": self.stats["successful_detections"],
|
||||||
|
"failed_detections": self.stats["failed_detections"],
|
||||||
|
"success_rate": (
|
||||||
|
self.stats["successful_detections"] / max(1, self.stats["total_processed"])
|
||||||
|
),
|
||||||
|
"average_confidence": self.stats["average_confidence"],
|
||||||
|
"average_processing_time_ms": self.stats["processing_time_ms"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async def reset(self):
|
||||||
|
"""Reset service state."""
|
||||||
|
self.stats = {
|
||||||
|
"total_processed": 0,
|
||||||
|
"successful_detections": 0,
|
||||||
|
"failed_detections": 0,
|
||||||
|
"average_confidence": 0.0,
|
||||||
|
"processing_time_ms": 0.0
|
||||||
|
}
|
||||||
|
self.last_error = None
|
||||||
|
self.logger.info("Pose service reset")
|
||||||
|
|
||||||
|
# API endpoint methods
|
||||||
|
async def estimate_poses(self, zone_ids=None, confidence_threshold=None, max_persons=None,
|
||||||
|
include_keypoints=True, include_segmentation=False):
|
||||||
|
"""Estimate poses with API parameters."""
|
||||||
|
try:
|
||||||
|
# Generate mock CSI data for estimation
|
||||||
|
mock_csi = np.random.randn(64, 56, 3) # Mock CSI data
|
||||||
|
metadata = {
|
||||||
|
"timestamp": datetime.now(),
|
||||||
|
"zone_ids": zone_ids or ["zone_1"],
|
||||||
|
"confidence_threshold": confidence_threshold or self.settings.pose_confidence_threshold,
|
||||||
|
"max_persons": max_persons or self.settings.pose_max_persons
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process the data
|
||||||
|
result = await self.process_csi_data(mock_csi, metadata)
|
||||||
|
|
||||||
|
# Format for API response
|
||||||
|
persons = []
|
||||||
|
for i, pose in enumerate(result["poses"]):
|
||||||
|
person = {
|
||||||
|
"person_id": str(pose["person_id"]),
|
||||||
|
"confidence": pose["confidence"],
|
||||||
|
"bounding_box": pose["bounding_box"],
|
||||||
|
"zone_id": zone_ids[0] if zone_ids else "zone_1",
|
||||||
|
"activity": pose["activity"],
|
||||||
|
"timestamp": datetime.fromisoformat(pose["timestamp"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if include_keypoints:
|
||||||
|
person["keypoints"] = pose["keypoints"]
|
||||||
|
|
||||||
|
if include_segmentation:
|
||||||
|
person["segmentation"] = {"mask": "mock_segmentation_data"}
|
||||||
|
|
||||||
|
persons.append(person)
|
||||||
|
|
||||||
|
# Zone summary
|
||||||
|
zone_summary = {}
|
||||||
|
for zone_id in (zone_ids or ["zone_1"]):
|
||||||
|
zone_summary[zone_id] = len([p for p in persons if p.get("zone_id") == zone_id])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"timestamp": datetime.now(),
|
||||||
|
"frame_id": f"frame_{int(datetime.now().timestamp())}",
|
||||||
|
"persons": persons,
|
||||||
|
"zone_summary": zone_summary,
|
||||||
|
"processing_time_ms": result["processing_time_ms"],
|
||||||
|
"metadata": {"mock_data": self.settings.mock_pose_data}
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error in estimate_poses: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def analyze_with_params(self, zone_ids=None, confidence_threshold=None, max_persons=None,
|
||||||
|
include_keypoints=True, include_segmentation=False):
|
||||||
|
"""Analyze pose data with custom parameters."""
|
||||||
|
return await self.estimate_poses(zone_ids, confidence_threshold, max_persons,
|
||||||
|
include_keypoints, include_segmentation)
|
||||||
|
|
||||||
|
async def get_zone_occupancy(self, zone_id: str):
|
||||||
|
"""Get current occupancy for a specific zone."""
|
||||||
|
try:
|
||||||
|
# Mock occupancy data
|
||||||
|
import random
|
||||||
|
count = random.randint(0, 5)
|
||||||
|
persons = []
|
||||||
|
|
||||||
|
for i in range(count):
|
||||||
|
persons.append({
|
||||||
|
"person_id": f"person_{i}",
|
||||||
|
"confidence": random.uniform(0.7, 0.95),
|
||||||
|
"activity": random.choice(["standing", "sitting", "walking"])
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"count": count,
|
||||||
|
"max_occupancy": 10,
|
||||||
|
"persons": persons,
|
||||||
|
"timestamp": datetime.now()
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error getting zone occupancy: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_zones_summary(self):
|
||||||
|
"""Get occupancy summary for all zones."""
|
||||||
|
try:
|
||||||
|
import random
|
||||||
|
zones = ["zone_1", "zone_2", "zone_3", "zone_4"]
|
||||||
|
zone_data = {}
|
||||||
|
total_persons = 0
|
||||||
|
active_zones = 0
|
||||||
|
|
||||||
|
for zone_id in zones:
|
||||||
|
count = random.randint(0, 3)
|
||||||
|
zone_data[zone_id] = {
|
||||||
|
"occupancy": count,
|
||||||
|
"max_occupancy": 10,
|
||||||
|
"status": "active" if count > 0 else "inactive"
|
||||||
|
}
|
||||||
|
total_persons += count
|
||||||
|
if count > 0:
|
||||||
|
active_zones += 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_persons": total_persons,
|
||||||
|
"zones": zone_data,
|
||||||
|
"active_zones": active_zones
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error getting zones summary: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_historical_data(self, start_time, end_time, zone_ids=None,
|
||||||
|
aggregation_interval=300, include_raw_data=False):
|
||||||
|
"""Get historical pose estimation data."""
|
||||||
|
try:
|
||||||
|
# Mock historical data
|
||||||
|
import random
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
current_time = start_time
|
||||||
|
aggregated_data = []
|
||||||
|
raw_data = [] if include_raw_data else None
|
||||||
|
|
||||||
|
while current_time < end_time:
|
||||||
|
# Generate aggregated data point
|
||||||
|
data_point = {
|
||||||
|
"timestamp": current_time,
|
||||||
|
"total_persons": random.randint(0, 8),
|
||||||
|
"zones": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
for zone_id in (zone_ids or ["zone_1", "zone_2", "zone_3"]):
|
||||||
|
data_point["zones"][zone_id] = {
|
||||||
|
"occupancy": random.randint(0, 3),
|
||||||
|
"avg_confidence": random.uniform(0.7, 0.95)
|
||||||
|
}
|
||||||
|
|
||||||
|
aggregated_data.append(data_point)
|
||||||
|
|
||||||
|
# Generate raw data if requested
|
||||||
|
if include_raw_data:
|
||||||
|
for _ in range(random.randint(0, 5)):
|
||||||
|
raw_data.append({
|
||||||
|
"timestamp": current_time + timedelta(seconds=random.randint(0, aggregation_interval)),
|
||||||
|
"person_id": f"person_{random.randint(1, 10)}",
|
||||||
|
"zone_id": random.choice(zone_ids or ["zone_1", "zone_2", "zone_3"]),
|
||||||
|
"confidence": random.uniform(0.5, 0.95),
|
||||||
|
"activity": random.choice(["standing", "sitting", "walking"])
|
||||||
|
})
|
||||||
|
|
||||||
|
current_time += timedelta(seconds=aggregation_interval)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"aggregated_data": aggregated_data,
|
||||||
|
"raw_data": raw_data,
|
||||||
|
"total_records": len(aggregated_data)
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error getting historical data: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_recent_activities(self, zone_id=None, limit=10):
|
||||||
|
"""Get recently detected activities."""
|
||||||
|
try:
|
||||||
|
import random
|
||||||
|
activities = []
|
||||||
|
|
||||||
|
for i in range(limit):
|
||||||
|
activity = {
|
||||||
|
"activity_id": f"activity_{i}",
|
||||||
|
"person_id": f"person_{random.randint(1, 5)}",
|
||||||
|
"zone_id": zone_id or random.choice(["zone_1", "zone_2", "zone_3"]),
|
||||||
|
"activity": random.choice(["standing", "sitting", "walking", "lying"]),
|
||||||
|
"confidence": random.uniform(0.6, 0.95),
|
||||||
|
"timestamp": datetime.now() - timedelta(minutes=random.randint(0, 60)),
|
||||||
|
"duration_seconds": random.randint(10, 300)
|
||||||
|
}
|
||||||
|
activities.append(activity)
|
||||||
|
|
||||||
|
return activities
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error getting recent activities: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def is_calibrating(self):
|
||||||
|
"""Check if calibration is in progress."""
|
||||||
|
return False # Mock implementation
|
||||||
|
|
||||||
|
async def start_calibration(self):
|
||||||
|
"""Start calibration process."""
|
||||||
|
import uuid
|
||||||
|
calibration_id = str(uuid.uuid4())
|
||||||
|
self.logger.info(f"Started calibration: {calibration_id}")
|
||||||
|
return calibration_id
|
||||||
|
|
||||||
|
async def run_calibration(self, calibration_id):
|
||||||
|
"""Run calibration process."""
|
||||||
|
self.logger.info(f"Running calibration: {calibration_id}")
|
||||||
|
# Mock calibration process
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
self.logger.info(f"Calibration completed: {calibration_id}")
|
||||||
|
|
||||||
|
async def get_calibration_status(self):
|
||||||
|
"""Get current calibration status."""
|
||||||
|
return {
|
||||||
|
"is_calibrating": False,
|
||||||
|
"calibration_id": None,
|
||||||
|
"progress_percent": 100,
|
||||||
|
"current_step": "completed",
|
||||||
|
"estimated_remaining_minutes": 0,
|
||||||
|
"last_calibration": datetime.now() - timedelta(hours=1)
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_statistics(self, start_time, end_time):
|
||||||
|
"""Get pose estimation statistics."""
|
||||||
|
try:
|
||||||
|
import random
|
||||||
|
|
||||||
|
# Mock statistics
|
||||||
|
total_detections = random.randint(100, 1000)
|
||||||
|
successful_detections = int(total_detections * random.uniform(0.8, 0.95))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_detections": total_detections,
|
||||||
|
"successful_detections": successful_detections,
|
||||||
|
"failed_detections": total_detections - successful_detections,
|
||||||
|
"success_rate": successful_detections / total_detections,
|
||||||
|
"average_confidence": random.uniform(0.75, 0.90),
|
||||||
|
"average_processing_time_ms": random.uniform(50, 200),
|
||||||
|
"unique_persons": random.randint(5, 20),
|
||||||
|
"most_active_zone": random.choice(["zone_1", "zone_2", "zone_3"]),
|
||||||
|
"activity_distribution": {
|
||||||
|
"standing": random.uniform(0.3, 0.5),
|
||||||
|
"sitting": random.uniform(0.2, 0.4),
|
||||||
|
"walking": random.uniform(0.1, 0.3),
|
||||||
|
"lying": random.uniform(0.0, 0.1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error getting statistics: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def process_segmentation_data(self, frame_id):
|
||||||
|
"""Process segmentation data in background."""
|
||||||
|
self.logger.info(f"Processing segmentation data for frame: {frame_id}")
|
||||||
|
# Mock background processing
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
self.logger.info(f"Segmentation processing completed for frame: {frame_id}")
|
||||||
|
|
||||||
|
# WebSocket streaming methods
|
||||||
|
async def get_current_pose_data(self):
|
||||||
|
"""Get current pose data for streaming."""
|
||||||
|
try:
|
||||||
|
# Generate current pose data
|
||||||
|
result = await self.estimate_poses()
|
||||||
|
|
||||||
|
# Format data by zones for WebSocket streaming
|
||||||
|
zone_data = {}
|
||||||
|
|
||||||
|
# Group persons by zone
|
||||||
|
for person in result["persons"]:
|
||||||
|
zone_id = person.get("zone_id", "zone_1")
|
||||||
|
|
||||||
|
if zone_id not in zone_data:
|
||||||
|
zone_data[zone_id] = {
|
||||||
|
"pose": {
|
||||||
|
"persons": [],
|
||||||
|
"count": 0
|
||||||
|
},
|
||||||
|
"confidence": 0.0,
|
||||||
|
"activity": None,
|
||||||
|
"metadata": {
|
||||||
|
"frame_id": result["frame_id"],
|
||||||
|
"processing_time_ms": result["processing_time_ms"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
zone_data[zone_id]["pose"]["persons"].append(person)
|
||||||
|
zone_data[zone_id]["pose"]["count"] += 1
|
||||||
|
|
||||||
|
# Update zone confidence (average)
|
||||||
|
current_confidence = zone_data[zone_id]["confidence"]
|
||||||
|
person_confidence = person.get("confidence", 0.0)
|
||||||
|
zone_data[zone_id]["confidence"] = (current_confidence + person_confidence) / 2
|
||||||
|
|
||||||
|
# Set activity if not already set
|
||||||
|
if not zone_data[zone_id]["activity"] and person.get("activity"):
|
||||||
|
zone_data[zone_id]["activity"] = person["activity"]
|
||||||
|
|
||||||
|
return zone_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error getting current pose data: {e}")
|
||||||
|
# Return empty zone data on error
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Health check methods
|
||||||
|
async def health_check(self):
|
||||||
|
"""Perform health check."""
|
||||||
|
try:
|
||||||
|
status = "healthy" if self.is_running and not self.last_error else "unhealthy"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": status,
|
||||||
|
"message": self.last_error if self.last_error else "Service is running normally",
|
||||||
|
"uptime_seconds": 0.0, # TODO: Implement actual uptime tracking
|
||||||
|
"metrics": {
|
||||||
|
"total_processed": self.stats["total_processed"],
|
||||||
|
"success_rate": (
|
||||||
|
self.stats["successful_detections"] / max(1, self.stats["total_processed"])
|
||||||
|
),
|
||||||
|
"average_processing_time_ms": self.stats["processing_time_ms"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"status": "unhealthy",
|
||||||
|
"message": f"Health check failed: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def is_ready(self):
|
||||||
|
"""Check if service is ready."""
|
||||||
|
return self.is_initialized and self.is_running
|
||||||
397
src/services/stream_service.py
Normal file
397
src/services/stream_service.py
Normal file
@@ -0,0 +1,397 @@
|
|||||||
|
"""
|
||||||
|
Real-time streaming service for WiFi-DensePose API
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from typing import Dict, List, Optional, Any, Set
|
||||||
|
from datetime import datetime
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import WebSocket
|
||||||
|
|
||||||
|
from src.config.settings import Settings
|
||||||
|
from src.config.domains import DomainConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamService:
|
||||||
|
"""Service for real-time data streaming."""
|
||||||
|
|
||||||
|
def __init__(self, settings: Settings, domain_config: DomainConfig):
|
||||||
|
"""Initialize stream service."""
|
||||||
|
self.settings = settings
|
||||||
|
self.domain_config = domain_config
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# WebSocket connections
|
||||||
|
self.connections: Set[WebSocket] = set()
|
||||||
|
self.connection_metadata: Dict[WebSocket, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
# Stream buffers
|
||||||
|
self.pose_buffer = deque(maxlen=self.settings.stream_buffer_size)
|
||||||
|
self.csi_buffer = deque(maxlen=self.settings.stream_buffer_size)
|
||||||
|
|
||||||
|
# Service state
|
||||||
|
self.is_running = False
|
||||||
|
self.last_error = None
|
||||||
|
|
||||||
|
# Streaming statistics
|
||||||
|
self.stats = {
|
||||||
|
"active_connections": 0,
|
||||||
|
"total_connections": 0,
|
||||||
|
"messages_sent": 0,
|
||||||
|
"messages_failed": 0,
|
||||||
|
"data_points_streamed": 0,
|
||||||
|
"average_latency_ms": 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
# Background tasks
|
||||||
|
self.streaming_task = None
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""Initialize the stream service."""
|
||||||
|
self.logger.info("Stream service initialized")
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""Start the stream service."""
|
||||||
|
if self.is_running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.is_running = True
|
||||||
|
self.logger.info("Stream service started")
|
||||||
|
|
||||||
|
# Start background streaming task
|
||||||
|
if self.settings.enable_real_time_processing:
|
||||||
|
self.streaming_task = asyncio.create_task(self._streaming_loop())
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""Stop the stream service."""
|
||||||
|
self.is_running = False
|
||||||
|
|
||||||
|
# Cancel background task
|
||||||
|
if self.streaming_task:
|
||||||
|
self.streaming_task.cancel()
|
||||||
|
try:
|
||||||
|
await self.streaming_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Close all connections
|
||||||
|
await self._close_all_connections()
|
||||||
|
|
||||||
|
self.logger.info("Stream service stopped")
|
||||||
|
|
||||||
|
async def add_connection(self, websocket: WebSocket, metadata: Dict[str, Any] = None):
|
||||||
|
"""Add a new WebSocket connection."""
|
||||||
|
try:
|
||||||
|
await websocket.accept()
|
||||||
|
self.connections.add(websocket)
|
||||||
|
self.connection_metadata[websocket] = metadata or {}
|
||||||
|
|
||||||
|
self.stats["active_connections"] = len(self.connections)
|
||||||
|
self.stats["total_connections"] += 1
|
||||||
|
|
||||||
|
self.logger.info(f"New WebSocket connection added. Total: {len(self.connections)}")
|
||||||
|
|
||||||
|
# Send initial data if available
|
||||||
|
await self._send_initial_data(websocket)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error adding WebSocket connection: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def remove_connection(self, websocket: WebSocket):
|
||||||
|
"""Remove a WebSocket connection."""
|
||||||
|
try:
|
||||||
|
if websocket in self.connections:
|
||||||
|
self.connections.remove(websocket)
|
||||||
|
self.connection_metadata.pop(websocket, None)
|
||||||
|
|
||||||
|
self.stats["active_connections"] = len(self.connections)
|
||||||
|
|
||||||
|
self.logger.info(f"WebSocket connection removed. Total: {len(self.connections)}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error removing WebSocket connection: {e}")
|
||||||
|
|
||||||
|
async def broadcast_pose_data(self, pose_data: Dict[str, Any]):
|
||||||
|
"""Broadcast pose data to all connected clients."""
|
||||||
|
if not self.is_running:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Add to buffer
|
||||||
|
self.pose_buffer.append({
|
||||||
|
"type": "pose_data",
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"data": pose_data
|
||||||
|
})
|
||||||
|
|
||||||
|
# Broadcast to all connections
|
||||||
|
await self._broadcast_message({
|
||||||
|
"type": "pose_update",
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"data": pose_data
|
||||||
|
})
|
||||||
|
|
||||||
|
async def broadcast_csi_data(self, csi_data: np.ndarray, metadata: Dict[str, Any]):
|
||||||
|
"""Broadcast CSI data to all connected clients."""
|
||||||
|
if not self.is_running:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Convert numpy array to list for JSON serialization
|
||||||
|
csi_list = csi_data.tolist() if isinstance(csi_data, np.ndarray) else csi_data
|
||||||
|
|
||||||
|
# Add to buffer
|
||||||
|
self.csi_buffer.append({
|
||||||
|
"type": "csi_data",
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"data": csi_list,
|
||||||
|
"metadata": metadata
|
||||||
|
})
|
||||||
|
|
||||||
|
# Broadcast to all connections
|
||||||
|
await self._broadcast_message({
|
||||||
|
"type": "csi_update",
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"data": csi_list,
|
||||||
|
"metadata": metadata
|
||||||
|
})
|
||||||
|
|
||||||
|
async def broadcast_system_status(self, status_data: Dict[str, Any]):
|
||||||
|
"""Broadcast system status to all connected clients."""
|
||||||
|
if not self.is_running:
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._broadcast_message({
|
||||||
|
"type": "system_status",
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"data": status_data
|
||||||
|
})
|
||||||
|
|
||||||
|
async def send_to_connection(self, websocket: WebSocket, message: Dict[str, Any]):
|
||||||
|
"""Send message to a specific connection."""
|
||||||
|
try:
|
||||||
|
if websocket in self.connections:
|
||||||
|
await websocket.send_text(json.dumps(message))
|
||||||
|
self.stats["messages_sent"] += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error sending message to connection: {e}")
|
||||||
|
self.stats["messages_failed"] += 1
|
||||||
|
await self.remove_connection(websocket)
|
||||||
|
|
||||||
|
async def _broadcast_message(self, message: Dict[str, Any]):
|
||||||
|
"""Broadcast message to all connected clients."""
|
||||||
|
if not self.connections:
|
||||||
|
return
|
||||||
|
|
||||||
|
disconnected = set()
|
||||||
|
|
||||||
|
for websocket in self.connections.copy():
|
||||||
|
try:
|
||||||
|
await websocket.send_text(json.dumps(message))
|
||||||
|
self.stats["messages_sent"] += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Failed to send message to connection: {e}")
|
||||||
|
self.stats["messages_failed"] += 1
|
||||||
|
disconnected.add(websocket)
|
||||||
|
|
||||||
|
# Remove disconnected clients
|
||||||
|
for websocket in disconnected:
|
||||||
|
await self.remove_connection(websocket)
|
||||||
|
|
||||||
|
if message.get("type") in ["pose_update", "csi_update"]:
|
||||||
|
self.stats["data_points_streamed"] += 1
|
||||||
|
|
||||||
|
async def _send_initial_data(self, websocket: WebSocket):
|
||||||
|
"""Send initial data to a new connection."""
|
||||||
|
try:
|
||||||
|
# Send recent pose data
|
||||||
|
if self.pose_buffer:
|
||||||
|
recent_poses = list(self.pose_buffer)[-10:] # Last 10 poses
|
||||||
|
await self.send_to_connection(websocket, {
|
||||||
|
"type": "initial_poses",
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"data": recent_poses
|
||||||
|
})
|
||||||
|
|
||||||
|
# Send recent CSI data
|
||||||
|
if self.csi_buffer:
|
||||||
|
recent_csi = list(self.csi_buffer)[-5:] # Last 5 CSI readings
|
||||||
|
await self.send_to_connection(websocket, {
|
||||||
|
"type": "initial_csi",
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"data": recent_csi
|
||||||
|
})
|
||||||
|
|
||||||
|
# Send service status
|
||||||
|
status = await self.get_status()
|
||||||
|
await self.send_to_connection(websocket, {
|
||||||
|
"type": "service_status",
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"data": status
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error sending initial data: {e}")
|
||||||
|
|
||||||
|
async def _streaming_loop(self):
|
||||||
|
"""Background streaming loop for periodic updates."""
|
||||||
|
try:
|
||||||
|
while self.is_running:
|
||||||
|
# Send periodic heartbeat
|
||||||
|
if self.connections:
|
||||||
|
await self._broadcast_message({
|
||||||
|
"type": "heartbeat",
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"active_connections": len(self.connections)
|
||||||
|
})
|
||||||
|
|
||||||
|
# Wait for next iteration
|
||||||
|
await asyncio.sleep(self.settings.websocket_ping_interval)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
self.logger.info("Streaming loop cancelled")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error in streaming loop: {e}")
|
||||||
|
self.last_error = str(e)
|
||||||
|
|
||||||
|
async def _close_all_connections(self):
|
||||||
|
"""Close all WebSocket connections."""
|
||||||
|
disconnected = []
|
||||||
|
|
||||||
|
for websocket in self.connections.copy():
|
||||||
|
try:
|
||||||
|
await websocket.close()
|
||||||
|
disconnected.append(websocket)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Error closing connection: {e}")
|
||||||
|
disconnected.append(websocket)
|
||||||
|
|
||||||
|
# Clear all connections
|
||||||
|
for websocket in disconnected:
|
||||||
|
await self.remove_connection(websocket)
|
||||||
|
|
||||||
|
async def get_status(self) -> Dict[str, Any]:
|
||||||
|
"""Get service status."""
|
||||||
|
return {
|
||||||
|
"status": "healthy" if self.is_running and not self.last_error else "unhealthy",
|
||||||
|
"running": self.is_running,
|
||||||
|
"last_error": self.last_error,
|
||||||
|
"connections": {
|
||||||
|
"active": len(self.connections),
|
||||||
|
"total": self.stats["total_connections"]
|
||||||
|
},
|
||||||
|
"buffers": {
|
||||||
|
"pose_buffer_size": len(self.pose_buffer),
|
||||||
|
"csi_buffer_size": len(self.csi_buffer),
|
||||||
|
"max_buffer_size": self.settings.stream_buffer_size
|
||||||
|
},
|
||||||
|
"statistics": self.stats.copy(),
|
||||||
|
"configuration": {
|
||||||
|
"stream_fps": self.settings.stream_fps,
|
||||||
|
"buffer_size": self.settings.stream_buffer_size,
|
||||||
|
"ping_interval": self.settings.websocket_ping_interval,
|
||||||
|
"timeout": self.settings.websocket_timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_metrics(self) -> Dict[str, Any]:
|
||||||
|
"""Get service metrics."""
|
||||||
|
total_messages = self.stats["messages_sent"] + self.stats["messages_failed"]
|
||||||
|
success_rate = self.stats["messages_sent"] / max(1, total_messages)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"stream_service": {
|
||||||
|
"active_connections": self.stats["active_connections"],
|
||||||
|
"total_connections": self.stats["total_connections"],
|
||||||
|
"messages_sent": self.stats["messages_sent"],
|
||||||
|
"messages_failed": self.stats["messages_failed"],
|
||||||
|
"message_success_rate": success_rate,
|
||||||
|
"data_points_streamed": self.stats["data_points_streamed"],
|
||||||
|
"average_latency_ms": self.stats["average_latency_ms"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_connection_info(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Get information about active connections."""
|
||||||
|
connections_info = []
|
||||||
|
|
||||||
|
for websocket in self.connections:
|
||||||
|
metadata = self.connection_metadata.get(websocket, {})
|
||||||
|
|
||||||
|
connection_info = {
|
||||||
|
"id": id(websocket),
|
||||||
|
"connected_at": metadata.get("connected_at", "unknown"),
|
||||||
|
"user_agent": metadata.get("user_agent", "unknown"),
|
||||||
|
"ip_address": metadata.get("ip_address", "unknown"),
|
||||||
|
"subscription_types": metadata.get("subscription_types", [])
|
||||||
|
}
|
||||||
|
|
||||||
|
connections_info.append(connection_info)
|
||||||
|
|
||||||
|
return connections_info
|
||||||
|
|
||||||
|
async def reset(self):
|
||||||
|
"""Reset service state."""
|
||||||
|
# Clear buffers
|
||||||
|
self.pose_buffer.clear()
|
||||||
|
self.csi_buffer.clear()
|
||||||
|
|
||||||
|
# Reset statistics
|
||||||
|
self.stats = {
|
||||||
|
"active_connections": len(self.connections),
|
||||||
|
"total_connections": 0,
|
||||||
|
"messages_sent": 0,
|
||||||
|
"messages_failed": 0,
|
||||||
|
"data_points_streamed": 0,
|
||||||
|
"average_latency_ms": 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
self.last_error = None
|
||||||
|
self.logger.info("Stream service reset")
|
||||||
|
|
||||||
|
def get_buffer_data(self, buffer_type: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||||
|
"""Get data from buffers."""
|
||||||
|
if buffer_type == "pose":
|
||||||
|
return list(self.pose_buffer)[-limit:]
|
||||||
|
elif buffer_type == "csi":
|
||||||
|
return list(self.csi_buffer)[-limit:]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_active(self) -> bool:
|
||||||
|
"""Check if stream service is active."""
|
||||||
|
return self.is_running
|
||||||
|
|
||||||
|
async def health_check(self) -> Dict[str, Any]:
|
||||||
|
"""Perform health check."""
|
||||||
|
try:
|
||||||
|
status = "healthy" if self.is_running and not self.last_error else "unhealthy"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": status,
|
||||||
|
"message": self.last_error if self.last_error else "Stream service is running normally",
|
||||||
|
"active_connections": len(self.connections),
|
||||||
|
"metrics": {
|
||||||
|
"messages_sent": self.stats["messages_sent"],
|
||||||
|
"messages_failed": self.stats["messages_failed"],
|
||||||
|
"data_points_streamed": self.stats["data_points_streamed"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"status": "unhealthy",
|
||||||
|
"message": f"Health check failed: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def is_ready(self) -> bool:
|
||||||
|
"""Check if service is ready."""
|
||||||
|
return self.is_running
|
||||||
198
test_application.py
Normal file
198
test_application.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test script to verify WiFi-DensePose API functionality
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
import json
|
||||||
|
import websockets
|
||||||
|
import sys
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
BASE_URL = "http://localhost:8000"
|
||||||
|
WS_URL = "ws://localhost:8000"
|
||||||
|
|
||||||
|
async def test_health_endpoints():
|
||||||
|
"""Test health check endpoints."""
|
||||||
|
print("🔍 Testing health endpoints...")
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
# Test basic health
|
||||||
|
async with session.get(f"{BASE_URL}/health/health") as response:
|
||||||
|
if response.status == 200:
|
||||||
|
data = await response.json()
|
||||||
|
print(f"✅ Health check: {data['status']}")
|
||||||
|
else:
|
||||||
|
print(f"❌ Health check failed: {response.status}")
|
||||||
|
|
||||||
|
# Test readiness
|
||||||
|
async with session.get(f"{BASE_URL}/health/ready") as response:
|
||||||
|
if response.status == 200:
|
||||||
|
data = await response.json()
|
||||||
|
status = "ready" if data['ready'] else "not ready"
|
||||||
|
print(f"✅ Readiness check: {status}")
|
||||||
|
else:
|
||||||
|
print(f"❌ Readiness check failed: {response.status}")
|
||||||
|
|
||||||
|
# Test liveness
|
||||||
|
async with session.get(f"{BASE_URL}/health/live") as response:
|
||||||
|
if response.status == 200:
|
||||||
|
data = await response.json()
|
||||||
|
print(f"✅ Liveness check: {data['status']}")
|
||||||
|
else:
|
||||||
|
print(f"❌ Liveness check failed: {response.status}")
|
||||||
|
|
||||||
|
async def test_api_endpoints():
|
||||||
|
"""Test main API endpoints."""
|
||||||
|
print("\n🔍 Testing API endpoints...")
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
# Test root endpoint
|
||||||
|
async with session.get(f"{BASE_URL}/") as response:
|
||||||
|
if response.status == 200:
|
||||||
|
data = await response.json()
|
||||||
|
print(f"✅ Root endpoint: {data['name']} v{data['version']}")
|
||||||
|
else:
|
||||||
|
print(f"❌ Root endpoint failed: {response.status}")
|
||||||
|
|
||||||
|
# Test API info
|
||||||
|
async with session.get(f"{BASE_URL}/api/v1/info") as response:
|
||||||
|
if response.status == 200:
|
||||||
|
data = await response.json()
|
||||||
|
print(f"✅ API info: {len(data['services'])} services configured")
|
||||||
|
else:
|
||||||
|
print(f"❌ API info failed: {response.status}")
|
||||||
|
|
||||||
|
# Test API status
|
||||||
|
async with session.get(f"{BASE_URL}/api/v1/status") as response:
|
||||||
|
if response.status == 200:
|
||||||
|
data = await response.json()
|
||||||
|
print(f"✅ API status: {data['api']['status']}")
|
||||||
|
else:
|
||||||
|
print(f"❌ API status failed: {response.status}")
|
||||||
|
|
||||||
|
async def test_pose_endpoints():
|
||||||
|
"""Test pose estimation endpoints."""
|
||||||
|
print("\n🔍 Testing pose endpoints...")
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
# Test current pose data
|
||||||
|
async with session.get(f"{BASE_URL}/api/v1/pose/current") as response:
|
||||||
|
if response.status == 200:
|
||||||
|
data = await response.json()
|
||||||
|
print(f"✅ Current pose data: {len(data.get('poses', []))} poses detected")
|
||||||
|
else:
|
||||||
|
print(f"❌ Current pose data failed: {response.status}")
|
||||||
|
|
||||||
|
# Test zones summary
|
||||||
|
async with session.get(f"{BASE_URL}/api/v1/pose/zones/summary") as response:
|
||||||
|
if response.status == 200:
|
||||||
|
data = await response.json()
|
||||||
|
zones = data.get('zones', {})
|
||||||
|
print(f"✅ Zones summary: {len(zones)} zones")
|
||||||
|
for zone_id, zone_data in list(zones.items())[:3]: # Show first 3 zones
|
||||||
|
print(f" - {zone_id}: {zone_data.get('occupancy', 0)} people")
|
||||||
|
else:
|
||||||
|
print(f"❌ Zones summary failed: {response.status}")
|
||||||
|
|
||||||
|
# Test pose stats
|
||||||
|
async with session.get(f"{BASE_URL}/api/v1/pose/stats") as response:
|
||||||
|
if response.status == 200:
|
||||||
|
data = await response.json()
|
||||||
|
print(f"✅ Pose stats: {data.get('total_detections', 0)} total detections")
|
||||||
|
else:
|
||||||
|
print(f"❌ Pose stats failed: {response.status}")
|
||||||
|
|
||||||
|
async def test_stream_endpoints():
|
||||||
|
"""Test streaming endpoints."""
|
||||||
|
print("\n🔍 Testing stream endpoints...")
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
# Test stream status
|
||||||
|
async with session.get(f"{BASE_URL}/api/v1/stream/status") as response:
|
||||||
|
if response.status == 200:
|
||||||
|
data = await response.json()
|
||||||
|
print(f"✅ Stream status: {'Active' if data['is_active'] else 'Inactive'}")
|
||||||
|
print(f" - Connected clients: {data['connected_clients']}")
|
||||||
|
else:
|
||||||
|
print(f"❌ Stream status failed: {response.status}")
|
||||||
|
|
||||||
|
# Test stream metrics
|
||||||
|
async with session.get(f"{BASE_URL}/api/v1/stream/metrics") as response:
|
||||||
|
if response.status == 200:
|
||||||
|
data = await response.json()
|
||||||
|
print(f"✅ Stream metrics available")
|
||||||
|
else:
|
||||||
|
print(f"❌ Stream metrics failed: {response.status}")
|
||||||
|
|
||||||
|
async def test_websocket_connection():
|
||||||
|
"""Test WebSocket connection."""
|
||||||
|
print("\n🔍 Testing WebSocket connection...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
uri = f"{WS_URL}/api/v1/stream/pose"
|
||||||
|
async with websockets.connect(uri) as websocket:
|
||||||
|
print("✅ WebSocket connected successfully")
|
||||||
|
|
||||||
|
# Wait for connection confirmation
|
||||||
|
message = await asyncio.wait_for(websocket.recv(), timeout=5.0)
|
||||||
|
data = json.loads(message)
|
||||||
|
|
||||||
|
if data.get("type") == "connection_established":
|
||||||
|
print(f"✅ Connection established with client ID: {data.get('client_id')}")
|
||||||
|
|
||||||
|
# Send a ping
|
||||||
|
await websocket.send(json.dumps({"type": "ping"}))
|
||||||
|
|
||||||
|
# Wait for pong
|
||||||
|
pong_message = await asyncio.wait_for(websocket.recv(), timeout=5.0)
|
||||||
|
pong_data = json.loads(pong_message)
|
||||||
|
|
||||||
|
if pong_data.get("type") == "pong":
|
||||||
|
print("✅ WebSocket ping/pong successful")
|
||||||
|
else:
|
||||||
|
print(f"❌ Unexpected pong response: {pong_data}")
|
||||||
|
else:
|
||||||
|
print(f"❌ Unexpected connection message: {data}")
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
print("❌ WebSocket connection timeout")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ WebSocket connection failed: {e}")
|
||||||
|
|
||||||
|
async def test_calibration_endpoints():
|
||||||
|
"""Test calibration endpoints."""
|
||||||
|
print("\n🔍 Testing calibration endpoints...")
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
# Test calibration status
|
||||||
|
async with session.get(f"{BASE_URL}/api/v1/pose/calibration/status") as response:
|
||||||
|
if response.status == 200:
|
||||||
|
data = await response.json()
|
||||||
|
print(f"✅ Calibration status: {data.get('status', 'unknown')}")
|
||||||
|
else:
|
||||||
|
print(f"❌ Calibration status failed: {response.status}")
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Run all tests."""
|
||||||
|
print("🚀 Starting WiFi-DensePose API Tests")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await test_health_endpoints()
|
||||||
|
await test_api_endpoints()
|
||||||
|
await test_pose_endpoints()
|
||||||
|
await test_stream_endpoints()
|
||||||
|
await test_websocket_connection()
|
||||||
|
await test_calibration_endpoints()
|
||||||
|
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("✅ All tests completed!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Test suite failed: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
Reference in New Issue
Block a user