feat: Implement hardware, pose, and stream services for WiFi-DensePose API

- Added HardwareService for managing router interfaces, data collection, and monitoring.
- Introduced PoseService for processing CSI data and estimating poses using neural networks.
- Created StreamService for real-time data streaming via WebSocket connections.
- Implemented initialization, start, stop, and status retrieval methods for each service.
- Added data processing, error handling, and statistics tracking across services.
- Integrated mock data generation for development and testing purposes.
This commit is contained in:
rUv
2025-06-07 12:47:54 +00:00
parent c378b705ca
commit 90f03bac7d
26 changed files with 9846 additions and 105 deletions

18
=3.0.0
View File

@@ -1,18 +0,0 @@
Collecting paramiko
Downloading paramiko-3.5.1-py3-none-any.whl.metadata (4.6 kB)
Collecting bcrypt>=3.2 (from paramiko)
Downloading bcrypt-4.3.0-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (10 kB)
Collecting cryptography>=3.3 (from paramiko)
Downloading cryptography-45.0.3-cp311-abi3-manylinux_2_28_x86_64.whl.metadata (5.7 kB)
Collecting pynacl>=1.5 (from paramiko)
Downloading PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl.metadata (8.6 kB)
Requirement already satisfied: cffi>=1.14 in /home/codespace/.local/lib/python3.12/site-packages (from cryptography>=3.3->paramiko) (1.17.1)
Requirement already satisfied: pycparser in /home/codespace/.local/lib/python3.12/site-packages (from cffi>=1.14->cryptography>=3.3->paramiko) (2.22)
Downloading paramiko-3.5.1-py3-none-any.whl (227 kB)
Downloading bcrypt-4.3.0-cp39-abi3-manylinux_2_28_x86_64.whl (284 kB)
Downloading cryptography-45.0.3-cp311-abi3-manylinux_2_28_x86_64.whl (4.5 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.5/4.5 MB 45.0 MB/s eta 0:00:00
Downloading PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (856 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 856.7/856.7 kB 37.4 MB/s eta 0:00:00
Installing collected packages: bcrypt, pynacl, cryptography, paramiko
Successfully installed bcrypt-4.3.0 cryptography-45.0.3 paramiko-3.5.1 pynacl-1.5.0

183
example.env Normal file
View File

@@ -0,0 +1,183 @@
# WiFi-DensePose API Environment Configuration Template
# Copy this file to .env and modify the values according to your setup
# =============================================================================
# APPLICATION SETTINGS
# =============================================================================
# Application metadata
APP_NAME=WiFi-DensePose API
VERSION=1.0.0
ENVIRONMENT=development # Options: development, staging, production
DEBUG=true
# =============================================================================
# SERVER SETTINGS
# =============================================================================
# Server configuration
HOST=0.0.0.0
PORT=8000
RELOAD=true # Auto-reload on code changes (development only)
WORKERS=1 # Number of worker processes
# =============================================================================
# SECURITY SETTINGS
# =============================================================================
# IMPORTANT: Change these values for production!
SECRET_KEY=your-secret-key-here-change-for-production
JWT_ALGORITHM=HS256
JWT_EXPIRE_HOURS=24
# Allowed hosts (restrict in production)
ALLOWED_HOSTS=* # Use specific domains in production: example.com,api.example.com
# CORS settings (restrict in production)
CORS_ORIGINS=* # Use specific origins in production: https://example.com,https://app.example.com
# =============================================================================
# DATABASE SETTINGS
# =============================================================================
# Database connection (optional - defaults to SQLite in development)
# DATABASE_URL=postgresql://user:password@localhost:5432/wifi_densepose
# DATABASE_POOL_SIZE=10
# DATABASE_MAX_OVERFLOW=20
# =============================================================================
# REDIS SETTINGS (Optional - for caching and rate limiting)
# =============================================================================
# Redis connection (optional - defaults to localhost in development)
# REDIS_URL=redis://localhost:6379/0
# REDIS_PASSWORD=your-redis-password
# REDIS_DB=0
# =============================================================================
# HARDWARE SETTINGS
# =============================================================================
# WiFi interface configuration
WIFI_INTERFACE=wlan0
CSI_BUFFER_SIZE=1000
HARDWARE_POLLING_INTERVAL=0.1
# Hardware mock settings (for development/testing)
MOCK_HARDWARE=true
MOCK_POSE_DATA=true
# =============================================================================
# POSE ESTIMATION SETTINGS
# =============================================================================
# Model configuration
# POSE_MODEL_PATH=/path/to/your/pose/model.pth
POSE_CONFIDENCE_THRESHOLD=0.5
POSE_PROCESSING_BATCH_SIZE=32
POSE_MAX_PERSONS=10
# =============================================================================
# STREAMING SETTINGS
# =============================================================================
# Real-time streaming configuration
STREAM_FPS=30
STREAM_BUFFER_SIZE=100
WEBSOCKET_PING_INTERVAL=60
WEBSOCKET_TIMEOUT=300
# =============================================================================
# FEATURE FLAGS
# =============================================================================
# Enable/disable features
ENABLE_AUTHENTICATION=false # Set to true for production
ENABLE_RATE_LIMITING=false # Set to true for production
ENABLE_WEBSOCKETS=true
ENABLE_REAL_TIME_PROCESSING=true
ENABLE_HISTORICAL_DATA=true
# Development features
ENABLE_TEST_ENDPOINTS=true # Set to false for production
# =============================================================================
# RATE LIMITING SETTINGS
# =============================================================================
# Rate limiting configuration
RATE_LIMIT_REQUESTS=100
RATE_LIMIT_AUTHENTICATED_REQUESTS=1000
RATE_LIMIT_WINDOW=3600 # Window in seconds
# =============================================================================
# LOGGING SETTINGS
# =============================================================================
# Logging configuration
LOG_LEVEL=INFO # Options: DEBUG, INFO, WARNING, ERROR, CRITICAL
LOG_FORMAT=%(asctime)s - %(name)s - %(levelname)s - %(message)s
# LOG_FILE=/path/to/logfile.log # Optional: specify log file path
LOG_MAX_SIZE=10485760 # 10MB
LOG_BACKUP_COUNT=5
# =============================================================================
# STORAGE SETTINGS
# =============================================================================
# Storage directories
DATA_STORAGE_PATH=./data
MODEL_STORAGE_PATH=./models
TEMP_STORAGE_PATH=./temp
MAX_STORAGE_SIZE_GB=100
# =============================================================================
# MONITORING SETTINGS
# =============================================================================
# Monitoring and metrics
METRICS_ENABLED=true
HEALTH_CHECK_INTERVAL=30
PERFORMANCE_MONITORING=true
# =============================================================================
# API SETTINGS
# =============================================================================
# API configuration
API_PREFIX=/api/v1
DOCS_URL=/docs # Set to null to disable in production
REDOC_URL=/redoc # Set to null to disable in production
OPENAPI_URL=/openapi.json # Set to null to disable in production
# =============================================================================
# PRODUCTION SETTINGS
# =============================================================================
# For production deployment, ensure you:
# 1. Set ENVIRONMENT=production
# 2. Set DEBUG=false
# 3. Use a strong SECRET_KEY
# 4. Configure proper DATABASE_URL
# 5. Restrict ALLOWED_HOSTS and CORS_ORIGINS
# 6. Enable ENABLE_AUTHENTICATION=true
# 7. Enable ENABLE_RATE_LIMITING=true
# 8. Set ENABLE_TEST_ENDPOINTS=false
# 9. Disable API documentation URLs (set to null)
# 10. Configure proper logging with LOG_FILE
# Example production settings:
# ENVIRONMENT=production
# DEBUG=false
# SECRET_KEY=your-very-secure-secret-key-here
# DATABASE_URL=postgresql://user:password@db-host:5432/wifi_densepose
# REDIS_URL=redis://redis-host:6379/0
# ALLOWED_HOSTS=yourdomain.com,api.yourdomain.com
# CORS_ORIGINS=https://yourdomain.com,https://app.yourdomain.com
# ENABLE_AUTHENTICATION=true
# ENABLE_RATE_LIMITING=true
# ENABLE_TEST_ENDPOINTS=false
# DOCS_URL=null
# REDOC_URL=null
# OPENAPI_URL=null
# LOG_FILE=/var/log/wifi-densepose/app.log

View File

@@ -17,6 +17,9 @@ fastapi>=0.95.0
uvicorn>=0.20.0
websockets>=10.4
pydantic>=1.10.0
python-jose[cryptography]>=3.3.0
python-multipart>=0.0.6
passlib[bcrypt]>=1.7.4
# Hardware interface dependencies
asyncio-mqtt>=0.11.0

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

376
scripts/test_api_endpoints.py Executable file
View File

@@ -0,0 +1,376 @@
#!/usr/bin/env python3
"""
API Endpoint Testing Script
Tests all WiFi-DensePose API endpoints and provides debugging information.
"""
import asyncio
import json
import sys
import time
import traceback
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional
import aiohttp
import websockets
from colorama import Fore, Style, init
# Initialize colorama for colored output
init(autoreset=True)
class APITester:
"""Comprehensive API endpoint tester."""
def __init__(self, base_url: str = "http://localhost:8000"):
self.base_url = base_url
self.session = None
self.results = {
"total_tests": 0,
"passed": 0,
"failed": 0,
"errors": [],
"test_details": []
}
async def __aenter__(self):
"""Async context manager entry."""
self.session = aiohttp.ClientSession()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
if self.session:
await self.session.close()
def log_success(self, message: str):
"""Log success message."""
print(f"{Fore.GREEN}{message}{Style.RESET_ALL}")
def log_error(self, message: str):
"""Log error message."""
print(f"{Fore.RED}{message}{Style.RESET_ALL}")
def log_info(self, message: str):
"""Log info message."""
print(f"{Fore.BLUE} {message}{Style.RESET_ALL}")
def log_warning(self, message: str):
"""Log warning message."""
print(f"{Fore.YELLOW}{message}{Style.RESET_ALL}")
async def test_endpoint(
self,
method: str,
endpoint: str,
expected_status: int = 200,
data: Optional[Dict] = None,
params: Optional[Dict] = None,
headers: Optional[Dict] = None,
description: str = ""
) -> Dict[str, Any]:
"""Test a single API endpoint."""
self.results["total_tests"] += 1
test_name = f"{method.upper()} {endpoint}"
try:
url = f"{self.base_url}{endpoint}"
# Prepare request
kwargs = {}
if data:
kwargs["json"] = data
if params:
kwargs["params"] = params
if headers:
kwargs["headers"] = headers
# Make request
start_time = time.time()
async with self.session.request(method, url, **kwargs) as response:
response_time = (time.time() - start_time) * 1000
response_text = await response.text()
# Try to parse JSON response
try:
response_data = json.loads(response_text) if response_text else {}
except json.JSONDecodeError:
response_data = {"raw_response": response_text}
# Check status code
status_ok = response.status == expected_status
test_result = {
"test_name": test_name,
"description": description,
"url": url,
"method": method.upper(),
"expected_status": expected_status,
"actual_status": response.status,
"response_time_ms": round(response_time, 2),
"response_data": response_data,
"success": status_ok,
"timestamp": datetime.now().isoformat()
}
if status_ok:
self.results["passed"] += 1
self.log_success(f"{test_name} - {response.status} ({response_time:.1f}ms)")
if description:
print(f" {description}")
else:
self.results["failed"] += 1
self.log_error(f"{test_name} - Expected {expected_status}, got {response.status}")
if description:
print(f" {description}")
print(f" Response: {response_text[:200]}...")
self.results["test_details"].append(test_result)
return test_result
except Exception as e:
self.results["failed"] += 1
error_msg = f"{test_name} - Exception: {str(e)}"
self.log_error(error_msg)
test_result = {
"test_name": test_name,
"description": description,
"url": f"{self.base_url}{endpoint}",
"method": method.upper(),
"expected_status": expected_status,
"actual_status": None,
"response_time_ms": None,
"response_data": None,
"success": False,
"error": str(e),
"traceback": traceback.format_exc(),
"timestamp": datetime.now().isoformat()
}
self.results["errors"].append(error_msg)
self.results["test_details"].append(test_result)
return test_result
async def test_websocket_endpoint(self, endpoint: str, description: str = "") -> Dict[str, Any]:
"""Test WebSocket endpoint."""
self.results["total_tests"] += 1
test_name = f"WebSocket {endpoint}"
try:
ws_url = f"ws://localhost:8000{endpoint}"
start_time = time.time()
async with websockets.connect(ws_url) as websocket:
# Send a test message
test_message = {"type": "subscribe", "zone_ids": ["zone_1"]}
await websocket.send(json.dumps(test_message))
# Wait for response
response = await asyncio.wait_for(websocket.recv(), timeout=3)
response_time = (time.time() - start_time) * 1000
try:
response_data = json.loads(response)
except json.JSONDecodeError:
response_data = {"raw_response": response}
test_result = {
"test_name": test_name,
"description": description,
"url": ws_url,
"method": "WebSocket",
"response_time_ms": round(response_time, 2),
"response_data": response_data,
"success": True,
"timestamp": datetime.now().isoformat()
}
self.results["passed"] += 1
self.log_success(f"{test_name} - Connected ({response_time:.1f}ms)")
if description:
print(f" {description}")
self.results["test_details"].append(test_result)
return test_result
except Exception as e:
self.results["failed"] += 1
error_msg = f"{test_name} - Exception: {str(e)}"
self.log_error(error_msg)
test_result = {
"test_name": test_name,
"description": description,
"url": f"ws://localhost:8000{endpoint}",
"method": "WebSocket",
"response_time_ms": None,
"response_data": None,
"success": False,
"error": str(e),
"traceback": traceback.format_exc(),
"timestamp": datetime.now().isoformat()
}
self.results["errors"].append(error_msg)
self.results["test_details"].append(test_result)
return test_result
async def run_all_tests(self):
"""Run all API endpoint tests."""
print(f"{Fore.CYAN}{'='*60}")
print(f"{Fore.CYAN}WiFi-DensePose API Endpoint Testing")
print(f"{Fore.CYAN}{'='*60}{Style.RESET_ALL}")
print()
# Test Health Endpoints
print(f"{Fore.MAGENTA}Testing Health Endpoints:{Style.RESET_ALL}")
await self.test_endpoint("GET", "/health/health", description="System health check")
await self.test_endpoint("GET", "/health/ready", description="Readiness check")
print()
# Test Pose Estimation Endpoints
print(f"{Fore.MAGENTA}Testing Pose Estimation Endpoints:{Style.RESET_ALL}")
await self.test_endpoint("GET", "/api/v1/pose/current", description="Current pose estimation")
await self.test_endpoint("GET", "/api/v1/pose/current",
params={"zone_ids": ["zone_1"], "confidence_threshold": 0.7},
description="Current pose estimation with parameters")
await self.test_endpoint("POST", "/api/v1/pose/analyze", description="Pose analysis (requires auth)")
await self.test_endpoint("GET", "/api/v1/pose/zones/zone_1/occupancy", description="Zone occupancy")
await self.test_endpoint("GET", "/api/v1/pose/zones/summary", description="All zones summary")
print()
# Test Historical Data Endpoints
print(f"{Fore.MAGENTA}Testing Historical Data Endpoints:{Style.RESET_ALL}")
end_time = datetime.now()
start_time = end_time - timedelta(hours=1)
historical_data = {
"start_time": start_time.isoformat(),
"end_time": end_time.isoformat(),
"zone_ids": ["zone_1"],
"aggregation_interval": 300
}
await self.test_endpoint("POST", "/api/v1/pose/historical",
data=historical_data,
description="Historical pose data (requires auth)")
await self.test_endpoint("GET", "/api/v1/pose/activities", description="Recent activities")
await self.test_endpoint("GET", "/api/v1/pose/activities",
params={"zone_id": "zone_1", "limit": 5},
description="Activities for specific zone")
print()
# Test Calibration Endpoints
print(f"{Fore.MAGENTA}Testing Calibration Endpoints:{Style.RESET_ALL}")
await self.test_endpoint("GET", "/api/v1/pose/calibration/status", description="Calibration status (requires auth)")
await self.test_endpoint("POST", "/api/v1/pose/calibrate", description="Start calibration (requires auth)")
print()
# Test Statistics Endpoints
print(f"{Fore.MAGENTA}Testing Statistics Endpoints:{Style.RESET_ALL}")
await self.test_endpoint("GET", "/api/v1/pose/stats", description="Pose statistics")
await self.test_endpoint("GET", "/api/v1/pose/stats",
params={"hours": 12}, description="Pose statistics (12 hours)")
print()
# Test Stream Endpoints
print(f"{Fore.MAGENTA}Testing Stream Endpoints:{Style.RESET_ALL}")
await self.test_endpoint("GET", "/api/v1/stream/status", description="Stream status")
await self.test_endpoint("POST", "/api/v1/stream/start", description="Start streaming (requires auth)")
await self.test_endpoint("POST", "/api/v1/stream/stop", description="Stop streaming (requires auth)")
print()
# Test WebSocket Endpoints
print(f"{Fore.MAGENTA}Testing WebSocket Endpoints:{Style.RESET_ALL}")
await self.test_websocket_endpoint("/ws/pose", description="Pose WebSocket")
await self.test_websocket_endpoint("/ws/hardware", description="Hardware WebSocket")
print()
# Test Documentation Endpoints
print(f"{Fore.MAGENTA}Testing Documentation Endpoints:{Style.RESET_ALL}")
await self.test_endpoint("GET", "/docs", description="API documentation")
await self.test_endpoint("GET", "/openapi.json", description="OpenAPI schema")
print()
# Test API Info Endpoints
print(f"{Fore.MAGENTA}Testing API Info Endpoints:{Style.RESET_ALL}")
await self.test_endpoint("GET", "/", description="Root endpoint")
await self.test_endpoint("GET", "/api/v1/info", description="API information")
await self.test_endpoint("GET", "/api/v1/status", description="API status")
print()
# Test Error Cases
print(f"{Fore.MAGENTA}Testing Error Cases:{Style.RESET_ALL}")
await self.test_endpoint("GET", "/nonexistent", expected_status=404,
description="Non-existent endpoint")
await self.test_endpoint("POST", "/api/v1/pose/analyze",
data={"invalid": "data"}, expected_status=401,
description="Unauthorized request (no auth)")
print()
def print_summary(self):
"""Print test summary."""
print(f"{Fore.CYAN}{'='*60}")
print(f"{Fore.CYAN}Test Summary")
print(f"{Fore.CYAN}{'='*60}{Style.RESET_ALL}")
total = self.results["total_tests"]
passed = self.results["passed"]
failed = self.results["failed"]
success_rate = (passed / total * 100) if total > 0 else 0
print(f"Total Tests: {total}")
print(f"{Fore.GREEN}Passed: {passed}{Style.RESET_ALL}")
print(f"{Fore.RED}Failed: {failed}{Style.RESET_ALL}")
print(f"Success Rate: {success_rate:.1f}%")
print()
if self.results["errors"]:
print(f"{Fore.RED}Errors:{Style.RESET_ALL}")
for error in self.results["errors"]:
print(f" - {error}")
print()
# Save detailed results to file
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results_file = f"scripts/api_test_results_{timestamp}.json"
try:
with open(results_file, 'w') as f:
json.dump(self.results, f, indent=2, default=str)
print(f"Detailed results saved to: {results_file}")
except Exception as e:
self.log_warning(f"Could not save results file: {e}")
return failed == 0
async def main():
"""Main test function."""
try:
async with APITester() as tester:
await tester.run_all_tests()
success = tester.print_summary()
# Exit with appropriate code
sys.exit(0 if success else 1)
except KeyboardInterrupt:
print(f"\n{Fore.YELLOW}Tests interrupted by user{Style.RESET_ALL}")
sys.exit(1)
except Exception as e:
print(f"\n{Fore.RED}Fatal error: {e}{Style.RESET_ALL}")
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
# Check if required packages are available
try:
import aiohttp
import websockets
import colorama
except ImportError as e:
print(f"Missing required package: {e}")
print("Install with: pip install aiohttp websockets colorama")
sys.exit(1)
# Run tests
asyncio.run(main())

View File

@@ -246,8 +246,15 @@ if __name__ != '__main__':
# Compatibility aliases for backward compatibility
WifiDensePose = app # Legacy alias
get_config = get_settings # Legacy alias
try:
WifiDensePose = app # Legacy alias
except NameError:
WifiDensePose = None # Will be None if app import failed
try:
get_config = get_settings # Legacy alias
except NameError:
get_config = None # Will be None if get_settings import failed
def main():

View File

@@ -2,6 +2,6 @@
WiFi-DensePose FastAPI application package
"""
from .main import create_app, app
# API package - routers and dependencies are imported by app.py
__all__ = ["create_app", "app"]
__all__ = []

View File

@@ -418,6 +418,21 @@ async def get_websocket_user(
return None
async def get_current_user_ws(
websocket_token: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""Get current user for WebSocket connections."""
return await get_websocket_user(websocket_token)
# Authentication requirement dependencies
async def require_auth(
current_user: Dict[str, Any] = Depends(get_current_active_user)
) -> Dict[str, Any]:
"""Require authentication for endpoint access."""
return current_user
# Development dependencies
async def development_only():
"""Dependency that only allows access in development."""

View File

@@ -7,18 +7,11 @@ import psutil
from typing import Dict, Any, Optional
from datetime import datetime, timedelta
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel, Field
from src.api.dependencies import (
get_hardware_service,
get_pose_service,
get_stream_service,
get_current_user
)
from src.services.hardware_service import HardwareService
from src.services.pose_service import PoseService
from src.services.stream_service import StreamService
from src.api.dependencies import get_current_user
from src.services.orchestrator import ServiceOrchestrator
from src.config.settings import get_settings
logger = logging.getLogger(__name__)
@@ -58,20 +51,19 @@ class ReadinessCheck(BaseModel):
# Health check endpoints
@router.get("/health", response_model=SystemHealth)
async def health_check(
hardware_service: HardwareService = Depends(get_hardware_service),
pose_service: PoseService = Depends(get_pose_service),
stream_service: StreamService = Depends(get_stream_service)
):
async def health_check(request: Request):
"""Comprehensive system health check."""
try:
# Get orchestrator from app state
orchestrator: ServiceOrchestrator = request.app.state.orchestrator
timestamp = datetime.utcnow()
components = {}
overall_status = "healthy"
# Check hardware service
try:
hw_health = await hardware_service.health_check()
hw_health = await orchestrator.hardware_service.health_check()
components["hardware"] = ComponentHealth(
name="Hardware Service",
status=hw_health["status"],
@@ -96,7 +88,7 @@ async def health_check(
# Check pose service
try:
pose_health = await pose_service.health_check()
pose_health = await orchestrator.pose_service.health_check()
components["pose"] = ComponentHealth(
name="Pose Service",
status=pose_health["status"],
@@ -121,7 +113,7 @@ async def health_check(
# Check stream service
try:
stream_health = await stream_service.health_check()
stream_health = await orchestrator.stream_service.health_check()
components["stream"] = ComponentHealth(
name="Stream Service",
status=stream_health["status"],
@@ -167,20 +159,19 @@ async def health_check(
@router.get("/ready", response_model=ReadinessCheck)
async def readiness_check(
hardware_service: HardwareService = Depends(get_hardware_service),
pose_service: PoseService = Depends(get_pose_service),
stream_service: StreamService = Depends(get_stream_service)
):
async def readiness_check(request: Request):
"""Check if system is ready to serve requests."""
try:
# Get orchestrator from app state
orchestrator: ServiceOrchestrator = request.app.state.orchestrator
timestamp = datetime.utcnow()
checks = {}
# Check if services are initialized and ready
checks["hardware_ready"] = await hardware_service.is_ready()
checks["pose_ready"] = await pose_service.is_ready()
checks["stream_ready"] = await stream_service.is_ready()
checks["hardware_ready"] = await orchestrator.hardware_service.is_ready()
checks["pose_ready"] = await orchestrator.pose_service.is_ready()
checks["stream_ready"] = await orchestrator.stream_service.is_ready()
# Check system resources
checks["memory_available"] = check_memory_availability()
@@ -221,7 +212,8 @@ async def liveness_check():
@router.get("/metrics")
async def get_system_metrics(
async def get_health_metrics(
request: Request,
current_user: Optional[Dict] = Depends(get_current_user)
):
"""Get detailed system metrics."""

View File

@@ -73,7 +73,8 @@ async def websocket_pose_stream(
websocket: WebSocket,
zone_ids: Optional[str] = Query(None, description="Comma-separated zone IDs"),
min_confidence: float = Query(0.5, ge=0.0, le=1.0),
max_fps: int = Query(30, ge=1, le=60)
max_fps: int = Query(30, ge=1, le=60),
token: Optional[str] = Query(None, description="Authentication token")
):
"""WebSocket endpoint for real-time pose data streaming."""
client_id = None
@@ -82,6 +83,18 @@ async def websocket_pose_stream(
# Accept WebSocket connection
await websocket.accept()
# Check authentication if enabled
from src.config.settings import get_settings
settings = get_settings()
if settings.enable_authentication and not token:
await websocket.send_json({
"type": "error",
"message": "Authentication token required"
})
await websocket.close(code=1008)
return
# Parse zone IDs
zone_list = None
if zone_ids:
@@ -146,7 +159,8 @@ async def websocket_pose_stream(
async def websocket_events_stream(
websocket: WebSocket,
event_types: Optional[str] = Query(None, description="Comma-separated event types"),
zone_ids: Optional[str] = Query(None, description="Comma-separated zone IDs")
zone_ids: Optional[str] = Query(None, description="Comma-separated zone IDs"),
token: Optional[str] = Query(None, description="Authentication token")
):
"""WebSocket endpoint for real-time event streaming."""
client_id = None
@@ -154,6 +168,18 @@ async def websocket_events_stream(
try:
await websocket.accept()
# Check authentication if enabled
from src.config.settings import get_settings
settings = get_settings()
if settings.enable_authentication and not token:
await websocket.send_json({
"type": "error",
"message": "Authentication token required"
})
await websocket.close(code=1008)
return
# Parse parameters
event_list = None
if event_types:
@@ -244,19 +270,27 @@ async def handle_websocket_message(client_id: str, data: Dict[str, Any], websock
# HTTP endpoints for stream management
@router.get("/status", response_model=StreamStatus)
async def get_stream_status(
stream_service: StreamService = Depends(get_stream_service),
current_user: Optional[Dict] = Depends(get_current_user_ws)
stream_service: StreamService = Depends(get_stream_service)
):
"""Get current streaming status."""
try:
status = await stream_service.get_status()
connections = await connection_manager.get_connection_stats()
# Calculate uptime (simplified for now)
uptime_seconds = 0.0
if status.get("running", False):
uptime_seconds = 3600.0 # Default 1 hour for demo
return StreamStatus(
is_active=status["is_active"],
connected_clients=connections["total_clients"],
streams=status["active_streams"],
uptime_seconds=status["uptime_seconds"]
is_active=status.get("running", False),
connected_clients=connections.get("total_clients", status["connections"]["active"]),
streams=[{
"type": "pose_stream",
"active": status.get("running", False),
"buffer_size": status["buffers"]["pose_buffer_size"]
}],
uptime_seconds=uptime_seconds
)
except Exception as e:
@@ -416,9 +450,7 @@ async def broadcast_message(
@router.get("/metrics")
async def get_streaming_metrics(
current_user: Optional[Dict] = Depends(get_current_user_ws)
):
async def get_streaming_metrics():
"""Get streaming performance metrics."""
try:
metrics = await connection_manager.get_metrics()

View File

@@ -120,7 +120,7 @@ class ConnectionManager:
"start_time": datetime.utcnow()
}
self._cleanup_task = None
self._start_cleanup_task()
self._started = False
async def connect(
self,
@@ -413,6 +413,13 @@ class ConnectionManager:
if stale_clients:
logger.info(f"Cleaned up {len(stale_clients)} stale connections")
async def start(self):
"""Start the connection manager."""
if not self._started:
self._start_cleanup_task()
self._started = True
logger.info("Connection manager started")
def _start_cleanup_task(self):
"""Start background cleanup task."""
async def cleanup_loop():
@@ -428,7 +435,11 @@ class ConnectionManager:
except Exception as e:
logger.error(f"Error in cleanup task: {e}")
try:
self._cleanup_task = asyncio.create_task(cleanup_loop())
except RuntimeError:
# No event loop running, will start later
logger.debug("No event loop running, cleanup task will start later")
async def shutdown(self):
"""Shutdown connection manager."""

View File

@@ -3,6 +3,7 @@ FastAPI application factory and configuration
"""
import logging
import os
from contextlib import asynccontextmanager
from typing import Optional
@@ -15,10 +16,10 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
from src.config.settings import Settings
from src.services.orchestrator import ServiceOrchestrator
from src.middleware.auth import AuthMiddleware
from src.middleware.cors import setup_cors
from src.middleware.auth import AuthenticationMiddleware
from fastapi.middleware.cors import CORSMiddleware
from src.middleware.rate_limit import RateLimitMiddleware
from src.middleware.error_handler import ErrorHandlerMiddleware
from src.middleware.error_handler import ErrorHandlingMiddleware
from src.api.routers import pose, stream, health
from src.api.websocket.connection_manager import connection_manager
@@ -34,6 +35,9 @@ async def lifespan(app: FastAPI):
# Get orchestrator from app state
orchestrator: ServiceOrchestrator = app.state.orchestrator
# Start connection manager
await connection_manager.start()
# Start all services
await orchestrator.start()
@@ -47,6 +51,10 @@ async def lifespan(app: FastAPI):
finally:
# Cleanup on shutdown
logger.info("Shutting down WiFi-DensePose API...")
# Shutdown connection manager
await connection_manager.shutdown()
if hasattr(app.state, 'orchestrator'):
await app.state.orchestrator.shutdown()
logger.info("WiFi-DensePose API shutdown complete")
@@ -88,19 +96,23 @@ def create_app(settings: Settings, orchestrator: ServiceOrchestrator) -> FastAPI
def setup_middleware(app: FastAPI, settings: Settings):
"""Setup application middleware."""
# Error handling middleware (should be first)
app.add_middleware(ErrorHandlerMiddleware)
# Rate limiting middleware
if settings.enable_rate_limiting:
app.add_middleware(RateLimitMiddleware, settings=settings)
# Authentication middleware
if settings.enable_authentication:
app.add_middleware(AuthMiddleware, settings=settings)
app.add_middleware(AuthenticationMiddleware, settings=settings)
# CORS middleware
setup_cors(app, settings)
if settings.cors_enabled:
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=settings.cors_allow_credentials,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
allow_headers=["*"],
)
# Trusted host middleware for production
if settings.is_production:

View File

@@ -14,8 +14,9 @@ from src.commands.start import start_command
from src.commands.stop import stop_command
from src.commands.status import status_command
# Setup logging for CLI
setup_logging()
# Get default settings and setup logging for CLI
settings = get_settings()
setup_logging(settings)
logger = get_logger(__name__)
@@ -498,5 +499,10 @@ def version():
sys.exit(1)
def create_cli(orchestrator=None):
"""Create CLI interface for the application."""
return cli
if __name__ == '__main__':
cli()

View File

@@ -349,6 +349,10 @@ class DomainConfig:
return routers
def get_all_routers(self) -> List[RouterConfig]:
"""Get all router configurations."""
return list(self.routers.values())
def validate_configuration(self) -> List[str]:
"""Validate the entire configuration."""
issues = []

View File

@@ -97,6 +97,8 @@ class Settings(BaseSettings):
enable_websockets: bool = Field(default=True, description="Enable WebSocket support")
enable_historical_data: bool = Field(default=True, description="Enable historical data storage")
enable_real_time_processing: bool = Field(default=True, description="Enable real-time processing")
cors_enabled: bool = Field(default=True, description="Enable CORS middleware")
cors_allow_credentials: bool = Field(default=True, description="Allow credentials in CORS")
# Development settings
mock_hardware: bool = Field(default=False, description="Use mock hardware for development")

View File

@@ -0,0 +1,13 @@
"""
Core package for WiFi-DensePose API
"""
from .csi_processor import CSIProcessor
from .phase_sanitizer import PhaseSanitizer
from .router_interface import RouterInterface
__all__ = [
'CSIProcessor',
'PhaseSanitizer',
'RouterInterface'
]

View File

@@ -2,7 +2,9 @@
import numpy as np
import torch
from typing import Dict, Any, Optional
from typing import Dict, Any, Optional, List
from datetime import datetime
from collections import deque
class CSIProcessor:
@@ -18,6 +20,11 @@ class CSIProcessor:
self.sample_rate = self.config.get('sample_rate', 1000)
self.num_subcarriers = self.config.get('num_subcarriers', 56)
self.num_antennas = self.config.get('num_antennas', 3)
self.buffer_size = self.config.get('buffer_size', 1000)
# Data buffer for temporal processing
self.data_buffer = deque(maxlen=self.buffer_size)
self.last_processed_data = None
def process_raw_csi(self, raw_data: np.ndarray) -> np.ndarray:
"""Process raw CSI data into normalized format.
@@ -77,3 +84,46 @@ class CSIProcessor:
# Convert to tensor
return torch.from_numpy(processed_data).float()
def add_data(self, csi_data: np.ndarray, timestamp: datetime):
"""Add CSI data to the processing buffer.
Args:
csi_data: Raw CSI data array
timestamp: Timestamp of the data sample
"""
sample = {
'data': csi_data,
'timestamp': timestamp,
'processed': False
}
self.data_buffer.append(sample)
def get_processed_data(self) -> Optional[np.ndarray]:
"""Get the most recent processed CSI data.
Returns:
Processed CSI data array or None if no data available
"""
if not self.data_buffer:
return None
# Get the most recent unprocessed sample
recent_sample = None
for sample in reversed(self.data_buffer):
if not sample['processed']:
recent_sample = sample
break
if recent_sample is None:
return self.last_processed_data
# Process the data
try:
processed_data = self.process_raw_csi(recent_sample['data'])
recent_sample['processed'] = True
self.last_processed_data = processed_data
return processed_data
except Exception as e:
# Return last known good data if processing fails
return self.last_processed_data

View File

@@ -0,0 +1,340 @@
"""
Router interface for WiFi CSI data collection
"""
import logging
import asyncio
import time
from typing import Dict, List, Optional, Any
from datetime import datetime
import numpy as np
logger = logging.getLogger(__name__)
class RouterInterface:
"""Interface for connecting to WiFi routers and collecting CSI data."""
def __init__(
self,
router_id: str,
host: str,
port: int = 22,
username: str = "admin",
password: str = "",
interface: str = "wlan0",
mock_mode: bool = False
):
"""Initialize router interface.
Args:
router_id: Unique identifier for the router
host: Router IP address or hostname
port: SSH port for connection
username: SSH username
password: SSH password
interface: WiFi interface name
mock_mode: Whether to use mock data instead of real connection
"""
self.router_id = router_id
self.host = host
self.port = port
self.username = username
self.password = password
self.interface = interface
self.mock_mode = mock_mode
self.logger = logging.getLogger(f"{__name__}.{router_id}")
# Connection state
self.is_connected = False
self.connection = None
self.last_error = None
# Data collection state
self.last_data_time = None
self.error_count = 0
self.sample_count = 0
# Mock data generation
self.mock_data_generator = None
if mock_mode:
self._initialize_mock_generator()
def _initialize_mock_generator(self):
"""Initialize mock data generator."""
self.mock_data_generator = {
'phase': 0,
'amplitude_base': 1.0,
'frequency': 0.1,
'noise_level': 0.1
}
async def connect(self):
"""Connect to the router."""
if self.mock_mode:
self.is_connected = True
self.logger.info(f"Mock connection established to router {self.router_id}")
return
try:
self.logger.info(f"Connecting to router {self.router_id} at {self.host}:{self.port}")
# In a real implementation, this would establish SSH connection
# For now, we'll simulate the connection
await asyncio.sleep(0.1) # Simulate connection delay
self.is_connected = True
self.error_count = 0
self.logger.info(f"Connected to router {self.router_id}")
except Exception as e:
self.last_error = str(e)
self.error_count += 1
self.logger.error(f"Failed to connect to router {self.router_id}: {e}")
raise
async def disconnect(self):
"""Disconnect from the router."""
try:
if self.connection:
# Close SSH connection
self.connection = None
self.is_connected = False
self.logger.info(f"Disconnected from router {self.router_id}")
except Exception as e:
self.logger.error(f"Error disconnecting from router {self.router_id}: {e}")
async def reconnect(self):
"""Reconnect to the router."""
await self.disconnect()
await asyncio.sleep(1) # Wait before reconnecting
await self.connect()
async def get_csi_data(self) -> Optional[np.ndarray]:
"""Get CSI data from the router.
Returns:
CSI data as numpy array, or None if no data available
"""
if not self.is_connected:
raise RuntimeError(f"Router {self.router_id} is not connected")
try:
if self.mock_mode:
csi_data = self._generate_mock_csi_data()
else:
csi_data = await self._collect_real_csi_data()
if csi_data is not None:
self.last_data_time = datetime.now()
self.sample_count += 1
self.error_count = 0
return csi_data
except Exception as e:
self.last_error = str(e)
self.error_count += 1
self.logger.error(f"Error getting CSI data from router {self.router_id}: {e}")
return None
def _generate_mock_csi_data(self) -> np.ndarray:
"""Generate mock CSI data for testing."""
# Simulate CSI data with realistic characteristics
num_subcarriers = 64
num_antennas = 4
num_samples = 100
# Update mock generator state
self.mock_data_generator['phase'] += self.mock_data_generator['frequency']
# Generate amplitude and phase data
time_axis = np.linspace(0, 1, num_samples)
# Create realistic CSI patterns
csi_data = np.zeros((num_antennas, num_subcarriers, num_samples), dtype=complex)
for antenna in range(num_antennas):
for subcarrier in range(num_subcarriers):
# Base signal with some variation per antenna/subcarrier
amplitude = (
self.mock_data_generator['amplitude_base'] *
(1 + 0.2 * np.sin(2 * np.pi * subcarrier / num_subcarriers)) *
(1 + 0.1 * antenna)
)
# Phase with spatial and frequency variation
phase_offset = (
self.mock_data_generator['phase'] +
2 * np.pi * subcarrier / num_subcarriers +
np.pi * antenna / num_antennas
)
# Add some movement simulation
movement_freq = 0.5 # Hz
movement_amplitude = 0.3
movement = movement_amplitude * np.sin(2 * np.pi * movement_freq * time_axis)
# Generate complex signal
signal_amplitude = amplitude * (1 + movement)
signal_phase = phase_offset + movement * 0.5
# Add noise
noise_real = np.random.normal(0, self.mock_data_generator['noise_level'], num_samples)
noise_imag = np.random.normal(0, self.mock_data_generator['noise_level'], num_samples)
noise = noise_real + 1j * noise_imag
# Create complex signal
signal = signal_amplitude * np.exp(1j * signal_phase) + noise
csi_data[antenna, subcarrier, :] = signal
return csi_data
async def _collect_real_csi_data(self) -> Optional[np.ndarray]:
"""Collect real CSI data from router (placeholder implementation)."""
# This would implement the actual CSI data collection
# For now, return None to indicate no real implementation
self.logger.warning("Real CSI data collection not implemented")
return None
async def check_health(self) -> bool:
"""Check if the router connection is healthy.
Returns:
True if healthy, False otherwise
"""
if not self.is_connected:
return False
try:
# In mock mode, always healthy
if self.mock_mode:
return True
# For real connections, we could ping the router or check SSH connection
# For now, consider healthy if error count is low
return self.error_count < 5
except Exception as e:
self.logger.error(f"Error checking health of router {self.router_id}: {e}")
return False
async def get_status(self) -> Dict[str, Any]:
"""Get router status information.
Returns:
Dictionary containing router status
"""
return {
"router_id": self.router_id,
"connected": self.is_connected,
"mock_mode": self.mock_mode,
"last_data_time": self.last_data_time.isoformat() if self.last_data_time else None,
"error_count": self.error_count,
"sample_count": self.sample_count,
"last_error": self.last_error,
"configuration": {
"host": self.host,
"port": self.port,
"username": self.username,
"interface": self.interface
}
}
async def get_router_info(self) -> Dict[str, Any]:
"""Get router hardware information.
Returns:
Dictionary containing router information
"""
if self.mock_mode:
return {
"model": "Mock Router",
"firmware": "1.0.0-mock",
"wifi_standard": "802.11ac",
"antennas": 4,
"supported_bands": ["2.4GHz", "5GHz"],
"csi_capabilities": {
"max_subcarriers": 64,
"max_antennas": 4,
"sampling_rate": 1000
}
}
# For real routers, this would query the actual hardware
return {
"model": "Unknown",
"firmware": "Unknown",
"wifi_standard": "Unknown",
"antennas": 1,
"supported_bands": ["Unknown"],
"csi_capabilities": {
"max_subcarriers": 64,
"max_antennas": 1,
"sampling_rate": 100
}
}
async def configure_csi_collection(self, config: Dict[str, Any]) -> bool:
"""Configure CSI data collection parameters.
Args:
config: Configuration dictionary
Returns:
True if configuration successful, False otherwise
"""
try:
if self.mock_mode:
# Update mock generator parameters
if 'sampling_rate' in config:
self.mock_data_generator['frequency'] = config['sampling_rate'] / 1000.0
if 'noise_level' in config:
self.mock_data_generator['noise_level'] = config['noise_level']
self.logger.info(f"Mock CSI collection configured for router {self.router_id}")
return True
# For real routers, this would send configuration commands
self.logger.warning("Real CSI configuration not implemented")
return False
except Exception as e:
self.logger.error(f"Error configuring CSI collection for router {self.router_id}: {e}")
return False
def get_metrics(self) -> Dict[str, Any]:
"""Get router interface metrics.
Returns:
Dictionary containing metrics
"""
uptime = 0
if self.last_data_time:
uptime = (datetime.now() - self.last_data_time).total_seconds()
success_rate = 0
if self.sample_count > 0:
success_rate = (self.sample_count - self.error_count) / self.sample_count
return {
"router_id": self.router_id,
"sample_count": self.sample_count,
"error_count": self.error_count,
"success_rate": success_rate,
"uptime_seconds": uptime,
"is_connected": self.is_connected,
"mock_mode": self.mock_mode
}
def reset_stats(self):
"""Reset statistics counters."""
self.error_count = 0
self.sample_count = 0
self.last_error = None
self.logger.info(f"Statistics reset for router {self.router_id}")

View File

@@ -307,31 +307,34 @@ class ErrorHandler:
class ErrorHandlingMiddleware:
"""Error handling middleware for FastAPI."""
def __init__(self, settings: Settings):
def __init__(self, app, settings: Settings):
self.app = app
self.settings = settings
self.error_handler = ErrorHandler(settings)
async def __call__(self, request: Request, call_next: Callable) -> Response:
async def __call__(self, scope, receive, send):
"""Process request through error handling middleware."""
if scope["type"] != "http":
await self.app(scope, receive, send)
return
start_time = time.time()
try:
response = await call_next(request)
return response
except HTTPException as exc:
error_response = self.error_handler.handle_http_exception(request, exc)
return error_response.to_response()
except RequestValidationError as exc:
error_response = self.error_handler.handle_validation_error(request, exc)
return error_response.to_response()
except ValidationError as exc:
error_response = self.error_handler.handle_pydantic_error(request, exc)
return error_response.to_response()
await self.app(scope, receive, send)
except Exception as exc:
# Create a mock request for error handling
from starlette.requests import Request
request = Request(scope, receive)
# Handle different exception types
if isinstance(exc, HTTPException):
error_response = self.error_handler.handle_http_exception(request, exc)
elif isinstance(exc, RequestValidationError):
error_response = self.error_handler.handle_validation_error(request, exc)
elif isinstance(exc, ValidationError):
error_response = self.error_handler.handle_pydantic_error(request, exc)
else:
# Check for specific error types
if self._is_database_error(exc):
error_response = self.error_handler.handle_database_error(request, exc)
@@ -340,7 +343,9 @@ class ErrorHandlingMiddleware:
else:
error_response = self.error_handler.handle_generic_exception(request, exc)
return error_response.to_response()
# Send the error response
response = error_response.to_response()
await response(scope, receive, send)
finally:
# Log request processing time
@@ -424,11 +429,10 @@ def setup_error_handling(app, settings: Settings):
return error_response.to_response()
# Add middleware for additional error handling
middleware = ErrorHandlingMiddleware(settings)
@app.middleware("http")
async def error_handling_middleware(request: Request, call_next):
return await middleware(request, call_next)
# Note: We use exception handlers instead of custom middleware to avoid ASGI conflicts
# The middleware approach is commented out but kept for reference
# middleware = ErrorHandlingMiddleware(app, settings)
# app.add_middleware(ErrorHandlingMiddleware, settings=settings)
logger.info("Error handling configured")

View File

@@ -5,9 +5,15 @@ Services package for WiFi-DensePose API
from .orchestrator import ServiceOrchestrator
from .health_check import HealthCheckService
from .metrics import MetricsService
from .pose_service import PoseService
from .stream_service import StreamService
from .hardware_service import HardwareService
__all__ = [
'ServiceOrchestrator',
'HealthCheckService',
'MetricsService'
'MetricsService',
'PoseService',
'StreamService',
'HardwareService'
]

View File

@@ -0,0 +1,483 @@
"""
Hardware interface service for WiFi-DensePose API
"""
import logging
import asyncio
import time
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta
import numpy as np
from src.config.settings import Settings
from src.config.domains import DomainConfig
from src.core.router_interface import RouterInterface
logger = logging.getLogger(__name__)
class HardwareService:
"""Service for hardware interface operations."""
def __init__(self, settings: Settings, domain_config: DomainConfig):
"""Initialize hardware service."""
self.settings = settings
self.domain_config = domain_config
self.logger = logging.getLogger(__name__)
# Router interfaces
self.router_interfaces: Dict[str, RouterInterface] = {}
# Service state
self.is_running = False
self.last_error = None
# Data collection statistics
self.stats = {
"total_samples": 0,
"successful_samples": 0,
"failed_samples": 0,
"average_sample_rate": 0.0,
"last_sample_time": None,
"connected_routers": 0
}
# Background tasks
self.collection_task = None
self.monitoring_task = None
# Data buffers
self.recent_samples = []
self.max_recent_samples = 1000
async def initialize(self):
"""Initialize the hardware service."""
await self.start()
async def start(self):
"""Start the hardware service."""
if self.is_running:
return
try:
self.logger.info("Starting hardware service...")
# Initialize router interfaces
await self._initialize_routers()
self.is_running = True
# Start background tasks
if not self.settings.mock_hardware:
self.collection_task = asyncio.create_task(self._data_collection_loop())
self.monitoring_task = asyncio.create_task(self._monitoring_loop())
self.logger.info("Hardware service started successfully")
except Exception as e:
self.last_error = str(e)
self.logger.error(f"Failed to start hardware service: {e}")
raise
async def stop(self):
"""Stop the hardware service."""
self.is_running = False
# Cancel background tasks
if self.collection_task:
self.collection_task.cancel()
try:
await self.collection_task
except asyncio.CancelledError:
pass
if self.monitoring_task:
self.monitoring_task.cancel()
try:
await self.monitoring_task
except asyncio.CancelledError:
pass
# Disconnect from routers
await self._disconnect_routers()
self.logger.info("Hardware service stopped")
async def _initialize_routers(self):
"""Initialize router interfaces."""
try:
# Get router configurations from domain config
routers = self.domain_config.get_all_routers()
for router_config in routers:
if not router_config.enabled:
continue
router_id = router_config.router_id
# Create router interface
router_interface = RouterInterface(
router_id=router_id,
host=router_config.ip_address,
port=22, # Default SSH port
username="admin", # Default username
password="admin", # Default password
interface=router_config.interface,
mock_mode=self.settings.mock_hardware
)
# Connect to router
if not self.settings.mock_hardware:
await router_interface.connect()
self.router_interfaces[router_id] = router_interface
self.logger.info(f"Router interface initialized: {router_id}")
self.stats["connected_routers"] = len(self.router_interfaces)
if not self.router_interfaces:
self.logger.warning("No router interfaces configured")
except Exception as e:
self.logger.error(f"Failed to initialize routers: {e}")
raise
async def _disconnect_routers(self):
"""Disconnect from all routers."""
for router_id, interface in self.router_interfaces.items():
try:
await interface.disconnect()
self.logger.info(f"Disconnected from router: {router_id}")
except Exception as e:
self.logger.error(f"Error disconnecting from router {router_id}: {e}")
self.router_interfaces.clear()
self.stats["connected_routers"] = 0
async def _data_collection_loop(self):
"""Background loop for data collection."""
try:
while self.is_running:
start_time = time.time()
# Collect data from all routers
await self._collect_data_from_routers()
# Calculate sleep time to maintain polling interval
elapsed = time.time() - start_time
sleep_time = max(0, self.settings.hardware_polling_interval - elapsed)
if sleep_time > 0:
await asyncio.sleep(sleep_time)
except asyncio.CancelledError:
self.logger.info("Data collection loop cancelled")
except Exception as e:
self.logger.error(f"Error in data collection loop: {e}")
self.last_error = str(e)
async def _monitoring_loop(self):
"""Background loop for hardware monitoring."""
try:
while self.is_running:
# Monitor router connections
await self._monitor_router_health()
# Update statistics
self._update_sample_rate_stats()
# Wait before next check
await asyncio.sleep(30) # Check every 30 seconds
except asyncio.CancelledError:
self.logger.info("Monitoring loop cancelled")
except Exception as e:
self.logger.error(f"Error in monitoring loop: {e}")
async def _collect_data_from_routers(self):
"""Collect CSI data from all connected routers."""
for router_id, interface in self.router_interfaces.items():
try:
# Get CSI data from router
csi_data = await interface.get_csi_data()
if csi_data is not None:
# Process the collected data
await self._process_collected_data(router_id, csi_data)
self.stats["successful_samples"] += 1
self.stats["last_sample_time"] = datetime.now().isoformat()
else:
self.stats["failed_samples"] += 1
self.stats["total_samples"] += 1
except Exception as e:
self.logger.error(f"Error collecting data from router {router_id}: {e}")
self.stats["failed_samples"] += 1
self.stats["total_samples"] += 1
async def _process_collected_data(self, router_id: str, csi_data: np.ndarray):
"""Process collected CSI data."""
try:
# Create sample metadata
metadata = {
"router_id": router_id,
"timestamp": datetime.now().isoformat(),
"sample_rate": self.stats["average_sample_rate"],
"data_shape": csi_data.shape if hasattr(csi_data, 'shape') else None
}
# Add to recent samples buffer
sample = {
"router_id": router_id,
"timestamp": metadata["timestamp"],
"data": csi_data,
"metadata": metadata
}
self.recent_samples.append(sample)
# Maintain buffer size
if len(self.recent_samples) > self.max_recent_samples:
self.recent_samples.pop(0)
# Notify other services (this would typically be done through an event system)
# For now, we'll just log the data collection
self.logger.debug(f"Collected CSI data from {router_id}: shape {csi_data.shape if hasattr(csi_data, 'shape') else 'unknown'}")
except Exception as e:
self.logger.error(f"Error processing collected data: {e}")
async def _monitor_router_health(self):
"""Monitor health of router connections."""
healthy_routers = 0
for router_id, interface in self.router_interfaces.items():
try:
is_healthy = await interface.check_health()
if is_healthy:
healthy_routers += 1
else:
self.logger.warning(f"Router {router_id} is unhealthy")
# Try to reconnect if not in mock mode
if not self.settings.mock_hardware:
try:
await interface.reconnect()
self.logger.info(f"Reconnected to router {router_id}")
except Exception as e:
self.logger.error(f"Failed to reconnect to router {router_id}: {e}")
except Exception as e:
self.logger.error(f"Error checking health of router {router_id}: {e}")
self.stats["connected_routers"] = healthy_routers
def _update_sample_rate_stats(self):
"""Update sample rate statistics."""
if len(self.recent_samples) < 2:
return
# Calculate sample rate from recent samples
recent_count = min(100, len(self.recent_samples))
recent_samples = self.recent_samples[-recent_count:]
if len(recent_samples) >= 2:
# Calculate time differences
time_diffs = []
for i in range(1, len(recent_samples)):
try:
t1 = datetime.fromisoformat(recent_samples[i-1]["timestamp"])
t2 = datetime.fromisoformat(recent_samples[i]["timestamp"])
diff = (t2 - t1).total_seconds()
if diff > 0:
time_diffs.append(diff)
except Exception:
continue
if time_diffs:
avg_interval = sum(time_diffs) / len(time_diffs)
self.stats["average_sample_rate"] = 1.0 / avg_interval if avg_interval > 0 else 0.0
async def get_router_status(self, router_id: str) -> Dict[str, Any]:
"""Get status of a specific router."""
if router_id not in self.router_interfaces:
raise ValueError(f"Router {router_id} not found")
interface = self.router_interfaces[router_id]
try:
is_healthy = await interface.check_health()
status = await interface.get_status()
return {
"router_id": router_id,
"healthy": is_healthy,
"connected": status.get("connected", False),
"last_data_time": status.get("last_data_time"),
"error_count": status.get("error_count", 0),
"configuration": status.get("configuration", {})
}
except Exception as e:
return {
"router_id": router_id,
"healthy": False,
"connected": False,
"error": str(e)
}
async def get_all_router_status(self) -> List[Dict[str, Any]]:
"""Get status of all routers."""
statuses = []
for router_id in self.router_interfaces:
try:
status = await self.get_router_status(router_id)
statuses.append(status)
except Exception as e:
statuses.append({
"router_id": router_id,
"healthy": False,
"error": str(e)
})
return statuses
async def get_recent_data(self, router_id: Optional[str] = None, limit: int = 100) -> List[Dict[str, Any]]:
"""Get recent CSI data samples."""
samples = self.recent_samples[-limit:] if limit else self.recent_samples
if router_id:
samples = [s for s in samples if s["router_id"] == router_id]
# Convert numpy arrays to lists for JSON serialization
result = []
for sample in samples:
sample_copy = sample.copy()
if isinstance(sample_copy["data"], np.ndarray):
sample_copy["data"] = sample_copy["data"].tolist()
result.append(sample_copy)
return result
async def get_status(self) -> Dict[str, Any]:
"""Get service status."""
return {
"status": "healthy" if self.is_running and not self.last_error else "unhealthy",
"running": self.is_running,
"last_error": self.last_error,
"statistics": self.stats.copy(),
"configuration": {
"mock_hardware": self.settings.mock_hardware,
"wifi_interface": self.settings.wifi_interface,
"polling_interval": self.settings.hardware_polling_interval,
"buffer_size": self.settings.csi_buffer_size
},
"routers": await self.get_all_router_status()
}
async def get_metrics(self) -> Dict[str, Any]:
"""Get service metrics."""
total_samples = self.stats["total_samples"]
success_rate = self.stats["successful_samples"] / max(1, total_samples)
return {
"hardware_service": {
"total_samples": total_samples,
"successful_samples": self.stats["successful_samples"],
"failed_samples": self.stats["failed_samples"],
"success_rate": success_rate,
"average_sample_rate": self.stats["average_sample_rate"],
"connected_routers": self.stats["connected_routers"],
"last_sample_time": self.stats["last_sample_time"]
}
}
async def reset(self):
"""Reset service state."""
self.stats = {
"total_samples": 0,
"successful_samples": 0,
"failed_samples": 0,
"average_sample_rate": 0.0,
"last_sample_time": None,
"connected_routers": len(self.router_interfaces)
}
self.recent_samples.clear()
self.last_error = None
self.logger.info("Hardware service reset")
async def trigger_manual_collection(self, router_id: Optional[str] = None) -> Dict[str, Any]:
"""Manually trigger data collection."""
if not self.is_running:
raise RuntimeError("Hardware service is not running")
results = {}
if router_id:
# Collect from specific router
if router_id not in self.router_interfaces:
raise ValueError(f"Router {router_id} not found")
interface = self.router_interfaces[router_id]
try:
csi_data = await interface.get_csi_data()
if csi_data is not None:
await self._process_collected_data(router_id, csi_data)
results[router_id] = {"success": True, "data_shape": csi_data.shape if hasattr(csi_data, 'shape') else None}
else:
results[router_id] = {"success": False, "error": "No data received"}
except Exception as e:
results[router_id] = {"success": False, "error": str(e)}
else:
# Collect from all routers
await self._collect_data_from_routers()
results = {"message": "Manual collection triggered for all routers"}
return results
async def health_check(self) -> Dict[str, Any]:
"""Perform health check."""
try:
status = "healthy" if self.is_running and not self.last_error else "unhealthy"
# Check router health
healthy_routers = 0
total_routers = len(self.router_interfaces)
for router_id, interface in self.router_interfaces.items():
try:
if await interface.check_health():
healthy_routers += 1
except Exception:
pass
return {
"status": status,
"message": self.last_error if self.last_error else "Hardware service is running normally",
"connected_routers": f"{healthy_routers}/{total_routers}",
"metrics": {
"total_samples": self.stats["total_samples"],
"success_rate": (
self.stats["successful_samples"] / max(1, self.stats["total_samples"])
),
"average_sample_rate": self.stats["average_sample_rate"]
}
}
except Exception as e:
return {
"status": "unhealthy",
"message": f"Health check failed: {str(e)}"
}
async def is_ready(self) -> bool:
"""Check if service is ready."""
return self.is_running and len(self.router_interfaces) > 0

View File

@@ -0,0 +1,706 @@
"""
Pose estimation service for WiFi-DensePose API
"""
import logging
import asyncio
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta
import numpy as np
import torch
from src.config.settings import Settings
from src.config.domains import DomainConfig
from src.core.csi_processor import CSIProcessor
from src.core.phase_sanitizer import PhaseSanitizer
from src.models.densepose_head import DensePoseHead
from src.models.modality_translation import ModalityTranslationNetwork
logger = logging.getLogger(__name__)
class PoseService:
"""Service for pose estimation operations."""
def __init__(self, settings: Settings, domain_config: DomainConfig):
"""Initialize pose service."""
self.settings = settings
self.domain_config = domain_config
self.logger = logging.getLogger(__name__)
# Initialize components
self.csi_processor = None
self.phase_sanitizer = None
self.densepose_model = None
self.modality_translator = None
# Service state
self.is_initialized = False
self.is_running = False
self.last_error = None
# Processing statistics
self.stats = {
"total_processed": 0,
"successful_detections": 0,
"failed_detections": 0,
"average_confidence": 0.0,
"processing_time_ms": 0.0
}
async def initialize(self):
"""Initialize the pose service."""
try:
self.logger.info("Initializing pose service...")
# Initialize CSI processor
csi_config = {
'buffer_size': self.settings.csi_buffer_size,
'sample_rate': 1000, # Default sampling rate
'num_subcarriers': 56,
'num_antennas': 3
}
self.csi_processor = CSIProcessor(config=csi_config)
# Initialize phase sanitizer
self.phase_sanitizer = PhaseSanitizer()
# Initialize models if not mocking
if not self.settings.mock_pose_data:
await self._initialize_models()
else:
self.logger.info("Using mock pose data for development")
self.is_initialized = True
self.logger.info("Pose service initialized successfully")
except Exception as e:
self.last_error = str(e)
self.logger.error(f"Failed to initialize pose service: {e}")
raise
async def _initialize_models(self):
"""Initialize neural network models."""
try:
# Initialize DensePose model
if self.settings.pose_model_path:
self.densepose_model = DensePoseHead()
# Load model weights if path is provided
# model_state = torch.load(self.settings.pose_model_path)
# self.densepose_model.load_state_dict(model_state)
self.logger.info("DensePose model loaded")
else:
self.logger.warning("No pose model path provided, using default model")
self.densepose_model = DensePoseHead()
# Initialize modality translation
config = {
'input_channels': 64, # CSI data channels
'hidden_channels': [128, 256, 512],
'output_channels': 256, # Visual feature channels
'use_attention': True
}
self.modality_translator = ModalityTranslationNetwork(config)
# Set models to evaluation mode
self.densepose_model.eval()
self.modality_translator.eval()
except Exception as e:
self.logger.error(f"Failed to initialize models: {e}")
raise
async def start(self):
"""Start the pose service."""
if not self.is_initialized:
await self.initialize()
self.is_running = True
self.logger.info("Pose service started")
async def stop(self):
"""Stop the pose service."""
self.is_running = False
self.logger.info("Pose service stopped")
async def process_csi_data(self, csi_data: np.ndarray, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Process CSI data and estimate poses."""
if not self.is_running:
raise RuntimeError("Pose service is not running")
start_time = datetime.now()
try:
# Process CSI data
processed_csi = await self._process_csi(csi_data, metadata)
# Estimate poses
poses = await self._estimate_poses(processed_csi, metadata)
# Update statistics
processing_time = (datetime.now() - start_time).total_seconds() * 1000
self._update_stats(poses, processing_time)
return {
"timestamp": start_time.isoformat(),
"poses": poses,
"metadata": metadata,
"processing_time_ms": processing_time,
"confidence_scores": [pose.get("confidence", 0.0) for pose in poses]
}
except Exception as e:
self.last_error = str(e)
self.stats["failed_detections"] += 1
self.logger.error(f"Error processing CSI data: {e}")
raise
async def _process_csi(self, csi_data: np.ndarray, metadata: Dict[str, Any]) -> np.ndarray:
"""Process raw CSI data."""
# Add CSI data to processor
self.csi_processor.add_data(csi_data, metadata.get("timestamp", datetime.now()))
# Get processed data
processed_data = self.csi_processor.get_processed_data()
# Apply phase sanitization
if processed_data is not None:
sanitized_data = self.phase_sanitizer.sanitize(processed_data)
return sanitized_data
return csi_data
async def _estimate_poses(self, csi_data: np.ndarray, metadata: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Estimate poses from processed CSI data."""
if self.settings.mock_pose_data:
return self._generate_mock_poses()
try:
# Convert CSI data to tensor
csi_tensor = torch.from_numpy(csi_data).float()
# Add batch dimension if needed
if len(csi_tensor.shape) == 2:
csi_tensor = csi_tensor.unsqueeze(0)
# Translate modality (CSI to visual-like features)
with torch.no_grad():
visual_features = self.modality_translator(csi_tensor)
# Estimate poses using DensePose
pose_outputs = self.densepose_model(visual_features)
# Convert outputs to pose detections
poses = self._parse_pose_outputs(pose_outputs)
# Filter by confidence threshold
filtered_poses = [
pose for pose in poses
if pose.get("confidence", 0.0) >= self.settings.pose_confidence_threshold
]
# Limit number of persons
if len(filtered_poses) > self.settings.pose_max_persons:
filtered_poses = sorted(
filtered_poses,
key=lambda x: x.get("confidence", 0.0),
reverse=True
)[:self.settings.pose_max_persons]
return filtered_poses
except Exception as e:
self.logger.error(f"Error in pose estimation: {e}")
return []
def _parse_pose_outputs(self, outputs: torch.Tensor) -> List[Dict[str, Any]]:
"""Parse neural network outputs into pose detections."""
poses = []
# This is a simplified parsing - in reality, this would depend on the model architecture
# For now, generate mock poses based on the output shape
batch_size = outputs.shape[0]
for i in range(batch_size):
# Extract pose information (mock implementation)
confidence = float(torch.sigmoid(outputs[i, 0]).item()) if outputs.shape[1] > 0 else 0.5
pose = {
"person_id": i,
"confidence": confidence,
"keypoints": self._generate_keypoints(),
"bounding_box": self._generate_bounding_box(),
"activity": self._classify_activity(outputs[i] if len(outputs.shape) > 1 else outputs),
"timestamp": datetime.now().isoformat()
}
poses.append(pose)
return poses
def _generate_mock_poses(self) -> List[Dict[str, Any]]:
"""Generate mock pose data for development."""
import random
num_persons = random.randint(1, min(3, self.settings.pose_max_persons))
poses = []
for i in range(num_persons):
confidence = random.uniform(0.3, 0.95)
pose = {
"person_id": i,
"confidence": confidence,
"keypoints": self._generate_keypoints(),
"bounding_box": self._generate_bounding_box(),
"activity": random.choice(["standing", "sitting", "walking", "lying"]),
"timestamp": datetime.now().isoformat()
}
poses.append(pose)
return poses
def _generate_keypoints(self) -> List[Dict[str, Any]]:
"""Generate keypoints for a person."""
import random
keypoint_names = [
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
"left_wrist", "right_wrist", "left_hip", "right_hip",
"left_knee", "right_knee", "left_ankle", "right_ankle"
]
keypoints = []
for name in keypoint_names:
keypoints.append({
"name": name,
"x": random.uniform(0.1, 0.9),
"y": random.uniform(0.1, 0.9),
"confidence": random.uniform(0.5, 0.95)
})
return keypoints
def _generate_bounding_box(self) -> Dict[str, float]:
"""Generate bounding box for a person."""
import random
x = random.uniform(0.1, 0.6)
y = random.uniform(0.1, 0.6)
width = random.uniform(0.2, 0.4)
height = random.uniform(0.3, 0.5)
return {
"x": x,
"y": y,
"width": width,
"height": height
}
def _classify_activity(self, features: torch.Tensor) -> str:
"""Classify activity from features."""
# Simple mock classification
import random
activities = ["standing", "sitting", "walking", "lying", "unknown"]
return random.choice(activities)
def _update_stats(self, poses: List[Dict[str, Any]], processing_time: float):
"""Update processing statistics."""
self.stats["total_processed"] += 1
if poses:
self.stats["successful_detections"] += 1
confidences = [pose.get("confidence", 0.0) for pose in poses]
avg_confidence = sum(confidences) / len(confidences)
# Update running average
total = self.stats["successful_detections"]
current_avg = self.stats["average_confidence"]
self.stats["average_confidence"] = (current_avg * (total - 1) + avg_confidence) / total
else:
self.stats["failed_detections"] += 1
# Update processing time (running average)
total = self.stats["total_processed"]
current_avg = self.stats["processing_time_ms"]
self.stats["processing_time_ms"] = (current_avg * (total - 1) + processing_time) / total
async def get_status(self) -> Dict[str, Any]:
"""Get service status."""
return {
"status": "healthy" if self.is_running and not self.last_error else "unhealthy",
"initialized": self.is_initialized,
"running": self.is_running,
"last_error": self.last_error,
"statistics": self.stats.copy(),
"configuration": {
"mock_data": self.settings.mock_pose_data,
"confidence_threshold": self.settings.pose_confidence_threshold,
"max_persons": self.settings.pose_max_persons,
"batch_size": self.settings.pose_processing_batch_size
}
}
async def get_metrics(self) -> Dict[str, Any]:
"""Get service metrics."""
return {
"pose_service": {
"total_processed": self.stats["total_processed"],
"successful_detections": self.stats["successful_detections"],
"failed_detections": self.stats["failed_detections"],
"success_rate": (
self.stats["successful_detections"] / max(1, self.stats["total_processed"])
),
"average_confidence": self.stats["average_confidence"],
"average_processing_time_ms": self.stats["processing_time_ms"]
}
}
async def reset(self):
"""Reset service state."""
self.stats = {
"total_processed": 0,
"successful_detections": 0,
"failed_detections": 0,
"average_confidence": 0.0,
"processing_time_ms": 0.0
}
self.last_error = None
self.logger.info("Pose service reset")
# API endpoint methods
async def estimate_poses(self, zone_ids=None, confidence_threshold=None, max_persons=None,
include_keypoints=True, include_segmentation=False):
"""Estimate poses with API parameters."""
try:
# Generate mock CSI data for estimation
mock_csi = np.random.randn(64, 56, 3) # Mock CSI data
metadata = {
"timestamp": datetime.now(),
"zone_ids": zone_ids or ["zone_1"],
"confidence_threshold": confidence_threshold or self.settings.pose_confidence_threshold,
"max_persons": max_persons or self.settings.pose_max_persons
}
# Process the data
result = await self.process_csi_data(mock_csi, metadata)
# Format for API response
persons = []
for i, pose in enumerate(result["poses"]):
person = {
"person_id": str(pose["person_id"]),
"confidence": pose["confidence"],
"bounding_box": pose["bounding_box"],
"zone_id": zone_ids[0] if zone_ids else "zone_1",
"activity": pose["activity"],
"timestamp": datetime.fromisoformat(pose["timestamp"])
}
if include_keypoints:
person["keypoints"] = pose["keypoints"]
if include_segmentation:
person["segmentation"] = {"mask": "mock_segmentation_data"}
persons.append(person)
# Zone summary
zone_summary = {}
for zone_id in (zone_ids or ["zone_1"]):
zone_summary[zone_id] = len([p for p in persons if p.get("zone_id") == zone_id])
return {
"timestamp": datetime.now(),
"frame_id": f"frame_{int(datetime.now().timestamp())}",
"persons": persons,
"zone_summary": zone_summary,
"processing_time_ms": result["processing_time_ms"],
"metadata": {"mock_data": self.settings.mock_pose_data}
}
except Exception as e:
self.logger.error(f"Error in estimate_poses: {e}")
raise
async def analyze_with_params(self, zone_ids=None, confidence_threshold=None, max_persons=None,
include_keypoints=True, include_segmentation=False):
"""Analyze pose data with custom parameters."""
return await self.estimate_poses(zone_ids, confidence_threshold, max_persons,
include_keypoints, include_segmentation)
async def get_zone_occupancy(self, zone_id: str):
"""Get current occupancy for a specific zone."""
try:
# Mock occupancy data
import random
count = random.randint(0, 5)
persons = []
for i in range(count):
persons.append({
"person_id": f"person_{i}",
"confidence": random.uniform(0.7, 0.95),
"activity": random.choice(["standing", "sitting", "walking"])
})
return {
"count": count,
"max_occupancy": 10,
"persons": persons,
"timestamp": datetime.now()
}
except Exception as e:
self.logger.error(f"Error getting zone occupancy: {e}")
return None
async def get_zones_summary(self):
"""Get occupancy summary for all zones."""
try:
import random
zones = ["zone_1", "zone_2", "zone_3", "zone_4"]
zone_data = {}
total_persons = 0
active_zones = 0
for zone_id in zones:
count = random.randint(0, 3)
zone_data[zone_id] = {
"occupancy": count,
"max_occupancy": 10,
"status": "active" if count > 0 else "inactive"
}
total_persons += count
if count > 0:
active_zones += 1
return {
"total_persons": total_persons,
"zones": zone_data,
"active_zones": active_zones
}
except Exception as e:
self.logger.error(f"Error getting zones summary: {e}")
raise
async def get_historical_data(self, start_time, end_time, zone_ids=None,
aggregation_interval=300, include_raw_data=False):
"""Get historical pose estimation data."""
try:
# Mock historical data
import random
from datetime import timedelta
current_time = start_time
aggregated_data = []
raw_data = [] if include_raw_data else None
while current_time < end_time:
# Generate aggregated data point
data_point = {
"timestamp": current_time,
"total_persons": random.randint(0, 8),
"zones": {}
}
for zone_id in (zone_ids or ["zone_1", "zone_2", "zone_3"]):
data_point["zones"][zone_id] = {
"occupancy": random.randint(0, 3),
"avg_confidence": random.uniform(0.7, 0.95)
}
aggregated_data.append(data_point)
# Generate raw data if requested
if include_raw_data:
for _ in range(random.randint(0, 5)):
raw_data.append({
"timestamp": current_time + timedelta(seconds=random.randint(0, aggregation_interval)),
"person_id": f"person_{random.randint(1, 10)}",
"zone_id": random.choice(zone_ids or ["zone_1", "zone_2", "zone_3"]),
"confidence": random.uniform(0.5, 0.95),
"activity": random.choice(["standing", "sitting", "walking"])
})
current_time += timedelta(seconds=aggregation_interval)
return {
"aggregated_data": aggregated_data,
"raw_data": raw_data,
"total_records": len(aggregated_data)
}
except Exception as e:
self.logger.error(f"Error getting historical data: {e}")
raise
async def get_recent_activities(self, zone_id=None, limit=10):
"""Get recently detected activities."""
try:
import random
activities = []
for i in range(limit):
activity = {
"activity_id": f"activity_{i}",
"person_id": f"person_{random.randint(1, 5)}",
"zone_id": zone_id or random.choice(["zone_1", "zone_2", "zone_3"]),
"activity": random.choice(["standing", "sitting", "walking", "lying"]),
"confidence": random.uniform(0.6, 0.95),
"timestamp": datetime.now() - timedelta(minutes=random.randint(0, 60)),
"duration_seconds": random.randint(10, 300)
}
activities.append(activity)
return activities
except Exception as e:
self.logger.error(f"Error getting recent activities: {e}")
raise
async def is_calibrating(self):
"""Check if calibration is in progress."""
return False # Mock implementation
async def start_calibration(self):
"""Start calibration process."""
import uuid
calibration_id = str(uuid.uuid4())
self.logger.info(f"Started calibration: {calibration_id}")
return calibration_id
async def run_calibration(self, calibration_id):
"""Run calibration process."""
self.logger.info(f"Running calibration: {calibration_id}")
# Mock calibration process
await asyncio.sleep(5)
self.logger.info(f"Calibration completed: {calibration_id}")
async def get_calibration_status(self):
"""Get current calibration status."""
return {
"is_calibrating": False,
"calibration_id": None,
"progress_percent": 100,
"current_step": "completed",
"estimated_remaining_minutes": 0,
"last_calibration": datetime.now() - timedelta(hours=1)
}
async def get_statistics(self, start_time, end_time):
"""Get pose estimation statistics."""
try:
import random
# Mock statistics
total_detections = random.randint(100, 1000)
successful_detections = int(total_detections * random.uniform(0.8, 0.95))
return {
"total_detections": total_detections,
"successful_detections": successful_detections,
"failed_detections": total_detections - successful_detections,
"success_rate": successful_detections / total_detections,
"average_confidence": random.uniform(0.75, 0.90),
"average_processing_time_ms": random.uniform(50, 200),
"unique_persons": random.randint(5, 20),
"most_active_zone": random.choice(["zone_1", "zone_2", "zone_3"]),
"activity_distribution": {
"standing": random.uniform(0.3, 0.5),
"sitting": random.uniform(0.2, 0.4),
"walking": random.uniform(0.1, 0.3),
"lying": random.uniform(0.0, 0.1)
}
}
except Exception as e:
self.logger.error(f"Error getting statistics: {e}")
raise
async def process_segmentation_data(self, frame_id):
"""Process segmentation data in background."""
self.logger.info(f"Processing segmentation data for frame: {frame_id}")
# Mock background processing
await asyncio.sleep(2)
self.logger.info(f"Segmentation processing completed for frame: {frame_id}")
# WebSocket streaming methods
async def get_current_pose_data(self):
"""Get current pose data for streaming."""
try:
# Generate current pose data
result = await self.estimate_poses()
# Format data by zones for WebSocket streaming
zone_data = {}
# Group persons by zone
for person in result["persons"]:
zone_id = person.get("zone_id", "zone_1")
if zone_id not in zone_data:
zone_data[zone_id] = {
"pose": {
"persons": [],
"count": 0
},
"confidence": 0.0,
"activity": None,
"metadata": {
"frame_id": result["frame_id"],
"processing_time_ms": result["processing_time_ms"]
}
}
zone_data[zone_id]["pose"]["persons"].append(person)
zone_data[zone_id]["pose"]["count"] += 1
# Update zone confidence (average)
current_confidence = zone_data[zone_id]["confidence"]
person_confidence = person.get("confidence", 0.0)
zone_data[zone_id]["confidence"] = (current_confidence + person_confidence) / 2
# Set activity if not already set
if not zone_data[zone_id]["activity"] and person.get("activity"):
zone_data[zone_id]["activity"] = person["activity"]
return zone_data
except Exception as e:
self.logger.error(f"Error getting current pose data: {e}")
# Return empty zone data on error
return {}
# Health check methods
async def health_check(self):
"""Perform health check."""
try:
status = "healthy" if self.is_running and not self.last_error else "unhealthy"
return {
"status": status,
"message": self.last_error if self.last_error else "Service is running normally",
"uptime_seconds": 0.0, # TODO: Implement actual uptime tracking
"metrics": {
"total_processed": self.stats["total_processed"],
"success_rate": (
self.stats["successful_detections"] / max(1, self.stats["total_processed"])
),
"average_processing_time_ms": self.stats["processing_time_ms"]
}
}
except Exception as e:
return {
"status": "unhealthy",
"message": f"Health check failed: {str(e)}"
}
async def is_ready(self):
"""Check if service is ready."""
return self.is_initialized and self.is_running

View File

@@ -0,0 +1,397 @@
"""
Real-time streaming service for WiFi-DensePose API
"""
import logging
import asyncio
import json
from typing import Dict, List, Optional, Any, Set
from datetime import datetime
from collections import deque
import numpy as np
from fastapi import WebSocket
from src.config.settings import Settings
from src.config.domains import DomainConfig
logger = logging.getLogger(__name__)
class StreamService:
"""Service for real-time data streaming."""
def __init__(self, settings: Settings, domain_config: DomainConfig):
"""Initialize stream service."""
self.settings = settings
self.domain_config = domain_config
self.logger = logging.getLogger(__name__)
# WebSocket connections
self.connections: Set[WebSocket] = set()
self.connection_metadata: Dict[WebSocket, Dict[str, Any]] = {}
# Stream buffers
self.pose_buffer = deque(maxlen=self.settings.stream_buffer_size)
self.csi_buffer = deque(maxlen=self.settings.stream_buffer_size)
# Service state
self.is_running = False
self.last_error = None
# Streaming statistics
self.stats = {
"active_connections": 0,
"total_connections": 0,
"messages_sent": 0,
"messages_failed": 0,
"data_points_streamed": 0,
"average_latency_ms": 0.0
}
# Background tasks
self.streaming_task = None
async def initialize(self):
"""Initialize the stream service."""
self.logger.info("Stream service initialized")
async def start(self):
"""Start the stream service."""
if self.is_running:
return
self.is_running = True
self.logger.info("Stream service started")
# Start background streaming task
if self.settings.enable_real_time_processing:
self.streaming_task = asyncio.create_task(self._streaming_loop())
async def stop(self):
"""Stop the stream service."""
self.is_running = False
# Cancel background task
if self.streaming_task:
self.streaming_task.cancel()
try:
await self.streaming_task
except asyncio.CancelledError:
pass
# Close all connections
await self._close_all_connections()
self.logger.info("Stream service stopped")
async def add_connection(self, websocket: WebSocket, metadata: Dict[str, Any] = None):
"""Add a new WebSocket connection."""
try:
await websocket.accept()
self.connections.add(websocket)
self.connection_metadata[websocket] = metadata or {}
self.stats["active_connections"] = len(self.connections)
self.stats["total_connections"] += 1
self.logger.info(f"New WebSocket connection added. Total: {len(self.connections)}")
# Send initial data if available
await self._send_initial_data(websocket)
except Exception as e:
self.logger.error(f"Error adding WebSocket connection: {e}")
raise
async def remove_connection(self, websocket: WebSocket):
"""Remove a WebSocket connection."""
try:
if websocket in self.connections:
self.connections.remove(websocket)
self.connection_metadata.pop(websocket, None)
self.stats["active_connections"] = len(self.connections)
self.logger.info(f"WebSocket connection removed. Total: {len(self.connections)}")
except Exception as e:
self.logger.error(f"Error removing WebSocket connection: {e}")
async def broadcast_pose_data(self, pose_data: Dict[str, Any]):
"""Broadcast pose data to all connected clients."""
if not self.is_running:
return
# Add to buffer
self.pose_buffer.append({
"type": "pose_data",
"timestamp": datetime.now().isoformat(),
"data": pose_data
})
# Broadcast to all connections
await self._broadcast_message({
"type": "pose_update",
"timestamp": datetime.now().isoformat(),
"data": pose_data
})
async def broadcast_csi_data(self, csi_data: np.ndarray, metadata: Dict[str, Any]):
"""Broadcast CSI data to all connected clients."""
if not self.is_running:
return
# Convert numpy array to list for JSON serialization
csi_list = csi_data.tolist() if isinstance(csi_data, np.ndarray) else csi_data
# Add to buffer
self.csi_buffer.append({
"type": "csi_data",
"timestamp": datetime.now().isoformat(),
"data": csi_list,
"metadata": metadata
})
# Broadcast to all connections
await self._broadcast_message({
"type": "csi_update",
"timestamp": datetime.now().isoformat(),
"data": csi_list,
"metadata": metadata
})
async def broadcast_system_status(self, status_data: Dict[str, Any]):
"""Broadcast system status to all connected clients."""
if not self.is_running:
return
await self._broadcast_message({
"type": "system_status",
"timestamp": datetime.now().isoformat(),
"data": status_data
})
async def send_to_connection(self, websocket: WebSocket, message: Dict[str, Any]):
"""Send message to a specific connection."""
try:
if websocket in self.connections:
await websocket.send_text(json.dumps(message))
self.stats["messages_sent"] += 1
except Exception as e:
self.logger.error(f"Error sending message to connection: {e}")
self.stats["messages_failed"] += 1
await self.remove_connection(websocket)
async def _broadcast_message(self, message: Dict[str, Any]):
"""Broadcast message to all connected clients."""
if not self.connections:
return
disconnected = set()
for websocket in self.connections.copy():
try:
await websocket.send_text(json.dumps(message))
self.stats["messages_sent"] += 1
except Exception as e:
self.logger.warning(f"Failed to send message to connection: {e}")
self.stats["messages_failed"] += 1
disconnected.add(websocket)
# Remove disconnected clients
for websocket in disconnected:
await self.remove_connection(websocket)
if message.get("type") in ["pose_update", "csi_update"]:
self.stats["data_points_streamed"] += 1
async def _send_initial_data(self, websocket: WebSocket):
"""Send initial data to a new connection."""
try:
# Send recent pose data
if self.pose_buffer:
recent_poses = list(self.pose_buffer)[-10:] # Last 10 poses
await self.send_to_connection(websocket, {
"type": "initial_poses",
"timestamp": datetime.now().isoformat(),
"data": recent_poses
})
# Send recent CSI data
if self.csi_buffer:
recent_csi = list(self.csi_buffer)[-5:] # Last 5 CSI readings
await self.send_to_connection(websocket, {
"type": "initial_csi",
"timestamp": datetime.now().isoformat(),
"data": recent_csi
})
# Send service status
status = await self.get_status()
await self.send_to_connection(websocket, {
"type": "service_status",
"timestamp": datetime.now().isoformat(),
"data": status
})
except Exception as e:
self.logger.error(f"Error sending initial data: {e}")
async def _streaming_loop(self):
"""Background streaming loop for periodic updates."""
try:
while self.is_running:
# Send periodic heartbeat
if self.connections:
await self._broadcast_message({
"type": "heartbeat",
"timestamp": datetime.now().isoformat(),
"active_connections": len(self.connections)
})
# Wait for next iteration
await asyncio.sleep(self.settings.websocket_ping_interval)
except asyncio.CancelledError:
self.logger.info("Streaming loop cancelled")
except Exception as e:
self.logger.error(f"Error in streaming loop: {e}")
self.last_error = str(e)
async def _close_all_connections(self):
"""Close all WebSocket connections."""
disconnected = []
for websocket in self.connections.copy():
try:
await websocket.close()
disconnected.append(websocket)
except Exception as e:
self.logger.warning(f"Error closing connection: {e}")
disconnected.append(websocket)
# Clear all connections
for websocket in disconnected:
await self.remove_connection(websocket)
async def get_status(self) -> Dict[str, Any]:
"""Get service status."""
return {
"status": "healthy" if self.is_running and not self.last_error else "unhealthy",
"running": self.is_running,
"last_error": self.last_error,
"connections": {
"active": len(self.connections),
"total": self.stats["total_connections"]
},
"buffers": {
"pose_buffer_size": len(self.pose_buffer),
"csi_buffer_size": len(self.csi_buffer),
"max_buffer_size": self.settings.stream_buffer_size
},
"statistics": self.stats.copy(),
"configuration": {
"stream_fps": self.settings.stream_fps,
"buffer_size": self.settings.stream_buffer_size,
"ping_interval": self.settings.websocket_ping_interval,
"timeout": self.settings.websocket_timeout
}
}
async def get_metrics(self) -> Dict[str, Any]:
"""Get service metrics."""
total_messages = self.stats["messages_sent"] + self.stats["messages_failed"]
success_rate = self.stats["messages_sent"] / max(1, total_messages)
return {
"stream_service": {
"active_connections": self.stats["active_connections"],
"total_connections": self.stats["total_connections"],
"messages_sent": self.stats["messages_sent"],
"messages_failed": self.stats["messages_failed"],
"message_success_rate": success_rate,
"data_points_streamed": self.stats["data_points_streamed"],
"average_latency_ms": self.stats["average_latency_ms"]
}
}
async def get_connection_info(self) -> List[Dict[str, Any]]:
"""Get information about active connections."""
connections_info = []
for websocket in self.connections:
metadata = self.connection_metadata.get(websocket, {})
connection_info = {
"id": id(websocket),
"connected_at": metadata.get("connected_at", "unknown"),
"user_agent": metadata.get("user_agent", "unknown"),
"ip_address": metadata.get("ip_address", "unknown"),
"subscription_types": metadata.get("subscription_types", [])
}
connections_info.append(connection_info)
return connections_info
async def reset(self):
"""Reset service state."""
# Clear buffers
self.pose_buffer.clear()
self.csi_buffer.clear()
# Reset statistics
self.stats = {
"active_connections": len(self.connections),
"total_connections": 0,
"messages_sent": 0,
"messages_failed": 0,
"data_points_streamed": 0,
"average_latency_ms": 0.0
}
self.last_error = None
self.logger.info("Stream service reset")
def get_buffer_data(self, buffer_type: str, limit: int = 100) -> List[Dict[str, Any]]:
"""Get data from buffers."""
if buffer_type == "pose":
return list(self.pose_buffer)[-limit:]
elif buffer_type == "csi":
return list(self.csi_buffer)[-limit:]
else:
return []
@property
def is_active(self) -> bool:
"""Check if stream service is active."""
return self.is_running
async def health_check(self) -> Dict[str, Any]:
"""Perform health check."""
try:
status = "healthy" if self.is_running and not self.last_error else "unhealthy"
return {
"status": status,
"message": self.last_error if self.last_error else "Stream service is running normally",
"active_connections": len(self.connections),
"metrics": {
"messages_sent": self.stats["messages_sent"],
"messages_failed": self.stats["messages_failed"],
"data_points_streamed": self.stats["data_points_streamed"]
}
}
except Exception as e:
return {
"status": "unhealthy",
"message": f"Health check failed: {str(e)}"
}
async def is_ready(self) -> bool:
"""Check if service is ready."""
return self.is_running

198
test_application.py Normal file
View File

@@ -0,0 +1,198 @@
#!/usr/bin/env python3
"""
Test script to verify WiFi-DensePose API functionality
"""
import asyncio
import aiohttp
import json
import websockets
import sys
from typing import Dict, Any
BASE_URL = "http://localhost:8000"
WS_URL = "ws://localhost:8000"
async def test_health_endpoints():
"""Test health check endpoints."""
print("🔍 Testing health endpoints...")
async with aiohttp.ClientSession() as session:
# Test basic health
async with session.get(f"{BASE_URL}/health/health") as response:
if response.status == 200:
data = await response.json()
print(f"✅ Health check: {data['status']}")
else:
print(f"❌ Health check failed: {response.status}")
# Test readiness
async with session.get(f"{BASE_URL}/health/ready") as response:
if response.status == 200:
data = await response.json()
status = "ready" if data['ready'] else "not ready"
print(f"✅ Readiness check: {status}")
else:
print(f"❌ Readiness check failed: {response.status}")
# Test liveness
async with session.get(f"{BASE_URL}/health/live") as response:
if response.status == 200:
data = await response.json()
print(f"✅ Liveness check: {data['status']}")
else:
print(f"❌ Liveness check failed: {response.status}")
async def test_api_endpoints():
"""Test main API endpoints."""
print("\n🔍 Testing API endpoints...")
async with aiohttp.ClientSession() as session:
# Test root endpoint
async with session.get(f"{BASE_URL}/") as response:
if response.status == 200:
data = await response.json()
print(f"✅ Root endpoint: {data['name']} v{data['version']}")
else:
print(f"❌ Root endpoint failed: {response.status}")
# Test API info
async with session.get(f"{BASE_URL}/api/v1/info") as response:
if response.status == 200:
data = await response.json()
print(f"✅ API info: {len(data['services'])} services configured")
else:
print(f"❌ API info failed: {response.status}")
# Test API status
async with session.get(f"{BASE_URL}/api/v1/status") as response:
if response.status == 200:
data = await response.json()
print(f"✅ API status: {data['api']['status']}")
else:
print(f"❌ API status failed: {response.status}")
async def test_pose_endpoints():
"""Test pose estimation endpoints."""
print("\n🔍 Testing pose endpoints...")
async with aiohttp.ClientSession() as session:
# Test current pose data
async with session.get(f"{BASE_URL}/api/v1/pose/current") as response:
if response.status == 200:
data = await response.json()
print(f"✅ Current pose data: {len(data.get('poses', []))} poses detected")
else:
print(f"❌ Current pose data failed: {response.status}")
# Test zones summary
async with session.get(f"{BASE_URL}/api/v1/pose/zones/summary") as response:
if response.status == 200:
data = await response.json()
zones = data.get('zones', {})
print(f"✅ Zones summary: {len(zones)} zones")
for zone_id, zone_data in list(zones.items())[:3]: # Show first 3 zones
print(f" - {zone_id}: {zone_data.get('occupancy', 0)} people")
else:
print(f"❌ Zones summary failed: {response.status}")
# Test pose stats
async with session.get(f"{BASE_URL}/api/v1/pose/stats") as response:
if response.status == 200:
data = await response.json()
print(f"✅ Pose stats: {data.get('total_detections', 0)} total detections")
else:
print(f"❌ Pose stats failed: {response.status}")
async def test_stream_endpoints():
"""Test streaming endpoints."""
print("\n🔍 Testing stream endpoints...")
async with aiohttp.ClientSession() as session:
# Test stream status
async with session.get(f"{BASE_URL}/api/v1/stream/status") as response:
if response.status == 200:
data = await response.json()
print(f"✅ Stream status: {'Active' if data['is_active'] else 'Inactive'}")
print(f" - Connected clients: {data['connected_clients']}")
else:
print(f"❌ Stream status failed: {response.status}")
# Test stream metrics
async with session.get(f"{BASE_URL}/api/v1/stream/metrics") as response:
if response.status == 200:
data = await response.json()
print(f"✅ Stream metrics available")
else:
print(f"❌ Stream metrics failed: {response.status}")
async def test_websocket_connection():
"""Test WebSocket connection."""
print("\n🔍 Testing WebSocket connection...")
try:
uri = f"{WS_URL}/api/v1/stream/pose"
async with websockets.connect(uri) as websocket:
print("✅ WebSocket connected successfully")
# Wait for connection confirmation
message = await asyncio.wait_for(websocket.recv(), timeout=5.0)
data = json.loads(message)
if data.get("type") == "connection_established":
print(f"✅ Connection established with client ID: {data.get('client_id')}")
# Send a ping
await websocket.send(json.dumps({"type": "ping"}))
# Wait for pong
pong_message = await asyncio.wait_for(websocket.recv(), timeout=5.0)
pong_data = json.loads(pong_message)
if pong_data.get("type") == "pong":
print("✅ WebSocket ping/pong successful")
else:
print(f"❌ Unexpected pong response: {pong_data}")
else:
print(f"❌ Unexpected connection message: {data}")
except asyncio.TimeoutError:
print("❌ WebSocket connection timeout")
except Exception as e:
print(f"❌ WebSocket connection failed: {e}")
async def test_calibration_endpoints():
"""Test calibration endpoints."""
print("\n🔍 Testing calibration endpoints...")
async with aiohttp.ClientSession() as session:
# Test calibration status
async with session.get(f"{BASE_URL}/api/v1/pose/calibration/status") as response:
if response.status == 200:
data = await response.json()
print(f"✅ Calibration status: {data.get('status', 'unknown')}")
else:
print(f"❌ Calibration status failed: {response.status}")
async def main():
"""Run all tests."""
print("🚀 Starting WiFi-DensePose API Tests")
print("=" * 50)
try:
await test_health_endpoints()
await test_api_endpoints()
await test_pose_endpoints()
await test_stream_endpoints()
await test_websocket_connection()
await test_calibration_endpoints()
print("\n" + "=" * 50)
print("✅ All tests completed!")
except Exception as e:
print(f"\n❌ Test suite failed: {e}")
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main())