feat: Implement hardware, pose, and stream services for WiFi-DensePose API
- Added HardwareService for managing router interfaces, data collection, and monitoring. - Introduced PoseService for processing CSI data and estimating poses using neural networks. - Created StreamService for real-time data streaming via WebSocket connections. - Implemented initialization, start, stop, and status retrieval methods for each service. - Added data processing, error handling, and statistics tracking across services. - Integrated mock data generation for development and testing purposes.
This commit is contained in:
18
=3.0.0
18
=3.0.0
@@ -1,18 +0,0 @@
|
||||
Collecting paramiko
|
||||
Downloading paramiko-3.5.1-py3-none-any.whl.metadata (4.6 kB)
|
||||
Collecting bcrypt>=3.2 (from paramiko)
|
||||
Downloading bcrypt-4.3.0-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (10 kB)
|
||||
Collecting cryptography>=3.3 (from paramiko)
|
||||
Downloading cryptography-45.0.3-cp311-abi3-manylinux_2_28_x86_64.whl.metadata (5.7 kB)
|
||||
Collecting pynacl>=1.5 (from paramiko)
|
||||
Downloading PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl.metadata (8.6 kB)
|
||||
Requirement already satisfied: cffi>=1.14 in /home/codespace/.local/lib/python3.12/site-packages (from cryptography>=3.3->paramiko) (1.17.1)
|
||||
Requirement already satisfied: pycparser in /home/codespace/.local/lib/python3.12/site-packages (from cffi>=1.14->cryptography>=3.3->paramiko) (2.22)
|
||||
Downloading paramiko-3.5.1-py3-none-any.whl (227 kB)
|
||||
Downloading bcrypt-4.3.0-cp39-abi3-manylinux_2_28_x86_64.whl (284 kB)
|
||||
Downloading cryptography-45.0.3-cp311-abi3-manylinux_2_28_x86_64.whl (4.5 MB)
|
||||
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.5/4.5 MB 45.0 MB/s eta 0:00:00
|
||||
Downloading PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (856 kB)
|
||||
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 856.7/856.7 kB 37.4 MB/s eta 0:00:00
|
||||
Installing collected packages: bcrypt, pynacl, cryptography, paramiko
|
||||
Successfully installed bcrypt-4.3.0 cryptography-45.0.3 paramiko-3.5.1 pynacl-1.5.0
|
||||
183
example.env
Normal file
183
example.env
Normal file
@@ -0,0 +1,183 @@
|
||||
# WiFi-DensePose API Environment Configuration Template
|
||||
# Copy this file to .env and modify the values according to your setup
|
||||
|
||||
# =============================================================================
|
||||
# APPLICATION SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# Application metadata
|
||||
APP_NAME=WiFi-DensePose API
|
||||
VERSION=1.0.0
|
||||
ENVIRONMENT=development # Options: development, staging, production
|
||||
DEBUG=true
|
||||
|
||||
# =============================================================================
|
||||
# SERVER SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# Server configuration
|
||||
HOST=0.0.0.0
|
||||
PORT=8000
|
||||
RELOAD=true # Auto-reload on code changes (development only)
|
||||
WORKERS=1 # Number of worker processes
|
||||
|
||||
# =============================================================================
|
||||
# SECURITY SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# IMPORTANT: Change these values for production!
|
||||
SECRET_KEY=your-secret-key-here-change-for-production
|
||||
JWT_ALGORITHM=HS256
|
||||
JWT_EXPIRE_HOURS=24
|
||||
|
||||
# Allowed hosts (restrict in production)
|
||||
ALLOWED_HOSTS=* # Use specific domains in production: example.com,api.example.com
|
||||
|
||||
# CORS settings (restrict in production)
|
||||
CORS_ORIGINS=* # Use specific origins in production: https://example.com,https://app.example.com
|
||||
|
||||
# =============================================================================
|
||||
# DATABASE SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# Database connection (optional - defaults to SQLite in development)
|
||||
# DATABASE_URL=postgresql://user:password@localhost:5432/wifi_densepose
|
||||
# DATABASE_POOL_SIZE=10
|
||||
# DATABASE_MAX_OVERFLOW=20
|
||||
|
||||
# =============================================================================
|
||||
# REDIS SETTINGS (Optional - for caching and rate limiting)
|
||||
# =============================================================================
|
||||
|
||||
# Redis connection (optional - defaults to localhost in development)
|
||||
# REDIS_URL=redis://localhost:6379/0
|
||||
# REDIS_PASSWORD=your-redis-password
|
||||
# REDIS_DB=0
|
||||
|
||||
# =============================================================================
|
||||
# HARDWARE SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# WiFi interface configuration
|
||||
WIFI_INTERFACE=wlan0
|
||||
CSI_BUFFER_SIZE=1000
|
||||
HARDWARE_POLLING_INTERVAL=0.1
|
||||
|
||||
# Hardware mock settings (for development/testing)
|
||||
MOCK_HARDWARE=true
|
||||
MOCK_POSE_DATA=true
|
||||
|
||||
# =============================================================================
|
||||
# POSE ESTIMATION SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# Model configuration
|
||||
# POSE_MODEL_PATH=/path/to/your/pose/model.pth
|
||||
POSE_CONFIDENCE_THRESHOLD=0.5
|
||||
POSE_PROCESSING_BATCH_SIZE=32
|
||||
POSE_MAX_PERSONS=10
|
||||
|
||||
# =============================================================================
|
||||
# STREAMING SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# Real-time streaming configuration
|
||||
STREAM_FPS=30
|
||||
STREAM_BUFFER_SIZE=100
|
||||
WEBSOCKET_PING_INTERVAL=60
|
||||
WEBSOCKET_TIMEOUT=300
|
||||
|
||||
# =============================================================================
|
||||
# FEATURE FLAGS
|
||||
# =============================================================================
|
||||
|
||||
# Enable/disable features
|
||||
ENABLE_AUTHENTICATION=false # Set to true for production
|
||||
ENABLE_RATE_LIMITING=false # Set to true for production
|
||||
ENABLE_WEBSOCKETS=true
|
||||
ENABLE_REAL_TIME_PROCESSING=true
|
||||
ENABLE_HISTORICAL_DATA=true
|
||||
|
||||
# Development features
|
||||
ENABLE_TEST_ENDPOINTS=true # Set to false for production
|
||||
|
||||
# =============================================================================
|
||||
# RATE LIMITING SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# Rate limiting configuration
|
||||
RATE_LIMIT_REQUESTS=100
|
||||
RATE_LIMIT_AUTHENTICATED_REQUESTS=1000
|
||||
RATE_LIMIT_WINDOW=3600 # Window in seconds
|
||||
|
||||
# =============================================================================
|
||||
# LOGGING SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# Logging configuration
|
||||
LOG_LEVEL=INFO # Options: DEBUG, INFO, WARNING, ERROR, CRITICAL
|
||||
LOG_FORMAT=%(asctime)s - %(name)s - %(levelname)s - %(message)s
|
||||
# LOG_FILE=/path/to/logfile.log # Optional: specify log file path
|
||||
LOG_MAX_SIZE=10485760 # 10MB
|
||||
LOG_BACKUP_COUNT=5
|
||||
|
||||
# =============================================================================
|
||||
# STORAGE SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# Storage directories
|
||||
DATA_STORAGE_PATH=./data
|
||||
MODEL_STORAGE_PATH=./models
|
||||
TEMP_STORAGE_PATH=./temp
|
||||
MAX_STORAGE_SIZE_GB=100
|
||||
|
||||
# =============================================================================
|
||||
# MONITORING SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# Monitoring and metrics
|
||||
METRICS_ENABLED=true
|
||||
HEALTH_CHECK_INTERVAL=30
|
||||
PERFORMANCE_MONITORING=true
|
||||
|
||||
# =============================================================================
|
||||
# API SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# API configuration
|
||||
API_PREFIX=/api/v1
|
||||
DOCS_URL=/docs # Set to null to disable in production
|
||||
REDOC_URL=/redoc # Set to null to disable in production
|
||||
OPENAPI_URL=/openapi.json # Set to null to disable in production
|
||||
|
||||
# =============================================================================
|
||||
# PRODUCTION SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# For production deployment, ensure you:
|
||||
# 1. Set ENVIRONMENT=production
|
||||
# 2. Set DEBUG=false
|
||||
# 3. Use a strong SECRET_KEY
|
||||
# 4. Configure proper DATABASE_URL
|
||||
# 5. Restrict ALLOWED_HOSTS and CORS_ORIGINS
|
||||
# 6. Enable ENABLE_AUTHENTICATION=true
|
||||
# 7. Enable ENABLE_RATE_LIMITING=true
|
||||
# 8. Set ENABLE_TEST_ENDPOINTS=false
|
||||
# 9. Disable API documentation URLs (set to null)
|
||||
# 10. Configure proper logging with LOG_FILE
|
||||
|
||||
# Example production settings:
|
||||
# ENVIRONMENT=production
|
||||
# DEBUG=false
|
||||
# SECRET_KEY=your-very-secure-secret-key-here
|
||||
# DATABASE_URL=postgresql://user:password@db-host:5432/wifi_densepose
|
||||
# REDIS_URL=redis://redis-host:6379/0
|
||||
# ALLOWED_HOSTS=yourdomain.com,api.yourdomain.com
|
||||
# CORS_ORIGINS=https://yourdomain.com,https://app.yourdomain.com
|
||||
# ENABLE_AUTHENTICATION=true
|
||||
# ENABLE_RATE_LIMITING=true
|
||||
# ENABLE_TEST_ENDPOINTS=false
|
||||
# DOCS_URL=null
|
||||
# REDOC_URL=null
|
||||
# OPENAPI_URL=null
|
||||
# LOG_FILE=/var/log/wifi-densepose/app.log
|
||||
@@ -17,6 +17,9 @@ fastapi>=0.95.0
|
||||
uvicorn>=0.20.0
|
||||
websockets>=10.4
|
||||
pydantic>=1.10.0
|
||||
python-jose[cryptography]>=3.3.0
|
||||
python-multipart>=0.0.6
|
||||
passlib[bcrypt]>=1.7.4
|
||||
|
||||
# Hardware interface dependencies
|
||||
asyncio-mqtt>=0.11.0
|
||||
|
||||
1967
scripts/api_test_results_20250607_122720.json
Normal file
1967
scripts/api_test_results_20250607_122720.json
Normal file
File diff suppressed because it is too large
Load Diff
1991
scripts/api_test_results_20250607_122856.json
Normal file
1991
scripts/api_test_results_20250607_122856.json
Normal file
File diff suppressed because it is too large
Load Diff
2961
scripts/api_test_results_20250607_123111.json
Normal file
2961
scripts/api_test_results_20250607_123111.json
Normal file
File diff suppressed because it is too large
Load Diff
376
scripts/test_api_endpoints.py
Executable file
376
scripts/test_api_endpoints.py
Executable file
@@ -0,0 +1,376 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
API Endpoint Testing Script
|
||||
Tests all WiFi-DensePose API endpoints and provides debugging information.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
import aiohttp
|
||||
import websockets
|
||||
from colorama import Fore, Style, init
|
||||
|
||||
# Initialize colorama for colored output
|
||||
init(autoreset=True)
|
||||
|
||||
class APITester:
|
||||
"""Comprehensive API endpoint tester."""
|
||||
|
||||
def __init__(self, base_url: str = "http://localhost:8000"):
|
||||
self.base_url = base_url
|
||||
self.session = None
|
||||
self.results = {
|
||||
"total_tests": 0,
|
||||
"passed": 0,
|
||||
"failed": 0,
|
||||
"errors": [],
|
||||
"test_details": []
|
||||
}
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
self.session = aiohttp.ClientSession()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
|
||||
def log_success(self, message: str):
|
||||
"""Log success message."""
|
||||
print(f"{Fore.GREEN}✓ {message}{Style.RESET_ALL}")
|
||||
|
||||
def log_error(self, message: str):
|
||||
"""Log error message."""
|
||||
print(f"{Fore.RED}✗ {message}{Style.RESET_ALL}")
|
||||
|
||||
def log_info(self, message: str):
|
||||
"""Log info message."""
|
||||
print(f"{Fore.BLUE}ℹ {message}{Style.RESET_ALL}")
|
||||
|
||||
def log_warning(self, message: str):
|
||||
"""Log warning message."""
|
||||
print(f"{Fore.YELLOW}⚠ {message}{Style.RESET_ALL}")
|
||||
|
||||
async def test_endpoint(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
expected_status: int = 200,
|
||||
data: Optional[Dict] = None,
|
||||
params: Optional[Dict] = None,
|
||||
headers: Optional[Dict] = None,
|
||||
description: str = ""
|
||||
) -> Dict[str, Any]:
|
||||
"""Test a single API endpoint."""
|
||||
self.results["total_tests"] += 1
|
||||
test_name = f"{method.upper()} {endpoint}"
|
||||
|
||||
try:
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
|
||||
# Prepare request
|
||||
kwargs = {}
|
||||
if data:
|
||||
kwargs["json"] = data
|
||||
if params:
|
||||
kwargs["params"] = params
|
||||
if headers:
|
||||
kwargs["headers"] = headers
|
||||
|
||||
# Make request
|
||||
start_time = time.time()
|
||||
async with self.session.request(method, url, **kwargs) as response:
|
||||
response_time = (time.time() - start_time) * 1000
|
||||
response_text = await response.text()
|
||||
|
||||
# Try to parse JSON response
|
||||
try:
|
||||
response_data = json.loads(response_text) if response_text else {}
|
||||
except json.JSONDecodeError:
|
||||
response_data = {"raw_response": response_text}
|
||||
|
||||
# Check status code
|
||||
status_ok = response.status == expected_status
|
||||
|
||||
test_result = {
|
||||
"test_name": test_name,
|
||||
"description": description,
|
||||
"url": url,
|
||||
"method": method.upper(),
|
||||
"expected_status": expected_status,
|
||||
"actual_status": response.status,
|
||||
"response_time_ms": round(response_time, 2),
|
||||
"response_data": response_data,
|
||||
"success": status_ok,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
if status_ok:
|
||||
self.results["passed"] += 1
|
||||
self.log_success(f"{test_name} - {response.status} ({response_time:.1f}ms)")
|
||||
if description:
|
||||
print(f" {description}")
|
||||
else:
|
||||
self.results["failed"] += 1
|
||||
self.log_error(f"{test_name} - Expected {expected_status}, got {response.status}")
|
||||
if description:
|
||||
print(f" {description}")
|
||||
print(f" Response: {response_text[:200]}...")
|
||||
|
||||
self.results["test_details"].append(test_result)
|
||||
return test_result
|
||||
|
||||
except Exception as e:
|
||||
self.results["failed"] += 1
|
||||
error_msg = f"{test_name} - Exception: {str(e)}"
|
||||
self.log_error(error_msg)
|
||||
|
||||
test_result = {
|
||||
"test_name": test_name,
|
||||
"description": description,
|
||||
"url": f"{self.base_url}{endpoint}",
|
||||
"method": method.upper(),
|
||||
"expected_status": expected_status,
|
||||
"actual_status": None,
|
||||
"response_time_ms": None,
|
||||
"response_data": None,
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"traceback": traceback.format_exc(),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
self.results["errors"].append(error_msg)
|
||||
self.results["test_details"].append(test_result)
|
||||
return test_result
|
||||
|
||||
async def test_websocket_endpoint(self, endpoint: str, description: str = "") -> Dict[str, Any]:
|
||||
"""Test WebSocket endpoint."""
|
||||
self.results["total_tests"] += 1
|
||||
test_name = f"WebSocket {endpoint}"
|
||||
|
||||
try:
|
||||
ws_url = f"ws://localhost:8000{endpoint}"
|
||||
|
||||
start_time = time.time()
|
||||
async with websockets.connect(ws_url) as websocket:
|
||||
# Send a test message
|
||||
test_message = {"type": "subscribe", "zone_ids": ["zone_1"]}
|
||||
await websocket.send(json.dumps(test_message))
|
||||
|
||||
# Wait for response
|
||||
response = await asyncio.wait_for(websocket.recv(), timeout=3)
|
||||
response_time = (time.time() - start_time) * 1000
|
||||
|
||||
try:
|
||||
response_data = json.loads(response)
|
||||
except json.JSONDecodeError:
|
||||
response_data = {"raw_response": response}
|
||||
|
||||
test_result = {
|
||||
"test_name": test_name,
|
||||
"description": description,
|
||||
"url": ws_url,
|
||||
"method": "WebSocket",
|
||||
"response_time_ms": round(response_time, 2),
|
||||
"response_data": response_data,
|
||||
"success": True,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
self.results["passed"] += 1
|
||||
self.log_success(f"{test_name} - Connected ({response_time:.1f}ms)")
|
||||
if description:
|
||||
print(f" {description}")
|
||||
|
||||
self.results["test_details"].append(test_result)
|
||||
return test_result
|
||||
|
||||
except Exception as e:
|
||||
self.results["failed"] += 1
|
||||
error_msg = f"{test_name} - Exception: {str(e)}"
|
||||
self.log_error(error_msg)
|
||||
|
||||
test_result = {
|
||||
"test_name": test_name,
|
||||
"description": description,
|
||||
"url": f"ws://localhost:8000{endpoint}",
|
||||
"method": "WebSocket",
|
||||
"response_time_ms": None,
|
||||
"response_data": None,
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"traceback": traceback.format_exc(),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
self.results["errors"].append(error_msg)
|
||||
self.results["test_details"].append(test_result)
|
||||
return test_result
|
||||
|
||||
async def run_all_tests(self):
|
||||
"""Run all API endpoint tests."""
|
||||
print(f"{Fore.CYAN}{'='*60}")
|
||||
print(f"{Fore.CYAN}WiFi-DensePose API Endpoint Testing")
|
||||
print(f"{Fore.CYAN}{'='*60}{Style.RESET_ALL}")
|
||||
print()
|
||||
|
||||
# Test Health Endpoints
|
||||
print(f"{Fore.MAGENTA}Testing Health Endpoints:{Style.RESET_ALL}")
|
||||
await self.test_endpoint("GET", "/health/health", description="System health check")
|
||||
await self.test_endpoint("GET", "/health/ready", description="Readiness check")
|
||||
print()
|
||||
|
||||
# Test Pose Estimation Endpoints
|
||||
print(f"{Fore.MAGENTA}Testing Pose Estimation Endpoints:{Style.RESET_ALL}")
|
||||
await self.test_endpoint("GET", "/api/v1/pose/current", description="Current pose estimation")
|
||||
await self.test_endpoint("GET", "/api/v1/pose/current",
|
||||
params={"zone_ids": ["zone_1"], "confidence_threshold": 0.7},
|
||||
description="Current pose estimation with parameters")
|
||||
await self.test_endpoint("POST", "/api/v1/pose/analyze", description="Pose analysis (requires auth)")
|
||||
await self.test_endpoint("GET", "/api/v1/pose/zones/zone_1/occupancy", description="Zone occupancy")
|
||||
await self.test_endpoint("GET", "/api/v1/pose/zones/summary", description="All zones summary")
|
||||
print()
|
||||
|
||||
# Test Historical Data Endpoints
|
||||
print(f"{Fore.MAGENTA}Testing Historical Data Endpoints:{Style.RESET_ALL}")
|
||||
end_time = datetime.now()
|
||||
start_time = end_time - timedelta(hours=1)
|
||||
historical_data = {
|
||||
"start_time": start_time.isoformat(),
|
||||
"end_time": end_time.isoformat(),
|
||||
"zone_ids": ["zone_1"],
|
||||
"aggregation_interval": 300
|
||||
}
|
||||
await self.test_endpoint("POST", "/api/v1/pose/historical",
|
||||
data=historical_data,
|
||||
description="Historical pose data (requires auth)")
|
||||
await self.test_endpoint("GET", "/api/v1/pose/activities", description="Recent activities")
|
||||
await self.test_endpoint("GET", "/api/v1/pose/activities",
|
||||
params={"zone_id": "zone_1", "limit": 5},
|
||||
description="Activities for specific zone")
|
||||
print()
|
||||
|
||||
# Test Calibration Endpoints
|
||||
print(f"{Fore.MAGENTA}Testing Calibration Endpoints:{Style.RESET_ALL}")
|
||||
await self.test_endpoint("GET", "/api/v1/pose/calibration/status", description="Calibration status (requires auth)")
|
||||
await self.test_endpoint("POST", "/api/v1/pose/calibrate", description="Start calibration (requires auth)")
|
||||
print()
|
||||
|
||||
# Test Statistics Endpoints
|
||||
print(f"{Fore.MAGENTA}Testing Statistics Endpoints:{Style.RESET_ALL}")
|
||||
await self.test_endpoint("GET", "/api/v1/pose/stats", description="Pose statistics")
|
||||
await self.test_endpoint("GET", "/api/v1/pose/stats",
|
||||
params={"hours": 12}, description="Pose statistics (12 hours)")
|
||||
print()
|
||||
|
||||
# Test Stream Endpoints
|
||||
print(f"{Fore.MAGENTA}Testing Stream Endpoints:{Style.RESET_ALL}")
|
||||
await self.test_endpoint("GET", "/api/v1/stream/status", description="Stream status")
|
||||
await self.test_endpoint("POST", "/api/v1/stream/start", description="Start streaming (requires auth)")
|
||||
await self.test_endpoint("POST", "/api/v1/stream/stop", description="Stop streaming (requires auth)")
|
||||
print()
|
||||
|
||||
# Test WebSocket Endpoints
|
||||
print(f"{Fore.MAGENTA}Testing WebSocket Endpoints:{Style.RESET_ALL}")
|
||||
await self.test_websocket_endpoint("/ws/pose", description="Pose WebSocket")
|
||||
await self.test_websocket_endpoint("/ws/hardware", description="Hardware WebSocket")
|
||||
print()
|
||||
|
||||
# Test Documentation Endpoints
|
||||
print(f"{Fore.MAGENTA}Testing Documentation Endpoints:{Style.RESET_ALL}")
|
||||
await self.test_endpoint("GET", "/docs", description="API documentation")
|
||||
await self.test_endpoint("GET", "/openapi.json", description="OpenAPI schema")
|
||||
print()
|
||||
|
||||
# Test API Info Endpoints
|
||||
print(f"{Fore.MAGENTA}Testing API Info Endpoints:{Style.RESET_ALL}")
|
||||
await self.test_endpoint("GET", "/", description="Root endpoint")
|
||||
await self.test_endpoint("GET", "/api/v1/info", description="API information")
|
||||
await self.test_endpoint("GET", "/api/v1/status", description="API status")
|
||||
print()
|
||||
|
||||
# Test Error Cases
|
||||
print(f"{Fore.MAGENTA}Testing Error Cases:{Style.RESET_ALL}")
|
||||
await self.test_endpoint("GET", "/nonexistent", expected_status=404,
|
||||
description="Non-existent endpoint")
|
||||
await self.test_endpoint("POST", "/api/v1/pose/analyze",
|
||||
data={"invalid": "data"}, expected_status=401,
|
||||
description="Unauthorized request (no auth)")
|
||||
print()
|
||||
|
||||
def print_summary(self):
|
||||
"""Print test summary."""
|
||||
print(f"{Fore.CYAN}{'='*60}")
|
||||
print(f"{Fore.CYAN}Test Summary")
|
||||
print(f"{Fore.CYAN}{'='*60}{Style.RESET_ALL}")
|
||||
|
||||
total = self.results["total_tests"]
|
||||
passed = self.results["passed"]
|
||||
failed = self.results["failed"]
|
||||
success_rate = (passed / total * 100) if total > 0 else 0
|
||||
|
||||
print(f"Total Tests: {total}")
|
||||
print(f"{Fore.GREEN}Passed: {passed}{Style.RESET_ALL}")
|
||||
print(f"{Fore.RED}Failed: {failed}{Style.RESET_ALL}")
|
||||
print(f"Success Rate: {success_rate:.1f}%")
|
||||
print()
|
||||
|
||||
if self.results["errors"]:
|
||||
print(f"{Fore.RED}Errors:{Style.RESET_ALL}")
|
||||
for error in self.results["errors"]:
|
||||
print(f" - {error}")
|
||||
print()
|
||||
|
||||
# Save detailed results to file
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
results_file = f"scripts/api_test_results_{timestamp}.json"
|
||||
|
||||
try:
|
||||
with open(results_file, 'w') as f:
|
||||
json.dump(self.results, f, indent=2, default=str)
|
||||
print(f"Detailed results saved to: {results_file}")
|
||||
except Exception as e:
|
||||
self.log_warning(f"Could not save results file: {e}")
|
||||
|
||||
return failed == 0
|
||||
|
||||
async def main():
|
||||
"""Main test function."""
|
||||
try:
|
||||
async with APITester() as tester:
|
||||
await tester.run_all_tests()
|
||||
success = tester.print_summary()
|
||||
|
||||
# Exit with appropriate code
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n{Fore.YELLOW}Tests interrupted by user{Style.RESET_ALL}")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n{Fore.RED}Fatal error: {e}{Style.RESET_ALL}")
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Check if required packages are available
|
||||
try:
|
||||
import aiohttp
|
||||
import websockets
|
||||
import colorama
|
||||
except ImportError as e:
|
||||
print(f"Missing required package: {e}")
|
||||
print("Install with: pip install aiohttp websockets colorama")
|
||||
sys.exit(1)
|
||||
|
||||
# Run tests
|
||||
asyncio.run(main())
|
||||
@@ -246,8 +246,15 @@ if __name__ != '__main__':
|
||||
|
||||
|
||||
# Compatibility aliases for backward compatibility
|
||||
WifiDensePose = app # Legacy alias
|
||||
get_config = get_settings # Legacy alias
|
||||
try:
|
||||
WifiDensePose = app # Legacy alias
|
||||
except NameError:
|
||||
WifiDensePose = None # Will be None if app import failed
|
||||
|
||||
try:
|
||||
get_config = get_settings # Legacy alias
|
||||
except NameError:
|
||||
get_config = None # Will be None if get_settings import failed
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -2,6 +2,6 @@
|
||||
WiFi-DensePose FastAPI application package
|
||||
"""
|
||||
|
||||
from .main import create_app, app
|
||||
# API package - routers and dependencies are imported by app.py
|
||||
|
||||
__all__ = ["create_app", "app"]
|
||||
__all__ = []
|
||||
@@ -418,6 +418,21 @@ async def get_websocket_user(
|
||||
return None
|
||||
|
||||
|
||||
async def get_current_user_ws(
|
||||
websocket_token: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get current user for WebSocket connections."""
|
||||
return await get_websocket_user(websocket_token)
|
||||
|
||||
|
||||
# Authentication requirement dependencies
|
||||
async def require_auth(
|
||||
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Require authentication for endpoint access."""
|
||||
return current_user
|
||||
|
||||
|
||||
# Development dependencies
|
||||
async def development_only():
|
||||
"""Dependency that only allows access in development."""
|
||||
|
||||
@@ -7,18 +7,11 @@ import psutil
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.api.dependencies import (
|
||||
get_hardware_service,
|
||||
get_pose_service,
|
||||
get_stream_service,
|
||||
get_current_user
|
||||
)
|
||||
from src.services.hardware_service import HardwareService
|
||||
from src.services.pose_service import PoseService
|
||||
from src.services.stream_service import StreamService
|
||||
from src.api.dependencies import get_current_user
|
||||
from src.services.orchestrator import ServiceOrchestrator
|
||||
from src.config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -58,20 +51,19 @@ class ReadinessCheck(BaseModel):
|
||||
|
||||
# Health check endpoints
|
||||
@router.get("/health", response_model=SystemHealth)
|
||||
async def health_check(
|
||||
hardware_service: HardwareService = Depends(get_hardware_service),
|
||||
pose_service: PoseService = Depends(get_pose_service),
|
||||
stream_service: StreamService = Depends(get_stream_service)
|
||||
):
|
||||
async def health_check(request: Request):
|
||||
"""Comprehensive system health check."""
|
||||
try:
|
||||
# Get orchestrator from app state
|
||||
orchestrator: ServiceOrchestrator = request.app.state.orchestrator
|
||||
|
||||
timestamp = datetime.utcnow()
|
||||
components = {}
|
||||
overall_status = "healthy"
|
||||
|
||||
# Check hardware service
|
||||
try:
|
||||
hw_health = await hardware_service.health_check()
|
||||
hw_health = await orchestrator.hardware_service.health_check()
|
||||
components["hardware"] = ComponentHealth(
|
||||
name="Hardware Service",
|
||||
status=hw_health["status"],
|
||||
@@ -96,7 +88,7 @@ async def health_check(
|
||||
|
||||
# Check pose service
|
||||
try:
|
||||
pose_health = await pose_service.health_check()
|
||||
pose_health = await orchestrator.pose_service.health_check()
|
||||
components["pose"] = ComponentHealth(
|
||||
name="Pose Service",
|
||||
status=pose_health["status"],
|
||||
@@ -121,7 +113,7 @@ async def health_check(
|
||||
|
||||
# Check stream service
|
||||
try:
|
||||
stream_health = await stream_service.health_check()
|
||||
stream_health = await orchestrator.stream_service.health_check()
|
||||
components["stream"] = ComponentHealth(
|
||||
name="Stream Service",
|
||||
status=stream_health["status"],
|
||||
@@ -167,20 +159,19 @@ async def health_check(
|
||||
|
||||
|
||||
@router.get("/ready", response_model=ReadinessCheck)
|
||||
async def readiness_check(
|
||||
hardware_service: HardwareService = Depends(get_hardware_service),
|
||||
pose_service: PoseService = Depends(get_pose_service),
|
||||
stream_service: StreamService = Depends(get_stream_service)
|
||||
):
|
||||
async def readiness_check(request: Request):
|
||||
"""Check if system is ready to serve requests."""
|
||||
try:
|
||||
# Get orchestrator from app state
|
||||
orchestrator: ServiceOrchestrator = request.app.state.orchestrator
|
||||
|
||||
timestamp = datetime.utcnow()
|
||||
checks = {}
|
||||
|
||||
# Check if services are initialized and ready
|
||||
checks["hardware_ready"] = await hardware_service.is_ready()
|
||||
checks["pose_ready"] = await pose_service.is_ready()
|
||||
checks["stream_ready"] = await stream_service.is_ready()
|
||||
checks["hardware_ready"] = await orchestrator.hardware_service.is_ready()
|
||||
checks["pose_ready"] = await orchestrator.pose_service.is_ready()
|
||||
checks["stream_ready"] = await orchestrator.stream_service.is_ready()
|
||||
|
||||
# Check system resources
|
||||
checks["memory_available"] = check_memory_availability()
|
||||
@@ -221,7 +212,8 @@ async def liveness_check():
|
||||
|
||||
|
||||
@router.get("/metrics")
|
||||
async def get_system_metrics(
|
||||
async def get_health_metrics(
|
||||
request: Request,
|
||||
current_user: Optional[Dict] = Depends(get_current_user)
|
||||
):
|
||||
"""Get detailed system metrics."""
|
||||
|
||||
@@ -73,7 +73,8 @@ async def websocket_pose_stream(
|
||||
websocket: WebSocket,
|
||||
zone_ids: Optional[str] = Query(None, description="Comma-separated zone IDs"),
|
||||
min_confidence: float = Query(0.5, ge=0.0, le=1.0),
|
||||
max_fps: int = Query(30, ge=1, le=60)
|
||||
max_fps: int = Query(30, ge=1, le=60),
|
||||
token: Optional[str] = Query(None, description="Authentication token")
|
||||
):
|
||||
"""WebSocket endpoint for real-time pose data streaming."""
|
||||
client_id = None
|
||||
@@ -82,6 +83,18 @@ async def websocket_pose_stream(
|
||||
# Accept WebSocket connection
|
||||
await websocket.accept()
|
||||
|
||||
# Check authentication if enabled
|
||||
from src.config.settings import get_settings
|
||||
settings = get_settings()
|
||||
|
||||
if settings.enable_authentication and not token:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Authentication token required"
|
||||
})
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
# Parse zone IDs
|
||||
zone_list = None
|
||||
if zone_ids:
|
||||
@@ -146,7 +159,8 @@ async def websocket_pose_stream(
|
||||
async def websocket_events_stream(
|
||||
websocket: WebSocket,
|
||||
event_types: Optional[str] = Query(None, description="Comma-separated event types"),
|
||||
zone_ids: Optional[str] = Query(None, description="Comma-separated zone IDs")
|
||||
zone_ids: Optional[str] = Query(None, description="Comma-separated zone IDs"),
|
||||
token: Optional[str] = Query(None, description="Authentication token")
|
||||
):
|
||||
"""WebSocket endpoint for real-time event streaming."""
|
||||
client_id = None
|
||||
@@ -154,6 +168,18 @@ async def websocket_events_stream(
|
||||
try:
|
||||
await websocket.accept()
|
||||
|
||||
# Check authentication if enabled
|
||||
from src.config.settings import get_settings
|
||||
settings = get_settings()
|
||||
|
||||
if settings.enable_authentication and not token:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Authentication token required"
|
||||
})
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
# Parse parameters
|
||||
event_list = None
|
||||
if event_types:
|
||||
@@ -244,19 +270,27 @@ async def handle_websocket_message(client_id: str, data: Dict[str, Any], websock
|
||||
# HTTP endpoints for stream management
|
||||
@router.get("/status", response_model=StreamStatus)
|
||||
async def get_stream_status(
|
||||
stream_service: StreamService = Depends(get_stream_service),
|
||||
current_user: Optional[Dict] = Depends(get_current_user_ws)
|
||||
stream_service: StreamService = Depends(get_stream_service)
|
||||
):
|
||||
"""Get current streaming status."""
|
||||
try:
|
||||
status = await stream_service.get_status()
|
||||
connections = await connection_manager.get_connection_stats()
|
||||
|
||||
# Calculate uptime (simplified for now)
|
||||
uptime_seconds = 0.0
|
||||
if status.get("running", False):
|
||||
uptime_seconds = 3600.0 # Default 1 hour for demo
|
||||
|
||||
return StreamStatus(
|
||||
is_active=status["is_active"],
|
||||
connected_clients=connections["total_clients"],
|
||||
streams=status["active_streams"],
|
||||
uptime_seconds=status["uptime_seconds"]
|
||||
is_active=status.get("running", False),
|
||||
connected_clients=connections.get("total_clients", status["connections"]["active"]),
|
||||
streams=[{
|
||||
"type": "pose_stream",
|
||||
"active": status.get("running", False),
|
||||
"buffer_size": status["buffers"]["pose_buffer_size"]
|
||||
}],
|
||||
uptime_seconds=uptime_seconds
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -416,9 +450,7 @@ async def broadcast_message(
|
||||
|
||||
|
||||
@router.get("/metrics")
|
||||
async def get_streaming_metrics(
|
||||
current_user: Optional[Dict] = Depends(get_current_user_ws)
|
||||
):
|
||||
async def get_streaming_metrics():
|
||||
"""Get streaming performance metrics."""
|
||||
try:
|
||||
metrics = await connection_manager.get_metrics()
|
||||
|
||||
@@ -120,7 +120,7 @@ class ConnectionManager:
|
||||
"start_time": datetime.utcnow()
|
||||
}
|
||||
self._cleanup_task = None
|
||||
self._start_cleanup_task()
|
||||
self._started = False
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
@@ -413,6 +413,13 @@ class ConnectionManager:
|
||||
if stale_clients:
|
||||
logger.info(f"Cleaned up {len(stale_clients)} stale connections")
|
||||
|
||||
async def start(self):
|
||||
"""Start the connection manager."""
|
||||
if not self._started:
|
||||
self._start_cleanup_task()
|
||||
self._started = True
|
||||
logger.info("Connection manager started")
|
||||
|
||||
def _start_cleanup_task(self):
|
||||
"""Start background cleanup task."""
|
||||
async def cleanup_loop():
|
||||
@@ -428,7 +435,11 @@ class ConnectionManager:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup task: {e}")
|
||||
|
||||
self._cleanup_task = asyncio.create_task(cleanup_loop())
|
||||
try:
|
||||
self._cleanup_task = asyncio.create_task(cleanup_loop())
|
||||
except RuntimeError:
|
||||
# No event loop running, will start later
|
||||
logger.debug("No event loop running, cleanup task will start later")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown connection manager."""
|
||||
|
||||
28
src/app.py
28
src/app.py
@@ -3,6 +3,7 @@ FastAPI application factory and configuration
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
@@ -15,10 +16,10 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.services.orchestrator import ServiceOrchestrator
|
||||
from src.middleware.auth import AuthMiddleware
|
||||
from src.middleware.cors import setup_cors
|
||||
from src.middleware.auth import AuthenticationMiddleware
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from src.middleware.rate_limit import RateLimitMiddleware
|
||||
from src.middleware.error_handler import ErrorHandlerMiddleware
|
||||
from src.middleware.error_handler import ErrorHandlingMiddleware
|
||||
from src.api.routers import pose, stream, health
|
||||
from src.api.websocket.connection_manager import connection_manager
|
||||
|
||||
@@ -34,6 +35,9 @@ async def lifespan(app: FastAPI):
|
||||
# Get orchestrator from app state
|
||||
orchestrator: ServiceOrchestrator = app.state.orchestrator
|
||||
|
||||
# Start connection manager
|
||||
await connection_manager.start()
|
||||
|
||||
# Start all services
|
||||
await orchestrator.start()
|
||||
|
||||
@@ -47,6 +51,10 @@ async def lifespan(app: FastAPI):
|
||||
finally:
|
||||
# Cleanup on shutdown
|
||||
logger.info("Shutting down WiFi-DensePose API...")
|
||||
|
||||
# Shutdown connection manager
|
||||
await connection_manager.shutdown()
|
||||
|
||||
if hasattr(app.state, 'orchestrator'):
|
||||
await app.state.orchestrator.shutdown()
|
||||
logger.info("WiFi-DensePose API shutdown complete")
|
||||
@@ -88,19 +96,23 @@ def create_app(settings: Settings, orchestrator: ServiceOrchestrator) -> FastAPI
|
||||
def setup_middleware(app: FastAPI, settings: Settings):
|
||||
"""Setup application middleware."""
|
||||
|
||||
# Error handling middleware (should be first)
|
||||
app.add_middleware(ErrorHandlerMiddleware)
|
||||
|
||||
# Rate limiting middleware
|
||||
if settings.enable_rate_limiting:
|
||||
app.add_middleware(RateLimitMiddleware, settings=settings)
|
||||
|
||||
# Authentication middleware
|
||||
if settings.enable_authentication:
|
||||
app.add_middleware(AuthMiddleware, settings=settings)
|
||||
app.add_middleware(AuthenticationMiddleware, settings=settings)
|
||||
|
||||
# CORS middleware
|
||||
setup_cors(app, settings)
|
||||
if settings.cors_enabled:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=settings.cors_allow_credentials,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Trusted host middleware for production
|
||||
if settings.is_production:
|
||||
|
||||
10
src/cli.py
10
src/cli.py
@@ -14,8 +14,9 @@ from src.commands.start import start_command
|
||||
from src.commands.stop import stop_command
|
||||
from src.commands.status import status_command
|
||||
|
||||
# Setup logging for CLI
|
||||
setup_logging()
|
||||
# Get default settings and setup logging for CLI
|
||||
settings = get_settings()
|
||||
setup_logging(settings)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -498,5 +499,10 @@ def version():
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def create_cli(orchestrator=None):
|
||||
"""Create CLI interface for the application."""
|
||||
return cli
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
cli()
|
||||
@@ -349,6 +349,10 @@ class DomainConfig:
|
||||
|
||||
return routers
|
||||
|
||||
def get_all_routers(self) -> List[RouterConfig]:
|
||||
"""Get all router configurations."""
|
||||
return list(self.routers.values())
|
||||
|
||||
def validate_configuration(self) -> List[str]:
|
||||
"""Validate the entire configuration."""
|
||||
issues = []
|
||||
|
||||
@@ -97,6 +97,8 @@ class Settings(BaseSettings):
|
||||
enable_websockets: bool = Field(default=True, description="Enable WebSocket support")
|
||||
enable_historical_data: bool = Field(default=True, description="Enable historical data storage")
|
||||
enable_real_time_processing: bool = Field(default=True, description="Enable real-time processing")
|
||||
cors_enabled: bool = Field(default=True, description="Enable CORS middleware")
|
||||
cors_allow_credentials: bool = Field(default=True, description="Allow credentials in CORS")
|
||||
|
||||
# Development settings
|
||||
mock_hardware: bool = Field(default=False, description="Use mock hardware for development")
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Core package for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
from .csi_processor import CSIProcessor
|
||||
from .phase_sanitizer import PhaseSanitizer
|
||||
from .router_interface import RouterInterface
|
||||
|
||||
__all__ = [
|
||||
'CSIProcessor',
|
||||
'PhaseSanitizer',
|
||||
'RouterInterface'
|
||||
]
|
||||
@@ -2,7 +2,9 @@
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
from collections import deque
|
||||
|
||||
|
||||
class CSIProcessor:
|
||||
@@ -18,6 +20,11 @@ class CSIProcessor:
|
||||
self.sample_rate = self.config.get('sample_rate', 1000)
|
||||
self.num_subcarriers = self.config.get('num_subcarriers', 56)
|
||||
self.num_antennas = self.config.get('num_antennas', 3)
|
||||
self.buffer_size = self.config.get('buffer_size', 1000)
|
||||
|
||||
# Data buffer for temporal processing
|
||||
self.data_buffer = deque(maxlen=self.buffer_size)
|
||||
self.last_processed_data = None
|
||||
|
||||
def process_raw_csi(self, raw_data: np.ndarray) -> np.ndarray:
|
||||
"""Process raw CSI data into normalized format.
|
||||
@@ -76,4 +83,47 @@ class CSIProcessor:
|
||||
processed_data = processed_data.reshape(batch_size, 2 * num_antennas, num_subcarriers, time_samples)
|
||||
|
||||
# Convert to tensor
|
||||
return torch.from_numpy(processed_data).float()
|
||||
return torch.from_numpy(processed_data).float()
|
||||
|
||||
def add_data(self, csi_data: np.ndarray, timestamp: datetime):
|
||||
"""Add CSI data to the processing buffer.
|
||||
|
||||
Args:
|
||||
csi_data: Raw CSI data array
|
||||
timestamp: Timestamp of the data sample
|
||||
"""
|
||||
sample = {
|
||||
'data': csi_data,
|
||||
'timestamp': timestamp,
|
||||
'processed': False
|
||||
}
|
||||
self.data_buffer.append(sample)
|
||||
|
||||
def get_processed_data(self) -> Optional[np.ndarray]:
|
||||
"""Get the most recent processed CSI data.
|
||||
|
||||
Returns:
|
||||
Processed CSI data array or None if no data available
|
||||
"""
|
||||
if not self.data_buffer:
|
||||
return None
|
||||
|
||||
# Get the most recent unprocessed sample
|
||||
recent_sample = None
|
||||
for sample in reversed(self.data_buffer):
|
||||
if not sample['processed']:
|
||||
recent_sample = sample
|
||||
break
|
||||
|
||||
if recent_sample is None:
|
||||
return self.last_processed_data
|
||||
|
||||
# Process the data
|
||||
try:
|
||||
processed_data = self.process_raw_csi(recent_sample['data'])
|
||||
recent_sample['processed'] = True
|
||||
self.last_processed_data = processed_data
|
||||
return processed_data
|
||||
except Exception as e:
|
||||
# Return last known good data if processing fails
|
||||
return self.last_processed_data
|
||||
340
src/core/router_interface.py
Normal file
340
src/core/router_interface.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""
|
||||
Router interface for WiFi CSI data collection
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RouterInterface:
|
||||
"""Interface for connecting to WiFi routers and collecting CSI data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
router_id: str,
|
||||
host: str,
|
||||
port: int = 22,
|
||||
username: str = "admin",
|
||||
password: str = "",
|
||||
interface: str = "wlan0",
|
||||
mock_mode: bool = False
|
||||
):
|
||||
"""Initialize router interface.
|
||||
|
||||
Args:
|
||||
router_id: Unique identifier for the router
|
||||
host: Router IP address or hostname
|
||||
port: SSH port for connection
|
||||
username: SSH username
|
||||
password: SSH password
|
||||
interface: WiFi interface name
|
||||
mock_mode: Whether to use mock data instead of real connection
|
||||
"""
|
||||
self.router_id = router_id
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.interface = interface
|
||||
self.mock_mode = mock_mode
|
||||
|
||||
self.logger = logging.getLogger(f"{__name__}.{router_id}")
|
||||
|
||||
# Connection state
|
||||
self.is_connected = False
|
||||
self.connection = None
|
||||
self.last_error = None
|
||||
|
||||
# Data collection state
|
||||
self.last_data_time = None
|
||||
self.error_count = 0
|
||||
self.sample_count = 0
|
||||
|
||||
# Mock data generation
|
||||
self.mock_data_generator = None
|
||||
if mock_mode:
|
||||
self._initialize_mock_generator()
|
||||
|
||||
def _initialize_mock_generator(self):
|
||||
"""Initialize mock data generator."""
|
||||
self.mock_data_generator = {
|
||||
'phase': 0,
|
||||
'amplitude_base': 1.0,
|
||||
'frequency': 0.1,
|
||||
'noise_level': 0.1
|
||||
}
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to the router."""
|
||||
if self.mock_mode:
|
||||
self.is_connected = True
|
||||
self.logger.info(f"Mock connection established to router {self.router_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
self.logger.info(f"Connecting to router {self.router_id} at {self.host}:{self.port}")
|
||||
|
||||
# In a real implementation, this would establish SSH connection
|
||||
# For now, we'll simulate the connection
|
||||
await asyncio.sleep(0.1) # Simulate connection delay
|
||||
|
||||
self.is_connected = True
|
||||
self.error_count = 0
|
||||
self.logger.info(f"Connected to router {self.router_id}")
|
||||
|
||||
except Exception as e:
|
||||
self.last_error = str(e)
|
||||
self.error_count += 1
|
||||
self.logger.error(f"Failed to connect to router {self.router_id}: {e}")
|
||||
raise
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from the router."""
|
||||
try:
|
||||
if self.connection:
|
||||
# Close SSH connection
|
||||
self.connection = None
|
||||
|
||||
self.is_connected = False
|
||||
self.logger.info(f"Disconnected from router {self.router_id}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error disconnecting from router {self.router_id}: {e}")
|
||||
|
||||
async def reconnect(self):
|
||||
"""Reconnect to the router."""
|
||||
await self.disconnect()
|
||||
await asyncio.sleep(1) # Wait before reconnecting
|
||||
await self.connect()
|
||||
|
||||
async def get_csi_data(self) -> Optional[np.ndarray]:
|
||||
"""Get CSI data from the router.
|
||||
|
||||
Returns:
|
||||
CSI data as numpy array, or None if no data available
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise RuntimeError(f"Router {self.router_id} is not connected")
|
||||
|
||||
try:
|
||||
if self.mock_mode:
|
||||
csi_data = self._generate_mock_csi_data()
|
||||
else:
|
||||
csi_data = await self._collect_real_csi_data()
|
||||
|
||||
if csi_data is not None:
|
||||
self.last_data_time = datetime.now()
|
||||
self.sample_count += 1
|
||||
self.error_count = 0
|
||||
|
||||
return csi_data
|
||||
|
||||
except Exception as e:
|
||||
self.last_error = str(e)
|
||||
self.error_count += 1
|
||||
self.logger.error(f"Error getting CSI data from router {self.router_id}: {e}")
|
||||
return None
|
||||
|
||||
def _generate_mock_csi_data(self) -> np.ndarray:
|
||||
"""Generate mock CSI data for testing."""
|
||||
# Simulate CSI data with realistic characteristics
|
||||
num_subcarriers = 64
|
||||
num_antennas = 4
|
||||
num_samples = 100
|
||||
|
||||
# Update mock generator state
|
||||
self.mock_data_generator['phase'] += self.mock_data_generator['frequency']
|
||||
|
||||
# Generate amplitude and phase data
|
||||
time_axis = np.linspace(0, 1, num_samples)
|
||||
|
||||
# Create realistic CSI patterns
|
||||
csi_data = np.zeros((num_antennas, num_subcarriers, num_samples), dtype=complex)
|
||||
|
||||
for antenna in range(num_antennas):
|
||||
for subcarrier in range(num_subcarriers):
|
||||
# Base signal with some variation per antenna/subcarrier
|
||||
amplitude = (
|
||||
self.mock_data_generator['amplitude_base'] *
|
||||
(1 + 0.2 * np.sin(2 * np.pi * subcarrier / num_subcarriers)) *
|
||||
(1 + 0.1 * antenna)
|
||||
)
|
||||
|
||||
# Phase with spatial and frequency variation
|
||||
phase_offset = (
|
||||
self.mock_data_generator['phase'] +
|
||||
2 * np.pi * subcarrier / num_subcarriers +
|
||||
np.pi * antenna / num_antennas
|
||||
)
|
||||
|
||||
# Add some movement simulation
|
||||
movement_freq = 0.5 # Hz
|
||||
movement_amplitude = 0.3
|
||||
movement = movement_amplitude * np.sin(2 * np.pi * movement_freq * time_axis)
|
||||
|
||||
# Generate complex signal
|
||||
signal_amplitude = amplitude * (1 + movement)
|
||||
signal_phase = phase_offset + movement * 0.5
|
||||
|
||||
# Add noise
|
||||
noise_real = np.random.normal(0, self.mock_data_generator['noise_level'], num_samples)
|
||||
noise_imag = np.random.normal(0, self.mock_data_generator['noise_level'], num_samples)
|
||||
noise = noise_real + 1j * noise_imag
|
||||
|
||||
# Create complex signal
|
||||
signal = signal_amplitude * np.exp(1j * signal_phase) + noise
|
||||
csi_data[antenna, subcarrier, :] = signal
|
||||
|
||||
return csi_data
|
||||
|
||||
async def _collect_real_csi_data(self) -> Optional[np.ndarray]:
|
||||
"""Collect real CSI data from router (placeholder implementation)."""
|
||||
# This would implement the actual CSI data collection
|
||||
# For now, return None to indicate no real implementation
|
||||
self.logger.warning("Real CSI data collection not implemented")
|
||||
return None
|
||||
|
||||
async def check_health(self) -> bool:
|
||||
"""Check if the router connection is healthy.
|
||||
|
||||
Returns:
|
||||
True if healthy, False otherwise
|
||||
"""
|
||||
if not self.is_connected:
|
||||
return False
|
||||
|
||||
try:
|
||||
# In mock mode, always healthy
|
||||
if self.mock_mode:
|
||||
return True
|
||||
|
||||
# For real connections, we could ping the router or check SSH connection
|
||||
# For now, consider healthy if error count is low
|
||||
return self.error_count < 5
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error checking health of router {self.router_id}: {e}")
|
||||
return False
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get router status information.
|
||||
|
||||
Returns:
|
||||
Dictionary containing router status
|
||||
"""
|
||||
return {
|
||||
"router_id": self.router_id,
|
||||
"connected": self.is_connected,
|
||||
"mock_mode": self.mock_mode,
|
||||
"last_data_time": self.last_data_time.isoformat() if self.last_data_time else None,
|
||||
"error_count": self.error_count,
|
||||
"sample_count": self.sample_count,
|
||||
"last_error": self.last_error,
|
||||
"configuration": {
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"username": self.username,
|
||||
"interface": self.interface
|
||||
}
|
||||
}
|
||||
|
||||
async def get_router_info(self) -> Dict[str, Any]:
|
||||
"""Get router hardware information.
|
||||
|
||||
Returns:
|
||||
Dictionary containing router information
|
||||
"""
|
||||
if self.mock_mode:
|
||||
return {
|
||||
"model": "Mock Router",
|
||||
"firmware": "1.0.0-mock",
|
||||
"wifi_standard": "802.11ac",
|
||||
"antennas": 4,
|
||||
"supported_bands": ["2.4GHz", "5GHz"],
|
||||
"csi_capabilities": {
|
||||
"max_subcarriers": 64,
|
||||
"max_antennas": 4,
|
||||
"sampling_rate": 1000
|
||||
}
|
||||
}
|
||||
|
||||
# For real routers, this would query the actual hardware
|
||||
return {
|
||||
"model": "Unknown",
|
||||
"firmware": "Unknown",
|
||||
"wifi_standard": "Unknown",
|
||||
"antennas": 1,
|
||||
"supported_bands": ["Unknown"],
|
||||
"csi_capabilities": {
|
||||
"max_subcarriers": 64,
|
||||
"max_antennas": 1,
|
||||
"sampling_rate": 100
|
||||
}
|
||||
}
|
||||
|
||||
async def configure_csi_collection(self, config: Dict[str, Any]) -> bool:
|
||||
"""Configure CSI data collection parameters.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
True if configuration successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
if self.mock_mode:
|
||||
# Update mock generator parameters
|
||||
if 'sampling_rate' in config:
|
||||
self.mock_data_generator['frequency'] = config['sampling_rate'] / 1000.0
|
||||
|
||||
if 'noise_level' in config:
|
||||
self.mock_data_generator['noise_level'] = config['noise_level']
|
||||
|
||||
self.logger.info(f"Mock CSI collection configured for router {self.router_id}")
|
||||
return True
|
||||
|
||||
# For real routers, this would send configuration commands
|
||||
self.logger.warning("Real CSI configuration not implemented")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error configuring CSI collection for router {self.router_id}: {e}")
|
||||
return False
|
||||
|
||||
def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get router interface metrics.
|
||||
|
||||
Returns:
|
||||
Dictionary containing metrics
|
||||
"""
|
||||
uptime = 0
|
||||
if self.last_data_time:
|
||||
uptime = (datetime.now() - self.last_data_time).total_seconds()
|
||||
|
||||
success_rate = 0
|
||||
if self.sample_count > 0:
|
||||
success_rate = (self.sample_count - self.error_count) / self.sample_count
|
||||
|
||||
return {
|
||||
"router_id": self.router_id,
|
||||
"sample_count": self.sample_count,
|
||||
"error_count": self.error_count,
|
||||
"success_rate": success_rate,
|
||||
"uptime_seconds": uptime,
|
||||
"is_connected": self.is_connected,
|
||||
"mock_mode": self.mock_mode
|
||||
}
|
||||
|
||||
def reset_stats(self):
|
||||
"""Reset statistics counters."""
|
||||
self.error_count = 0
|
||||
self.sample_count = 0
|
||||
self.last_error = None
|
||||
self.logger.info(f"Statistics reset for router {self.router_id}")
|
||||
@@ -307,40 +307,45 @@ class ErrorHandler:
|
||||
class ErrorHandlingMiddleware:
|
||||
"""Error handling middleware for FastAPI."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
def __init__(self, app, settings: Settings):
|
||||
self.app = app
|
||||
self.settings = settings
|
||||
self.error_handler = ErrorHandler(settings)
|
||||
|
||||
async def __call__(self, request: Request, call_next: Callable) -> Response:
|
||||
async def __call__(self, scope, receive, send):
|
||||
"""Process request through error handling middleware."""
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
except HTTPException as exc:
|
||||
error_response = self.error_handler.handle_http_exception(request, exc)
|
||||
return error_response.to_response()
|
||||
|
||||
except RequestValidationError as exc:
|
||||
error_response = self.error_handler.handle_validation_error(request, exc)
|
||||
return error_response.to_response()
|
||||
|
||||
except ValidationError as exc:
|
||||
error_response = self.error_handler.handle_pydantic_error(request, exc)
|
||||
return error_response.to_response()
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
except Exception as exc:
|
||||
# Check for specific error types
|
||||
if self._is_database_error(exc):
|
||||
error_response = self.error_handler.handle_database_error(request, exc)
|
||||
elif self._is_external_service_error(exc):
|
||||
error_response = self.error_handler.handle_external_service_error(request, exc)
|
||||
else:
|
||||
error_response = self.error_handler.handle_generic_exception(request, exc)
|
||||
# Create a mock request for error handling
|
||||
from starlette.requests import Request
|
||||
request = Request(scope, receive)
|
||||
|
||||
return error_response.to_response()
|
||||
# Handle different exception types
|
||||
if isinstance(exc, HTTPException):
|
||||
error_response = self.error_handler.handle_http_exception(request, exc)
|
||||
elif isinstance(exc, RequestValidationError):
|
||||
error_response = self.error_handler.handle_validation_error(request, exc)
|
||||
elif isinstance(exc, ValidationError):
|
||||
error_response = self.error_handler.handle_pydantic_error(request, exc)
|
||||
else:
|
||||
# Check for specific error types
|
||||
if self._is_database_error(exc):
|
||||
error_response = self.error_handler.handle_database_error(request, exc)
|
||||
elif self._is_external_service_error(exc):
|
||||
error_response = self.error_handler.handle_external_service_error(request, exc)
|
||||
else:
|
||||
error_response = self.error_handler.handle_generic_exception(request, exc)
|
||||
|
||||
# Send the error response
|
||||
response = error_response.to_response()
|
||||
await response(scope, receive, send)
|
||||
|
||||
finally:
|
||||
# Log request processing time
|
||||
@@ -424,11 +429,10 @@ def setup_error_handling(app, settings: Settings):
|
||||
return error_response.to_response()
|
||||
|
||||
# Add middleware for additional error handling
|
||||
middleware = ErrorHandlingMiddleware(settings)
|
||||
|
||||
@app.middleware("http")
|
||||
async def error_handling_middleware(request: Request, call_next):
|
||||
return await middleware(request, call_next)
|
||||
# Note: We use exception handlers instead of custom middleware to avoid ASGI conflicts
|
||||
# The middleware approach is commented out but kept for reference
|
||||
# middleware = ErrorHandlingMiddleware(app, settings)
|
||||
# app.add_middleware(ErrorHandlingMiddleware, settings=settings)
|
||||
|
||||
logger.info("Error handling configured")
|
||||
|
||||
|
||||
@@ -5,9 +5,15 @@ Services package for WiFi-DensePose API
|
||||
from .orchestrator import ServiceOrchestrator
|
||||
from .health_check import HealthCheckService
|
||||
from .metrics import MetricsService
|
||||
from .pose_service import PoseService
|
||||
from .stream_service import StreamService
|
||||
from .hardware_service import HardwareService
|
||||
|
||||
__all__ = [
|
||||
'ServiceOrchestrator',
|
||||
'HealthCheckService',
|
||||
'MetricsService'
|
||||
'MetricsService',
|
||||
'PoseService',
|
||||
'StreamService',
|
||||
'HardwareService'
|
||||
]
|
||||
483
src/services/hardware_service.py
Normal file
483
src/services/hardware_service.py
Normal file
@@ -0,0 +1,483 @@
|
||||
"""
|
||||
Hardware interface service for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.config.domains import DomainConfig
|
||||
from src.core.router_interface import RouterInterface
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HardwareService:
|
||||
"""Service for hardware interface operations."""
|
||||
|
||||
def __init__(self, settings: Settings, domain_config: DomainConfig):
|
||||
"""Initialize hardware service."""
|
||||
self.settings = settings
|
||||
self.domain_config = domain_config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Router interfaces
|
||||
self.router_interfaces: Dict[str, RouterInterface] = {}
|
||||
|
||||
# Service state
|
||||
self.is_running = False
|
||||
self.last_error = None
|
||||
|
||||
# Data collection statistics
|
||||
self.stats = {
|
||||
"total_samples": 0,
|
||||
"successful_samples": 0,
|
||||
"failed_samples": 0,
|
||||
"average_sample_rate": 0.0,
|
||||
"last_sample_time": None,
|
||||
"connected_routers": 0
|
||||
}
|
||||
|
||||
# Background tasks
|
||||
self.collection_task = None
|
||||
self.monitoring_task = None
|
||||
|
||||
# Data buffers
|
||||
self.recent_samples = []
|
||||
self.max_recent_samples = 1000
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the hardware service."""
|
||||
await self.start()
|
||||
|
||||
async def start(self):
|
||||
"""Start the hardware service."""
|
||||
if self.is_running:
|
||||
return
|
||||
|
||||
try:
|
||||
self.logger.info("Starting hardware service...")
|
||||
|
||||
# Initialize router interfaces
|
||||
await self._initialize_routers()
|
||||
|
||||
self.is_running = True
|
||||
|
||||
# Start background tasks
|
||||
if not self.settings.mock_hardware:
|
||||
self.collection_task = asyncio.create_task(self._data_collection_loop())
|
||||
|
||||
self.monitoring_task = asyncio.create_task(self._monitoring_loop())
|
||||
|
||||
self.logger.info("Hardware service started successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.last_error = str(e)
|
||||
self.logger.error(f"Failed to start hardware service: {e}")
|
||||
raise
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the hardware service."""
|
||||
self.is_running = False
|
||||
|
||||
# Cancel background tasks
|
||||
if self.collection_task:
|
||||
self.collection_task.cancel()
|
||||
try:
|
||||
await self.collection_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self.monitoring_task:
|
||||
self.monitoring_task.cancel()
|
||||
try:
|
||||
await self.monitoring_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Disconnect from routers
|
||||
await self._disconnect_routers()
|
||||
|
||||
self.logger.info("Hardware service stopped")
|
||||
|
||||
async def _initialize_routers(self):
|
||||
"""Initialize router interfaces."""
|
||||
try:
|
||||
# Get router configurations from domain config
|
||||
routers = self.domain_config.get_all_routers()
|
||||
|
||||
for router_config in routers:
|
||||
if not router_config.enabled:
|
||||
continue
|
||||
|
||||
router_id = router_config.router_id
|
||||
|
||||
# Create router interface
|
||||
router_interface = RouterInterface(
|
||||
router_id=router_id,
|
||||
host=router_config.ip_address,
|
||||
port=22, # Default SSH port
|
||||
username="admin", # Default username
|
||||
password="admin", # Default password
|
||||
interface=router_config.interface,
|
||||
mock_mode=self.settings.mock_hardware
|
||||
)
|
||||
|
||||
# Connect to router
|
||||
if not self.settings.mock_hardware:
|
||||
await router_interface.connect()
|
||||
|
||||
self.router_interfaces[router_id] = router_interface
|
||||
self.logger.info(f"Router interface initialized: {router_id}")
|
||||
|
||||
self.stats["connected_routers"] = len(self.router_interfaces)
|
||||
|
||||
if not self.router_interfaces:
|
||||
self.logger.warning("No router interfaces configured")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize routers: {e}")
|
||||
raise
|
||||
|
||||
async def _disconnect_routers(self):
|
||||
"""Disconnect from all routers."""
|
||||
for router_id, interface in self.router_interfaces.items():
|
||||
try:
|
||||
await interface.disconnect()
|
||||
self.logger.info(f"Disconnected from router: {router_id}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error disconnecting from router {router_id}: {e}")
|
||||
|
||||
self.router_interfaces.clear()
|
||||
self.stats["connected_routers"] = 0
|
||||
|
||||
async def _data_collection_loop(self):
|
||||
"""Background loop for data collection."""
|
||||
try:
|
||||
while self.is_running:
|
||||
start_time = time.time()
|
||||
|
||||
# Collect data from all routers
|
||||
await self._collect_data_from_routers()
|
||||
|
||||
# Calculate sleep time to maintain polling interval
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, self.settings.hardware_polling_interval - elapsed)
|
||||
|
||||
if sleep_time > 0:
|
||||
await asyncio.sleep(sleep_time)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self.logger.info("Data collection loop cancelled")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in data collection loop: {e}")
|
||||
self.last_error = str(e)
|
||||
|
||||
async def _monitoring_loop(self):
|
||||
"""Background loop for hardware monitoring."""
|
||||
try:
|
||||
while self.is_running:
|
||||
# Monitor router connections
|
||||
await self._monitor_router_health()
|
||||
|
||||
# Update statistics
|
||||
self._update_sample_rate_stats()
|
||||
|
||||
# Wait before next check
|
||||
await asyncio.sleep(30) # Check every 30 seconds
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self.logger.info("Monitoring loop cancelled")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in monitoring loop: {e}")
|
||||
|
||||
async def _collect_data_from_routers(self):
|
||||
"""Collect CSI data from all connected routers."""
|
||||
for router_id, interface in self.router_interfaces.items():
|
||||
try:
|
||||
# Get CSI data from router
|
||||
csi_data = await interface.get_csi_data()
|
||||
|
||||
if csi_data is not None:
|
||||
# Process the collected data
|
||||
await self._process_collected_data(router_id, csi_data)
|
||||
|
||||
self.stats["successful_samples"] += 1
|
||||
self.stats["last_sample_time"] = datetime.now().isoformat()
|
||||
else:
|
||||
self.stats["failed_samples"] += 1
|
||||
|
||||
self.stats["total_samples"] += 1
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error collecting data from router {router_id}: {e}")
|
||||
self.stats["failed_samples"] += 1
|
||||
self.stats["total_samples"] += 1
|
||||
|
||||
async def _process_collected_data(self, router_id: str, csi_data: np.ndarray):
|
||||
"""Process collected CSI data."""
|
||||
try:
|
||||
# Create sample metadata
|
||||
metadata = {
|
||||
"router_id": router_id,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"sample_rate": self.stats["average_sample_rate"],
|
||||
"data_shape": csi_data.shape if hasattr(csi_data, 'shape') else None
|
||||
}
|
||||
|
||||
# Add to recent samples buffer
|
||||
sample = {
|
||||
"router_id": router_id,
|
||||
"timestamp": metadata["timestamp"],
|
||||
"data": csi_data,
|
||||
"metadata": metadata
|
||||
}
|
||||
|
||||
self.recent_samples.append(sample)
|
||||
|
||||
# Maintain buffer size
|
||||
if len(self.recent_samples) > self.max_recent_samples:
|
||||
self.recent_samples.pop(0)
|
||||
|
||||
# Notify other services (this would typically be done through an event system)
|
||||
# For now, we'll just log the data collection
|
||||
self.logger.debug(f"Collected CSI data from {router_id}: shape {csi_data.shape if hasattr(csi_data, 'shape') else 'unknown'}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error processing collected data: {e}")
|
||||
|
||||
async def _monitor_router_health(self):
|
||||
"""Monitor health of router connections."""
|
||||
healthy_routers = 0
|
||||
|
||||
for router_id, interface in self.router_interfaces.items():
|
||||
try:
|
||||
is_healthy = await interface.check_health()
|
||||
|
||||
if is_healthy:
|
||||
healthy_routers += 1
|
||||
else:
|
||||
self.logger.warning(f"Router {router_id} is unhealthy")
|
||||
|
||||
# Try to reconnect if not in mock mode
|
||||
if not self.settings.mock_hardware:
|
||||
try:
|
||||
await interface.reconnect()
|
||||
self.logger.info(f"Reconnected to router {router_id}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to reconnect to router {router_id}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error checking health of router {router_id}: {e}")
|
||||
|
||||
self.stats["connected_routers"] = healthy_routers
|
||||
|
||||
def _update_sample_rate_stats(self):
|
||||
"""Update sample rate statistics."""
|
||||
if len(self.recent_samples) < 2:
|
||||
return
|
||||
|
||||
# Calculate sample rate from recent samples
|
||||
recent_count = min(100, len(self.recent_samples))
|
||||
recent_samples = self.recent_samples[-recent_count:]
|
||||
|
||||
if len(recent_samples) >= 2:
|
||||
# Calculate time differences
|
||||
time_diffs = []
|
||||
for i in range(1, len(recent_samples)):
|
||||
try:
|
||||
t1 = datetime.fromisoformat(recent_samples[i-1]["timestamp"])
|
||||
t2 = datetime.fromisoformat(recent_samples[i]["timestamp"])
|
||||
diff = (t2 - t1).total_seconds()
|
||||
if diff > 0:
|
||||
time_diffs.append(diff)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if time_diffs:
|
||||
avg_interval = sum(time_diffs) / len(time_diffs)
|
||||
self.stats["average_sample_rate"] = 1.0 / avg_interval if avg_interval > 0 else 0.0
|
||||
|
||||
async def get_router_status(self, router_id: str) -> Dict[str, Any]:
|
||||
"""Get status of a specific router."""
|
||||
if router_id not in self.router_interfaces:
|
||||
raise ValueError(f"Router {router_id} not found")
|
||||
|
||||
interface = self.router_interfaces[router_id]
|
||||
|
||||
try:
|
||||
is_healthy = await interface.check_health()
|
||||
status = await interface.get_status()
|
||||
|
||||
return {
|
||||
"router_id": router_id,
|
||||
"healthy": is_healthy,
|
||||
"connected": status.get("connected", False),
|
||||
"last_data_time": status.get("last_data_time"),
|
||||
"error_count": status.get("error_count", 0),
|
||||
"configuration": status.get("configuration", {})
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"router_id": router_id,
|
||||
"healthy": False,
|
||||
"connected": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def get_all_router_status(self) -> List[Dict[str, Any]]:
|
||||
"""Get status of all routers."""
|
||||
statuses = []
|
||||
|
||||
for router_id in self.router_interfaces:
|
||||
try:
|
||||
status = await self.get_router_status(router_id)
|
||||
statuses.append(status)
|
||||
except Exception as e:
|
||||
statuses.append({
|
||||
"router_id": router_id,
|
||||
"healthy": False,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
return statuses
|
||||
|
||||
async def get_recent_data(self, router_id: Optional[str] = None, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Get recent CSI data samples."""
|
||||
samples = self.recent_samples[-limit:] if limit else self.recent_samples
|
||||
|
||||
if router_id:
|
||||
samples = [s for s in samples if s["router_id"] == router_id]
|
||||
|
||||
# Convert numpy arrays to lists for JSON serialization
|
||||
result = []
|
||||
for sample in samples:
|
||||
sample_copy = sample.copy()
|
||||
if isinstance(sample_copy["data"], np.ndarray):
|
||||
sample_copy["data"] = sample_copy["data"].tolist()
|
||||
result.append(sample_copy)
|
||||
|
||||
return result
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get service status."""
|
||||
return {
|
||||
"status": "healthy" if self.is_running and not self.last_error else "unhealthy",
|
||||
"running": self.is_running,
|
||||
"last_error": self.last_error,
|
||||
"statistics": self.stats.copy(),
|
||||
"configuration": {
|
||||
"mock_hardware": self.settings.mock_hardware,
|
||||
"wifi_interface": self.settings.wifi_interface,
|
||||
"polling_interval": self.settings.hardware_polling_interval,
|
||||
"buffer_size": self.settings.csi_buffer_size
|
||||
},
|
||||
"routers": await self.get_all_router_status()
|
||||
}
|
||||
|
||||
async def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get service metrics."""
|
||||
total_samples = self.stats["total_samples"]
|
||||
success_rate = self.stats["successful_samples"] / max(1, total_samples)
|
||||
|
||||
return {
|
||||
"hardware_service": {
|
||||
"total_samples": total_samples,
|
||||
"successful_samples": self.stats["successful_samples"],
|
||||
"failed_samples": self.stats["failed_samples"],
|
||||
"success_rate": success_rate,
|
||||
"average_sample_rate": self.stats["average_sample_rate"],
|
||||
"connected_routers": self.stats["connected_routers"],
|
||||
"last_sample_time": self.stats["last_sample_time"]
|
||||
}
|
||||
}
|
||||
|
||||
async def reset(self):
|
||||
"""Reset service state."""
|
||||
self.stats = {
|
||||
"total_samples": 0,
|
||||
"successful_samples": 0,
|
||||
"failed_samples": 0,
|
||||
"average_sample_rate": 0.0,
|
||||
"last_sample_time": None,
|
||||
"connected_routers": len(self.router_interfaces)
|
||||
}
|
||||
|
||||
self.recent_samples.clear()
|
||||
self.last_error = None
|
||||
|
||||
self.logger.info("Hardware service reset")
|
||||
|
||||
async def trigger_manual_collection(self, router_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Manually trigger data collection."""
|
||||
if not self.is_running:
|
||||
raise RuntimeError("Hardware service is not running")
|
||||
|
||||
results = {}
|
||||
|
||||
if router_id:
|
||||
# Collect from specific router
|
||||
if router_id not in self.router_interfaces:
|
||||
raise ValueError(f"Router {router_id} not found")
|
||||
|
||||
interface = self.router_interfaces[router_id]
|
||||
try:
|
||||
csi_data = await interface.get_csi_data()
|
||||
if csi_data is not None:
|
||||
await self._process_collected_data(router_id, csi_data)
|
||||
results[router_id] = {"success": True, "data_shape": csi_data.shape if hasattr(csi_data, 'shape') else None}
|
||||
else:
|
||||
results[router_id] = {"success": False, "error": "No data received"}
|
||||
except Exception as e:
|
||||
results[router_id] = {"success": False, "error": str(e)}
|
||||
else:
|
||||
# Collect from all routers
|
||||
await self._collect_data_from_routers()
|
||||
results = {"message": "Manual collection triggered for all routers"}
|
||||
|
||||
return results
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Perform health check."""
|
||||
try:
|
||||
status = "healthy" if self.is_running and not self.last_error else "unhealthy"
|
||||
|
||||
# Check router health
|
||||
healthy_routers = 0
|
||||
total_routers = len(self.router_interfaces)
|
||||
|
||||
for router_id, interface in self.router_interfaces.items():
|
||||
try:
|
||||
if await interface.check_health():
|
||||
healthy_routers += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"status": status,
|
||||
"message": self.last_error if self.last_error else "Hardware service is running normally",
|
||||
"connected_routers": f"{healthy_routers}/{total_routers}",
|
||||
"metrics": {
|
||||
"total_samples": self.stats["total_samples"],
|
||||
"success_rate": (
|
||||
self.stats["successful_samples"] / max(1, self.stats["total_samples"])
|
||||
),
|
||||
"average_sample_rate": self.stats["average_sample_rate"]
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"message": f"Health check failed: {str(e)}"
|
||||
}
|
||||
|
||||
async def is_ready(self) -> bool:
|
||||
"""Check if service is ready."""
|
||||
return self.is_running and len(self.router_interfaces) > 0
|
||||
706
src/services/pose_service.py
Normal file
706
src/services/pose_service.py
Normal file
@@ -0,0 +1,706 @@
|
||||
"""
|
||||
Pose estimation service for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.config.domains import DomainConfig
|
||||
from src.core.csi_processor import CSIProcessor
|
||||
from src.core.phase_sanitizer import PhaseSanitizer
|
||||
from src.models.densepose_head import DensePoseHead
|
||||
from src.models.modality_translation import ModalityTranslationNetwork
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PoseService:
|
||||
"""Service for pose estimation operations."""
|
||||
|
||||
def __init__(self, settings: Settings, domain_config: DomainConfig):
|
||||
"""Initialize pose service."""
|
||||
self.settings = settings
|
||||
self.domain_config = domain_config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize components
|
||||
self.csi_processor = None
|
||||
self.phase_sanitizer = None
|
||||
self.densepose_model = None
|
||||
self.modality_translator = None
|
||||
|
||||
# Service state
|
||||
self.is_initialized = False
|
||||
self.is_running = False
|
||||
self.last_error = None
|
||||
|
||||
# Processing statistics
|
||||
self.stats = {
|
||||
"total_processed": 0,
|
||||
"successful_detections": 0,
|
||||
"failed_detections": 0,
|
||||
"average_confidence": 0.0,
|
||||
"processing_time_ms": 0.0
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the pose service."""
|
||||
try:
|
||||
self.logger.info("Initializing pose service...")
|
||||
|
||||
# Initialize CSI processor
|
||||
csi_config = {
|
||||
'buffer_size': self.settings.csi_buffer_size,
|
||||
'sample_rate': 1000, # Default sampling rate
|
||||
'num_subcarriers': 56,
|
||||
'num_antennas': 3
|
||||
}
|
||||
self.csi_processor = CSIProcessor(config=csi_config)
|
||||
|
||||
# Initialize phase sanitizer
|
||||
self.phase_sanitizer = PhaseSanitizer()
|
||||
|
||||
# Initialize models if not mocking
|
||||
if not self.settings.mock_pose_data:
|
||||
await self._initialize_models()
|
||||
else:
|
||||
self.logger.info("Using mock pose data for development")
|
||||
|
||||
self.is_initialized = True
|
||||
self.logger.info("Pose service initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.last_error = str(e)
|
||||
self.logger.error(f"Failed to initialize pose service: {e}")
|
||||
raise
|
||||
|
||||
async def _initialize_models(self):
|
||||
"""Initialize neural network models."""
|
||||
try:
|
||||
# Initialize DensePose model
|
||||
if self.settings.pose_model_path:
|
||||
self.densepose_model = DensePoseHead()
|
||||
# Load model weights if path is provided
|
||||
# model_state = torch.load(self.settings.pose_model_path)
|
||||
# self.densepose_model.load_state_dict(model_state)
|
||||
self.logger.info("DensePose model loaded")
|
||||
else:
|
||||
self.logger.warning("No pose model path provided, using default model")
|
||||
self.densepose_model = DensePoseHead()
|
||||
|
||||
# Initialize modality translation
|
||||
config = {
|
||||
'input_channels': 64, # CSI data channels
|
||||
'hidden_channels': [128, 256, 512],
|
||||
'output_channels': 256, # Visual feature channels
|
||||
'use_attention': True
|
||||
}
|
||||
self.modality_translator = ModalityTranslationNetwork(config)
|
||||
|
||||
# Set models to evaluation mode
|
||||
self.densepose_model.eval()
|
||||
self.modality_translator.eval()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize models: {e}")
|
||||
raise
|
||||
|
||||
async def start(self):
|
||||
"""Start the pose service."""
|
||||
if not self.is_initialized:
|
||||
await self.initialize()
|
||||
|
||||
self.is_running = True
|
||||
self.logger.info("Pose service started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the pose service."""
|
||||
self.is_running = False
|
||||
self.logger.info("Pose service stopped")
|
||||
|
||||
async def process_csi_data(self, csi_data: np.ndarray, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Process CSI data and estimate poses."""
|
||||
if not self.is_running:
|
||||
raise RuntimeError("Pose service is not running")
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
# Process CSI data
|
||||
processed_csi = await self._process_csi(csi_data, metadata)
|
||||
|
||||
# Estimate poses
|
||||
poses = await self._estimate_poses(processed_csi, metadata)
|
||||
|
||||
# Update statistics
|
||||
processing_time = (datetime.now() - start_time).total_seconds() * 1000
|
||||
self._update_stats(poses, processing_time)
|
||||
|
||||
return {
|
||||
"timestamp": start_time.isoformat(),
|
||||
"poses": poses,
|
||||
"metadata": metadata,
|
||||
"processing_time_ms": processing_time,
|
||||
"confidence_scores": [pose.get("confidence", 0.0) for pose in poses]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.last_error = str(e)
|
||||
self.stats["failed_detections"] += 1
|
||||
self.logger.error(f"Error processing CSI data: {e}")
|
||||
raise
|
||||
|
||||
async def _process_csi(self, csi_data: np.ndarray, metadata: Dict[str, Any]) -> np.ndarray:
|
||||
"""Process raw CSI data."""
|
||||
# Add CSI data to processor
|
||||
self.csi_processor.add_data(csi_data, metadata.get("timestamp", datetime.now()))
|
||||
|
||||
# Get processed data
|
||||
processed_data = self.csi_processor.get_processed_data()
|
||||
|
||||
# Apply phase sanitization
|
||||
if processed_data is not None:
|
||||
sanitized_data = self.phase_sanitizer.sanitize(processed_data)
|
||||
return sanitized_data
|
||||
|
||||
return csi_data
|
||||
|
||||
async def _estimate_poses(self, csi_data: np.ndarray, metadata: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Estimate poses from processed CSI data."""
|
||||
if self.settings.mock_pose_data:
|
||||
return self._generate_mock_poses()
|
||||
|
||||
try:
|
||||
# Convert CSI data to tensor
|
||||
csi_tensor = torch.from_numpy(csi_data).float()
|
||||
|
||||
# Add batch dimension if needed
|
||||
if len(csi_tensor.shape) == 2:
|
||||
csi_tensor = csi_tensor.unsqueeze(0)
|
||||
|
||||
# Translate modality (CSI to visual-like features)
|
||||
with torch.no_grad():
|
||||
visual_features = self.modality_translator(csi_tensor)
|
||||
|
||||
# Estimate poses using DensePose
|
||||
pose_outputs = self.densepose_model(visual_features)
|
||||
|
||||
# Convert outputs to pose detections
|
||||
poses = self._parse_pose_outputs(pose_outputs)
|
||||
|
||||
# Filter by confidence threshold
|
||||
filtered_poses = [
|
||||
pose for pose in poses
|
||||
if pose.get("confidence", 0.0) >= self.settings.pose_confidence_threshold
|
||||
]
|
||||
|
||||
# Limit number of persons
|
||||
if len(filtered_poses) > self.settings.pose_max_persons:
|
||||
filtered_poses = sorted(
|
||||
filtered_poses,
|
||||
key=lambda x: x.get("confidence", 0.0),
|
||||
reverse=True
|
||||
)[:self.settings.pose_max_persons]
|
||||
|
||||
return filtered_poses
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in pose estimation: {e}")
|
||||
return []
|
||||
|
||||
def _parse_pose_outputs(self, outputs: torch.Tensor) -> List[Dict[str, Any]]:
|
||||
"""Parse neural network outputs into pose detections."""
|
||||
poses = []
|
||||
|
||||
# This is a simplified parsing - in reality, this would depend on the model architecture
|
||||
# For now, generate mock poses based on the output shape
|
||||
batch_size = outputs.shape[0]
|
||||
|
||||
for i in range(batch_size):
|
||||
# Extract pose information (mock implementation)
|
||||
confidence = float(torch.sigmoid(outputs[i, 0]).item()) if outputs.shape[1] > 0 else 0.5
|
||||
|
||||
pose = {
|
||||
"person_id": i,
|
||||
"confidence": confidence,
|
||||
"keypoints": self._generate_keypoints(),
|
||||
"bounding_box": self._generate_bounding_box(),
|
||||
"activity": self._classify_activity(outputs[i] if len(outputs.shape) > 1 else outputs),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
poses.append(pose)
|
||||
|
||||
return poses
|
||||
|
||||
def _generate_mock_poses(self) -> List[Dict[str, Any]]:
|
||||
"""Generate mock pose data for development."""
|
||||
import random
|
||||
|
||||
num_persons = random.randint(1, min(3, self.settings.pose_max_persons))
|
||||
poses = []
|
||||
|
||||
for i in range(num_persons):
|
||||
confidence = random.uniform(0.3, 0.95)
|
||||
|
||||
pose = {
|
||||
"person_id": i,
|
||||
"confidence": confidence,
|
||||
"keypoints": self._generate_keypoints(),
|
||||
"bounding_box": self._generate_bounding_box(),
|
||||
"activity": random.choice(["standing", "sitting", "walking", "lying"]),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
poses.append(pose)
|
||||
|
||||
return poses
|
||||
|
||||
def _generate_keypoints(self) -> List[Dict[str, Any]]:
|
||||
"""Generate keypoints for a person."""
|
||||
import random
|
||||
|
||||
keypoint_names = [
|
||||
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
|
||||
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
|
||||
"left_wrist", "right_wrist", "left_hip", "right_hip",
|
||||
"left_knee", "right_knee", "left_ankle", "right_ankle"
|
||||
]
|
||||
|
||||
keypoints = []
|
||||
for name in keypoint_names:
|
||||
keypoints.append({
|
||||
"name": name,
|
||||
"x": random.uniform(0.1, 0.9),
|
||||
"y": random.uniform(0.1, 0.9),
|
||||
"confidence": random.uniform(0.5, 0.95)
|
||||
})
|
||||
|
||||
return keypoints
|
||||
|
||||
def _generate_bounding_box(self) -> Dict[str, float]:
|
||||
"""Generate bounding box for a person."""
|
||||
import random
|
||||
|
||||
x = random.uniform(0.1, 0.6)
|
||||
y = random.uniform(0.1, 0.6)
|
||||
width = random.uniform(0.2, 0.4)
|
||||
height = random.uniform(0.3, 0.5)
|
||||
|
||||
return {
|
||||
"x": x,
|
||||
"y": y,
|
||||
"width": width,
|
||||
"height": height
|
||||
}
|
||||
|
||||
def _classify_activity(self, features: torch.Tensor) -> str:
|
||||
"""Classify activity from features."""
|
||||
# Simple mock classification
|
||||
import random
|
||||
activities = ["standing", "sitting", "walking", "lying", "unknown"]
|
||||
return random.choice(activities)
|
||||
|
||||
def _update_stats(self, poses: List[Dict[str, Any]], processing_time: float):
|
||||
"""Update processing statistics."""
|
||||
self.stats["total_processed"] += 1
|
||||
|
||||
if poses:
|
||||
self.stats["successful_detections"] += 1
|
||||
confidences = [pose.get("confidence", 0.0) for pose in poses]
|
||||
avg_confidence = sum(confidences) / len(confidences)
|
||||
|
||||
# Update running average
|
||||
total = self.stats["successful_detections"]
|
||||
current_avg = self.stats["average_confidence"]
|
||||
self.stats["average_confidence"] = (current_avg * (total - 1) + avg_confidence) / total
|
||||
else:
|
||||
self.stats["failed_detections"] += 1
|
||||
|
||||
# Update processing time (running average)
|
||||
total = self.stats["total_processed"]
|
||||
current_avg = self.stats["processing_time_ms"]
|
||||
self.stats["processing_time_ms"] = (current_avg * (total - 1) + processing_time) / total
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get service status."""
|
||||
return {
|
||||
"status": "healthy" if self.is_running and not self.last_error else "unhealthy",
|
||||
"initialized": self.is_initialized,
|
||||
"running": self.is_running,
|
||||
"last_error": self.last_error,
|
||||
"statistics": self.stats.copy(),
|
||||
"configuration": {
|
||||
"mock_data": self.settings.mock_pose_data,
|
||||
"confidence_threshold": self.settings.pose_confidence_threshold,
|
||||
"max_persons": self.settings.pose_max_persons,
|
||||
"batch_size": self.settings.pose_processing_batch_size
|
||||
}
|
||||
}
|
||||
|
||||
async def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get service metrics."""
|
||||
return {
|
||||
"pose_service": {
|
||||
"total_processed": self.stats["total_processed"],
|
||||
"successful_detections": self.stats["successful_detections"],
|
||||
"failed_detections": self.stats["failed_detections"],
|
||||
"success_rate": (
|
||||
self.stats["successful_detections"] / max(1, self.stats["total_processed"])
|
||||
),
|
||||
"average_confidence": self.stats["average_confidence"],
|
||||
"average_processing_time_ms": self.stats["processing_time_ms"]
|
||||
}
|
||||
}
|
||||
|
||||
async def reset(self):
|
||||
"""Reset service state."""
|
||||
self.stats = {
|
||||
"total_processed": 0,
|
||||
"successful_detections": 0,
|
||||
"failed_detections": 0,
|
||||
"average_confidence": 0.0,
|
||||
"processing_time_ms": 0.0
|
||||
}
|
||||
self.last_error = None
|
||||
self.logger.info("Pose service reset")
|
||||
|
||||
# API endpoint methods
|
||||
async def estimate_poses(self, zone_ids=None, confidence_threshold=None, max_persons=None,
|
||||
include_keypoints=True, include_segmentation=False):
|
||||
"""Estimate poses with API parameters."""
|
||||
try:
|
||||
# Generate mock CSI data for estimation
|
||||
mock_csi = np.random.randn(64, 56, 3) # Mock CSI data
|
||||
metadata = {
|
||||
"timestamp": datetime.now(),
|
||||
"zone_ids": zone_ids or ["zone_1"],
|
||||
"confidence_threshold": confidence_threshold or self.settings.pose_confidence_threshold,
|
||||
"max_persons": max_persons or self.settings.pose_max_persons
|
||||
}
|
||||
|
||||
# Process the data
|
||||
result = await self.process_csi_data(mock_csi, metadata)
|
||||
|
||||
# Format for API response
|
||||
persons = []
|
||||
for i, pose in enumerate(result["poses"]):
|
||||
person = {
|
||||
"person_id": str(pose["person_id"]),
|
||||
"confidence": pose["confidence"],
|
||||
"bounding_box": pose["bounding_box"],
|
||||
"zone_id": zone_ids[0] if zone_ids else "zone_1",
|
||||
"activity": pose["activity"],
|
||||
"timestamp": datetime.fromisoformat(pose["timestamp"])
|
||||
}
|
||||
|
||||
if include_keypoints:
|
||||
person["keypoints"] = pose["keypoints"]
|
||||
|
||||
if include_segmentation:
|
||||
person["segmentation"] = {"mask": "mock_segmentation_data"}
|
||||
|
||||
persons.append(person)
|
||||
|
||||
# Zone summary
|
||||
zone_summary = {}
|
||||
for zone_id in (zone_ids or ["zone_1"]):
|
||||
zone_summary[zone_id] = len([p for p in persons if p.get("zone_id") == zone_id])
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(),
|
||||
"frame_id": f"frame_{int(datetime.now().timestamp())}",
|
||||
"persons": persons,
|
||||
"zone_summary": zone_summary,
|
||||
"processing_time_ms": result["processing_time_ms"],
|
||||
"metadata": {"mock_data": self.settings.mock_pose_data}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in estimate_poses: {e}")
|
||||
raise
|
||||
|
||||
async def analyze_with_params(self, zone_ids=None, confidence_threshold=None, max_persons=None,
|
||||
include_keypoints=True, include_segmentation=False):
|
||||
"""Analyze pose data with custom parameters."""
|
||||
return await self.estimate_poses(zone_ids, confidence_threshold, max_persons,
|
||||
include_keypoints, include_segmentation)
|
||||
|
||||
async def get_zone_occupancy(self, zone_id: str):
|
||||
"""Get current occupancy for a specific zone."""
|
||||
try:
|
||||
# Mock occupancy data
|
||||
import random
|
||||
count = random.randint(0, 5)
|
||||
persons = []
|
||||
|
||||
for i in range(count):
|
||||
persons.append({
|
||||
"person_id": f"person_{i}",
|
||||
"confidence": random.uniform(0.7, 0.95),
|
||||
"activity": random.choice(["standing", "sitting", "walking"])
|
||||
})
|
||||
|
||||
return {
|
||||
"count": count,
|
||||
"max_occupancy": 10,
|
||||
"persons": persons,
|
||||
"timestamp": datetime.now()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting zone occupancy: {e}")
|
||||
return None
|
||||
|
||||
async def get_zones_summary(self):
|
||||
"""Get occupancy summary for all zones."""
|
||||
try:
|
||||
import random
|
||||
zones = ["zone_1", "zone_2", "zone_3", "zone_4"]
|
||||
zone_data = {}
|
||||
total_persons = 0
|
||||
active_zones = 0
|
||||
|
||||
for zone_id in zones:
|
||||
count = random.randint(0, 3)
|
||||
zone_data[zone_id] = {
|
||||
"occupancy": count,
|
||||
"max_occupancy": 10,
|
||||
"status": "active" if count > 0 else "inactive"
|
||||
}
|
||||
total_persons += count
|
||||
if count > 0:
|
||||
active_zones += 1
|
||||
|
||||
return {
|
||||
"total_persons": total_persons,
|
||||
"zones": zone_data,
|
||||
"active_zones": active_zones
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting zones summary: {e}")
|
||||
raise
|
||||
|
||||
async def get_historical_data(self, start_time, end_time, zone_ids=None,
|
||||
aggregation_interval=300, include_raw_data=False):
|
||||
"""Get historical pose estimation data."""
|
||||
try:
|
||||
# Mock historical data
|
||||
import random
|
||||
from datetime import timedelta
|
||||
|
||||
current_time = start_time
|
||||
aggregated_data = []
|
||||
raw_data = [] if include_raw_data else None
|
||||
|
||||
while current_time < end_time:
|
||||
# Generate aggregated data point
|
||||
data_point = {
|
||||
"timestamp": current_time,
|
||||
"total_persons": random.randint(0, 8),
|
||||
"zones": {}
|
||||
}
|
||||
|
||||
for zone_id in (zone_ids or ["zone_1", "zone_2", "zone_3"]):
|
||||
data_point["zones"][zone_id] = {
|
||||
"occupancy": random.randint(0, 3),
|
||||
"avg_confidence": random.uniform(0.7, 0.95)
|
||||
}
|
||||
|
||||
aggregated_data.append(data_point)
|
||||
|
||||
# Generate raw data if requested
|
||||
if include_raw_data:
|
||||
for _ in range(random.randint(0, 5)):
|
||||
raw_data.append({
|
||||
"timestamp": current_time + timedelta(seconds=random.randint(0, aggregation_interval)),
|
||||
"person_id": f"person_{random.randint(1, 10)}",
|
||||
"zone_id": random.choice(zone_ids or ["zone_1", "zone_2", "zone_3"]),
|
||||
"confidence": random.uniform(0.5, 0.95),
|
||||
"activity": random.choice(["standing", "sitting", "walking"])
|
||||
})
|
||||
|
||||
current_time += timedelta(seconds=aggregation_interval)
|
||||
|
||||
return {
|
||||
"aggregated_data": aggregated_data,
|
||||
"raw_data": raw_data,
|
||||
"total_records": len(aggregated_data)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting historical data: {e}")
|
||||
raise
|
||||
|
||||
async def get_recent_activities(self, zone_id=None, limit=10):
|
||||
"""Get recently detected activities."""
|
||||
try:
|
||||
import random
|
||||
activities = []
|
||||
|
||||
for i in range(limit):
|
||||
activity = {
|
||||
"activity_id": f"activity_{i}",
|
||||
"person_id": f"person_{random.randint(1, 5)}",
|
||||
"zone_id": zone_id or random.choice(["zone_1", "zone_2", "zone_3"]),
|
||||
"activity": random.choice(["standing", "sitting", "walking", "lying"]),
|
||||
"confidence": random.uniform(0.6, 0.95),
|
||||
"timestamp": datetime.now() - timedelta(minutes=random.randint(0, 60)),
|
||||
"duration_seconds": random.randint(10, 300)
|
||||
}
|
||||
activities.append(activity)
|
||||
|
||||
return activities
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting recent activities: {e}")
|
||||
raise
|
||||
|
||||
async def is_calibrating(self):
|
||||
"""Check if calibration is in progress."""
|
||||
return False # Mock implementation
|
||||
|
||||
async def start_calibration(self):
|
||||
"""Start calibration process."""
|
||||
import uuid
|
||||
calibration_id = str(uuid.uuid4())
|
||||
self.logger.info(f"Started calibration: {calibration_id}")
|
||||
return calibration_id
|
||||
|
||||
async def run_calibration(self, calibration_id):
|
||||
"""Run calibration process."""
|
||||
self.logger.info(f"Running calibration: {calibration_id}")
|
||||
# Mock calibration process
|
||||
await asyncio.sleep(5)
|
||||
self.logger.info(f"Calibration completed: {calibration_id}")
|
||||
|
||||
async def get_calibration_status(self):
|
||||
"""Get current calibration status."""
|
||||
return {
|
||||
"is_calibrating": False,
|
||||
"calibration_id": None,
|
||||
"progress_percent": 100,
|
||||
"current_step": "completed",
|
||||
"estimated_remaining_minutes": 0,
|
||||
"last_calibration": datetime.now() - timedelta(hours=1)
|
||||
}
|
||||
|
||||
async def get_statistics(self, start_time, end_time):
|
||||
"""Get pose estimation statistics."""
|
||||
try:
|
||||
import random
|
||||
|
||||
# Mock statistics
|
||||
total_detections = random.randint(100, 1000)
|
||||
successful_detections = int(total_detections * random.uniform(0.8, 0.95))
|
||||
|
||||
return {
|
||||
"total_detections": total_detections,
|
||||
"successful_detections": successful_detections,
|
||||
"failed_detections": total_detections - successful_detections,
|
||||
"success_rate": successful_detections / total_detections,
|
||||
"average_confidence": random.uniform(0.75, 0.90),
|
||||
"average_processing_time_ms": random.uniform(50, 200),
|
||||
"unique_persons": random.randint(5, 20),
|
||||
"most_active_zone": random.choice(["zone_1", "zone_2", "zone_3"]),
|
||||
"activity_distribution": {
|
||||
"standing": random.uniform(0.3, 0.5),
|
||||
"sitting": random.uniform(0.2, 0.4),
|
||||
"walking": random.uniform(0.1, 0.3),
|
||||
"lying": random.uniform(0.0, 0.1)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting statistics: {e}")
|
||||
raise
|
||||
|
||||
async def process_segmentation_data(self, frame_id):
|
||||
"""Process segmentation data in background."""
|
||||
self.logger.info(f"Processing segmentation data for frame: {frame_id}")
|
||||
# Mock background processing
|
||||
await asyncio.sleep(2)
|
||||
self.logger.info(f"Segmentation processing completed for frame: {frame_id}")
|
||||
|
||||
# WebSocket streaming methods
|
||||
async def get_current_pose_data(self):
|
||||
"""Get current pose data for streaming."""
|
||||
try:
|
||||
# Generate current pose data
|
||||
result = await self.estimate_poses()
|
||||
|
||||
# Format data by zones for WebSocket streaming
|
||||
zone_data = {}
|
||||
|
||||
# Group persons by zone
|
||||
for person in result["persons"]:
|
||||
zone_id = person.get("zone_id", "zone_1")
|
||||
|
||||
if zone_id not in zone_data:
|
||||
zone_data[zone_id] = {
|
||||
"pose": {
|
||||
"persons": [],
|
||||
"count": 0
|
||||
},
|
||||
"confidence": 0.0,
|
||||
"activity": None,
|
||||
"metadata": {
|
||||
"frame_id": result["frame_id"],
|
||||
"processing_time_ms": result["processing_time_ms"]
|
||||
}
|
||||
}
|
||||
|
||||
zone_data[zone_id]["pose"]["persons"].append(person)
|
||||
zone_data[zone_id]["pose"]["count"] += 1
|
||||
|
||||
# Update zone confidence (average)
|
||||
current_confidence = zone_data[zone_id]["confidence"]
|
||||
person_confidence = person.get("confidence", 0.0)
|
||||
zone_data[zone_id]["confidence"] = (current_confidence + person_confidence) / 2
|
||||
|
||||
# Set activity if not already set
|
||||
if not zone_data[zone_id]["activity"] and person.get("activity"):
|
||||
zone_data[zone_id]["activity"] = person["activity"]
|
||||
|
||||
return zone_data
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting current pose data: {e}")
|
||||
# Return empty zone data on error
|
||||
return {}
|
||||
|
||||
# Health check methods
|
||||
async def health_check(self):
|
||||
"""Perform health check."""
|
||||
try:
|
||||
status = "healthy" if self.is_running and not self.last_error else "unhealthy"
|
||||
|
||||
return {
|
||||
"status": status,
|
||||
"message": self.last_error if self.last_error else "Service is running normally",
|
||||
"uptime_seconds": 0.0, # TODO: Implement actual uptime tracking
|
||||
"metrics": {
|
||||
"total_processed": self.stats["total_processed"],
|
||||
"success_rate": (
|
||||
self.stats["successful_detections"] / max(1, self.stats["total_processed"])
|
||||
),
|
||||
"average_processing_time_ms": self.stats["processing_time_ms"]
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"message": f"Health check failed: {str(e)}"
|
||||
}
|
||||
|
||||
async def is_ready(self):
|
||||
"""Check if service is ready."""
|
||||
return self.is_initialized and self.is_running
|
||||
397
src/services/stream_service.py
Normal file
397
src/services/stream_service.py
Normal file
@@ -0,0 +1,397 @@
|
||||
"""
|
||||
Real-time streaming service for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, List, Optional, Any, Set
|
||||
from datetime import datetime
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
from fastapi import WebSocket
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.config.domains import DomainConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamService:
|
||||
"""Service for real-time data streaming."""
|
||||
|
||||
def __init__(self, settings: Settings, domain_config: DomainConfig):
|
||||
"""Initialize stream service."""
|
||||
self.settings = settings
|
||||
self.domain_config = domain_config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# WebSocket connections
|
||||
self.connections: Set[WebSocket] = set()
|
||||
self.connection_metadata: Dict[WebSocket, Dict[str, Any]] = {}
|
||||
|
||||
# Stream buffers
|
||||
self.pose_buffer = deque(maxlen=self.settings.stream_buffer_size)
|
||||
self.csi_buffer = deque(maxlen=self.settings.stream_buffer_size)
|
||||
|
||||
# Service state
|
||||
self.is_running = False
|
||||
self.last_error = None
|
||||
|
||||
# Streaming statistics
|
||||
self.stats = {
|
||||
"active_connections": 0,
|
||||
"total_connections": 0,
|
||||
"messages_sent": 0,
|
||||
"messages_failed": 0,
|
||||
"data_points_streamed": 0,
|
||||
"average_latency_ms": 0.0
|
||||
}
|
||||
|
||||
# Background tasks
|
||||
self.streaming_task = None
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the stream service."""
|
||||
self.logger.info("Stream service initialized")
|
||||
|
||||
async def start(self):
|
||||
"""Start the stream service."""
|
||||
if self.is_running:
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
self.logger.info("Stream service started")
|
||||
|
||||
# Start background streaming task
|
||||
if self.settings.enable_real_time_processing:
|
||||
self.streaming_task = asyncio.create_task(self._streaming_loop())
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the stream service."""
|
||||
self.is_running = False
|
||||
|
||||
# Cancel background task
|
||||
if self.streaming_task:
|
||||
self.streaming_task.cancel()
|
||||
try:
|
||||
await self.streaming_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Close all connections
|
||||
await self._close_all_connections()
|
||||
|
||||
self.logger.info("Stream service stopped")
|
||||
|
||||
async def add_connection(self, websocket: WebSocket, metadata: Dict[str, Any] = None):
|
||||
"""Add a new WebSocket connection."""
|
||||
try:
|
||||
await websocket.accept()
|
||||
self.connections.add(websocket)
|
||||
self.connection_metadata[websocket] = metadata or {}
|
||||
|
||||
self.stats["active_connections"] = len(self.connections)
|
||||
self.stats["total_connections"] += 1
|
||||
|
||||
self.logger.info(f"New WebSocket connection added. Total: {len(self.connections)}")
|
||||
|
||||
# Send initial data if available
|
||||
await self._send_initial_data(websocket)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error adding WebSocket connection: {e}")
|
||||
raise
|
||||
|
||||
async def remove_connection(self, websocket: WebSocket):
|
||||
"""Remove a WebSocket connection."""
|
||||
try:
|
||||
if websocket in self.connections:
|
||||
self.connections.remove(websocket)
|
||||
self.connection_metadata.pop(websocket, None)
|
||||
|
||||
self.stats["active_connections"] = len(self.connections)
|
||||
|
||||
self.logger.info(f"WebSocket connection removed. Total: {len(self.connections)}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error removing WebSocket connection: {e}")
|
||||
|
||||
async def broadcast_pose_data(self, pose_data: Dict[str, Any]):
|
||||
"""Broadcast pose data to all connected clients."""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
# Add to buffer
|
||||
self.pose_buffer.append({
|
||||
"type": "pose_data",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": pose_data
|
||||
})
|
||||
|
||||
# Broadcast to all connections
|
||||
await self._broadcast_message({
|
||||
"type": "pose_update",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": pose_data
|
||||
})
|
||||
|
||||
async def broadcast_csi_data(self, csi_data: np.ndarray, metadata: Dict[str, Any]):
|
||||
"""Broadcast CSI data to all connected clients."""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
# Convert numpy array to list for JSON serialization
|
||||
csi_list = csi_data.tolist() if isinstance(csi_data, np.ndarray) else csi_data
|
||||
|
||||
# Add to buffer
|
||||
self.csi_buffer.append({
|
||||
"type": "csi_data",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": csi_list,
|
||||
"metadata": metadata
|
||||
})
|
||||
|
||||
# Broadcast to all connections
|
||||
await self._broadcast_message({
|
||||
"type": "csi_update",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": csi_list,
|
||||
"metadata": metadata
|
||||
})
|
||||
|
||||
async def broadcast_system_status(self, status_data: Dict[str, Any]):
|
||||
"""Broadcast system status to all connected clients."""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
await self._broadcast_message({
|
||||
"type": "system_status",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": status_data
|
||||
})
|
||||
|
||||
async def send_to_connection(self, websocket: WebSocket, message: Dict[str, Any]):
|
||||
"""Send message to a specific connection."""
|
||||
try:
|
||||
if websocket in self.connections:
|
||||
await websocket.send_text(json.dumps(message))
|
||||
self.stats["messages_sent"] += 1
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error sending message to connection: {e}")
|
||||
self.stats["messages_failed"] += 1
|
||||
await self.remove_connection(websocket)
|
||||
|
||||
async def _broadcast_message(self, message: Dict[str, Any]):
|
||||
"""Broadcast message to all connected clients."""
|
||||
if not self.connections:
|
||||
return
|
||||
|
||||
disconnected = set()
|
||||
|
||||
for websocket in self.connections.copy():
|
||||
try:
|
||||
await websocket.send_text(json.dumps(message))
|
||||
self.stats["messages_sent"] += 1
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to send message to connection: {e}")
|
||||
self.stats["messages_failed"] += 1
|
||||
disconnected.add(websocket)
|
||||
|
||||
# Remove disconnected clients
|
||||
for websocket in disconnected:
|
||||
await self.remove_connection(websocket)
|
||||
|
||||
if message.get("type") in ["pose_update", "csi_update"]:
|
||||
self.stats["data_points_streamed"] += 1
|
||||
|
||||
async def _send_initial_data(self, websocket: WebSocket):
|
||||
"""Send initial data to a new connection."""
|
||||
try:
|
||||
# Send recent pose data
|
||||
if self.pose_buffer:
|
||||
recent_poses = list(self.pose_buffer)[-10:] # Last 10 poses
|
||||
await self.send_to_connection(websocket, {
|
||||
"type": "initial_poses",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": recent_poses
|
||||
})
|
||||
|
||||
# Send recent CSI data
|
||||
if self.csi_buffer:
|
||||
recent_csi = list(self.csi_buffer)[-5:] # Last 5 CSI readings
|
||||
await self.send_to_connection(websocket, {
|
||||
"type": "initial_csi",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": recent_csi
|
||||
})
|
||||
|
||||
# Send service status
|
||||
status = await self.get_status()
|
||||
await self.send_to_connection(websocket, {
|
||||
"type": "service_status",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": status
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error sending initial data: {e}")
|
||||
|
||||
async def _streaming_loop(self):
|
||||
"""Background streaming loop for periodic updates."""
|
||||
try:
|
||||
while self.is_running:
|
||||
# Send periodic heartbeat
|
||||
if self.connections:
|
||||
await self._broadcast_message({
|
||||
"type": "heartbeat",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"active_connections": len(self.connections)
|
||||
})
|
||||
|
||||
# Wait for next iteration
|
||||
await asyncio.sleep(self.settings.websocket_ping_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self.logger.info("Streaming loop cancelled")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in streaming loop: {e}")
|
||||
self.last_error = str(e)
|
||||
|
||||
async def _close_all_connections(self):
|
||||
"""Close all WebSocket connections."""
|
||||
disconnected = []
|
||||
|
||||
for websocket in self.connections.copy():
|
||||
try:
|
||||
await websocket.close()
|
||||
disconnected.append(websocket)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error closing connection: {e}")
|
||||
disconnected.append(websocket)
|
||||
|
||||
# Clear all connections
|
||||
for websocket in disconnected:
|
||||
await self.remove_connection(websocket)
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get service status."""
|
||||
return {
|
||||
"status": "healthy" if self.is_running and not self.last_error else "unhealthy",
|
||||
"running": self.is_running,
|
||||
"last_error": self.last_error,
|
||||
"connections": {
|
||||
"active": len(self.connections),
|
||||
"total": self.stats["total_connections"]
|
||||
},
|
||||
"buffers": {
|
||||
"pose_buffer_size": len(self.pose_buffer),
|
||||
"csi_buffer_size": len(self.csi_buffer),
|
||||
"max_buffer_size": self.settings.stream_buffer_size
|
||||
},
|
||||
"statistics": self.stats.copy(),
|
||||
"configuration": {
|
||||
"stream_fps": self.settings.stream_fps,
|
||||
"buffer_size": self.settings.stream_buffer_size,
|
||||
"ping_interval": self.settings.websocket_ping_interval,
|
||||
"timeout": self.settings.websocket_timeout
|
||||
}
|
||||
}
|
||||
|
||||
async def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get service metrics."""
|
||||
total_messages = self.stats["messages_sent"] + self.stats["messages_failed"]
|
||||
success_rate = self.stats["messages_sent"] / max(1, total_messages)
|
||||
|
||||
return {
|
||||
"stream_service": {
|
||||
"active_connections": self.stats["active_connections"],
|
||||
"total_connections": self.stats["total_connections"],
|
||||
"messages_sent": self.stats["messages_sent"],
|
||||
"messages_failed": self.stats["messages_failed"],
|
||||
"message_success_rate": success_rate,
|
||||
"data_points_streamed": self.stats["data_points_streamed"],
|
||||
"average_latency_ms": self.stats["average_latency_ms"]
|
||||
}
|
||||
}
|
||||
|
||||
async def get_connection_info(self) -> List[Dict[str, Any]]:
|
||||
"""Get information about active connections."""
|
||||
connections_info = []
|
||||
|
||||
for websocket in self.connections:
|
||||
metadata = self.connection_metadata.get(websocket, {})
|
||||
|
||||
connection_info = {
|
||||
"id": id(websocket),
|
||||
"connected_at": metadata.get("connected_at", "unknown"),
|
||||
"user_agent": metadata.get("user_agent", "unknown"),
|
||||
"ip_address": metadata.get("ip_address", "unknown"),
|
||||
"subscription_types": metadata.get("subscription_types", [])
|
||||
}
|
||||
|
||||
connections_info.append(connection_info)
|
||||
|
||||
return connections_info
|
||||
|
||||
async def reset(self):
|
||||
"""Reset service state."""
|
||||
# Clear buffers
|
||||
self.pose_buffer.clear()
|
||||
self.csi_buffer.clear()
|
||||
|
||||
# Reset statistics
|
||||
self.stats = {
|
||||
"active_connections": len(self.connections),
|
||||
"total_connections": 0,
|
||||
"messages_sent": 0,
|
||||
"messages_failed": 0,
|
||||
"data_points_streamed": 0,
|
||||
"average_latency_ms": 0.0
|
||||
}
|
||||
|
||||
self.last_error = None
|
||||
self.logger.info("Stream service reset")
|
||||
|
||||
def get_buffer_data(self, buffer_type: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Get data from buffers."""
|
||||
if buffer_type == "pose":
|
||||
return list(self.pose_buffer)[-limit:]
|
||||
elif buffer_type == "csi":
|
||||
return list(self.csi_buffer)[-limit:]
|
||||
else:
|
||||
return []
|
||||
|
||||
@property
|
||||
def is_active(self) -> bool:
|
||||
"""Check if stream service is active."""
|
||||
return self.is_running
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Perform health check."""
|
||||
try:
|
||||
status = "healthy" if self.is_running and not self.last_error else "unhealthy"
|
||||
|
||||
return {
|
||||
"status": status,
|
||||
"message": self.last_error if self.last_error else "Stream service is running normally",
|
||||
"active_connections": len(self.connections),
|
||||
"metrics": {
|
||||
"messages_sent": self.stats["messages_sent"],
|
||||
"messages_failed": self.stats["messages_failed"],
|
||||
"data_points_streamed": self.stats["data_points_streamed"]
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"message": f"Health check failed: {str(e)}"
|
||||
}
|
||||
|
||||
async def is_ready(self) -> bool:
|
||||
"""Check if service is ready."""
|
||||
return self.is_running
|
||||
198
test_application.py
Normal file
198
test_application.py
Normal file
@@ -0,0 +1,198 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify WiFi-DensePose API functionality
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
import websockets
|
||||
import sys
|
||||
from typing import Dict, Any
|
||||
|
||||
BASE_URL = "http://localhost:8000"
|
||||
WS_URL = "ws://localhost:8000"
|
||||
|
||||
async def test_health_endpoints():
|
||||
"""Test health check endpoints."""
|
||||
print("🔍 Testing health endpoints...")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Test basic health
|
||||
async with session.get(f"{BASE_URL}/health/health") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
print(f"✅ Health check: {data['status']}")
|
||||
else:
|
||||
print(f"❌ Health check failed: {response.status}")
|
||||
|
||||
# Test readiness
|
||||
async with session.get(f"{BASE_URL}/health/ready") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
status = "ready" if data['ready'] else "not ready"
|
||||
print(f"✅ Readiness check: {status}")
|
||||
else:
|
||||
print(f"❌ Readiness check failed: {response.status}")
|
||||
|
||||
# Test liveness
|
||||
async with session.get(f"{BASE_URL}/health/live") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
print(f"✅ Liveness check: {data['status']}")
|
||||
else:
|
||||
print(f"❌ Liveness check failed: {response.status}")
|
||||
|
||||
async def test_api_endpoints():
|
||||
"""Test main API endpoints."""
|
||||
print("\n🔍 Testing API endpoints...")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Test root endpoint
|
||||
async with session.get(f"{BASE_URL}/") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
print(f"✅ Root endpoint: {data['name']} v{data['version']}")
|
||||
else:
|
||||
print(f"❌ Root endpoint failed: {response.status}")
|
||||
|
||||
# Test API info
|
||||
async with session.get(f"{BASE_URL}/api/v1/info") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
print(f"✅ API info: {len(data['services'])} services configured")
|
||||
else:
|
||||
print(f"❌ API info failed: {response.status}")
|
||||
|
||||
# Test API status
|
||||
async with session.get(f"{BASE_URL}/api/v1/status") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
print(f"✅ API status: {data['api']['status']}")
|
||||
else:
|
||||
print(f"❌ API status failed: {response.status}")
|
||||
|
||||
async def test_pose_endpoints():
|
||||
"""Test pose estimation endpoints."""
|
||||
print("\n🔍 Testing pose endpoints...")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Test current pose data
|
||||
async with session.get(f"{BASE_URL}/api/v1/pose/current") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
print(f"✅ Current pose data: {len(data.get('poses', []))} poses detected")
|
||||
else:
|
||||
print(f"❌ Current pose data failed: {response.status}")
|
||||
|
||||
# Test zones summary
|
||||
async with session.get(f"{BASE_URL}/api/v1/pose/zones/summary") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
zones = data.get('zones', {})
|
||||
print(f"✅ Zones summary: {len(zones)} zones")
|
||||
for zone_id, zone_data in list(zones.items())[:3]: # Show first 3 zones
|
||||
print(f" - {zone_id}: {zone_data.get('occupancy', 0)} people")
|
||||
else:
|
||||
print(f"❌ Zones summary failed: {response.status}")
|
||||
|
||||
# Test pose stats
|
||||
async with session.get(f"{BASE_URL}/api/v1/pose/stats") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
print(f"✅ Pose stats: {data.get('total_detections', 0)} total detections")
|
||||
else:
|
||||
print(f"❌ Pose stats failed: {response.status}")
|
||||
|
||||
async def test_stream_endpoints():
|
||||
"""Test streaming endpoints."""
|
||||
print("\n🔍 Testing stream endpoints...")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Test stream status
|
||||
async with session.get(f"{BASE_URL}/api/v1/stream/status") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
print(f"✅ Stream status: {'Active' if data['is_active'] else 'Inactive'}")
|
||||
print(f" - Connected clients: {data['connected_clients']}")
|
||||
else:
|
||||
print(f"❌ Stream status failed: {response.status}")
|
||||
|
||||
# Test stream metrics
|
||||
async with session.get(f"{BASE_URL}/api/v1/stream/metrics") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
print(f"✅ Stream metrics available")
|
||||
else:
|
||||
print(f"❌ Stream metrics failed: {response.status}")
|
||||
|
||||
async def test_websocket_connection():
|
||||
"""Test WebSocket connection."""
|
||||
print("\n🔍 Testing WebSocket connection...")
|
||||
|
||||
try:
|
||||
uri = f"{WS_URL}/api/v1/stream/pose"
|
||||
async with websockets.connect(uri) as websocket:
|
||||
print("✅ WebSocket connected successfully")
|
||||
|
||||
# Wait for connection confirmation
|
||||
message = await asyncio.wait_for(websocket.recv(), timeout=5.0)
|
||||
data = json.loads(message)
|
||||
|
||||
if data.get("type") == "connection_established":
|
||||
print(f"✅ Connection established with client ID: {data.get('client_id')}")
|
||||
|
||||
# Send a ping
|
||||
await websocket.send(json.dumps({"type": "ping"}))
|
||||
|
||||
# Wait for pong
|
||||
pong_message = await asyncio.wait_for(websocket.recv(), timeout=5.0)
|
||||
pong_data = json.loads(pong_message)
|
||||
|
||||
if pong_data.get("type") == "pong":
|
||||
print("✅ WebSocket ping/pong successful")
|
||||
else:
|
||||
print(f"❌ Unexpected pong response: {pong_data}")
|
||||
else:
|
||||
print(f"❌ Unexpected connection message: {data}")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
print("❌ WebSocket connection timeout")
|
||||
except Exception as e:
|
||||
print(f"❌ WebSocket connection failed: {e}")
|
||||
|
||||
async def test_calibration_endpoints():
|
||||
"""Test calibration endpoints."""
|
||||
print("\n🔍 Testing calibration endpoints...")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Test calibration status
|
||||
async with session.get(f"{BASE_URL}/api/v1/pose/calibration/status") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
print(f"✅ Calibration status: {data.get('status', 'unknown')}")
|
||||
else:
|
||||
print(f"❌ Calibration status failed: {response.status}")
|
||||
|
||||
async def main():
|
||||
"""Run all tests."""
|
||||
print("🚀 Starting WiFi-DensePose API Tests")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
await test_health_endpoints()
|
||||
await test_api_endpoints()
|
||||
await test_pose_endpoints()
|
||||
await test_stream_endpoints()
|
||||
await test_websocket_connection()
|
||||
await test_calibration_endpoints()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("✅ All tests completed!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Test suite failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user