feat: Complete Rust port of WiFi-DensePose with modular crates
Major changes: - Organized Python v1 implementation into v1/ subdirectory - Created Rust workspace with 9 modular crates: - wifi-densepose-core: Core types, traits, errors - wifi-densepose-signal: CSI processing, phase sanitization, FFT - wifi-densepose-nn: Neural network inference (ONNX/Candle/tch) - wifi-densepose-api: Axum-based REST/WebSocket API - wifi-densepose-db: SQLx database layer - wifi-densepose-config: Configuration management - wifi-densepose-hardware: Hardware abstraction - wifi-densepose-wasm: WebAssembly bindings - wifi-densepose-cli: Command-line interface Documentation: - ADR-001: Workspace structure - ADR-002: Signal processing library selection - ADR-003: Neural network inference strategy - DDD domain model with bounded contexts Testing: - 69 tests passing across all crates - Signal processing: 45 tests - Neural networks: 21 tests - Core: 3 doc tests Performance targets: - 10x faster CSI processing (~0.5ms vs ~5ms) - 5x lower memory usage (~100MB vs ~500MB) - WASM support for browser deployment
This commit is contained in:
266
v1/src/__init__.py
Normal file
266
v1/src/__init__.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
WiFi-DensePose API Package
|
||||
==========================
|
||||
|
||||
A comprehensive system for WiFi-based human pose estimation using CSI data
|
||||
and DensePose neural networks.
|
||||
|
||||
This package provides:
|
||||
- Real-time CSI data collection from WiFi routers
|
||||
- Advanced signal processing and phase sanitization
|
||||
- DensePose neural network integration for pose estimation
|
||||
- RESTful API for data access and control
|
||||
- Background task management for data processing
|
||||
- Comprehensive monitoring and logging
|
||||
|
||||
Example usage:
|
||||
>>> from src.app import app
|
||||
>>> from src.config.settings import get_settings
|
||||
>>>
|
||||
>>> settings = get_settings()
|
||||
>>> # Run with: uvicorn src.app:app --host 0.0.0.0 --port 8000
|
||||
|
||||
For CLI usage:
|
||||
$ wifi-densepose start --host 0.0.0.0 --port 8000
|
||||
$ wifi-densepose status
|
||||
$ wifi-densepose stop
|
||||
|
||||
Author: WiFi-DensePose Team
|
||||
License: MIT
|
||||
"""
|
||||
|
||||
__version__ = "1.1.0"
|
||||
__author__ = "WiFi-DensePose Team"
|
||||
__email__ = "team@wifi-densepose.com"
|
||||
__license__ = "MIT"
|
||||
__copyright__ = "Copyright 2024 WiFi-DensePose Team"
|
||||
|
||||
# Package metadata
|
||||
__title__ = "wifi-densepose"
|
||||
__description__ = "WiFi-based human pose estimation using CSI data and DensePose neural networks"
|
||||
__url__ = "https://github.com/wifi-densepose/wifi-densepose"
|
||||
__download_url__ = "https://github.com/wifi-densepose/wifi-densepose/archive/main.zip"
|
||||
|
||||
# Version info tuple
|
||||
__version_info__ = tuple(int(x) for x in __version__.split('.'))
|
||||
|
||||
# Import key components for easy access
|
||||
try:
|
||||
from src.app import app
|
||||
from src.config.settings import get_settings, Settings
|
||||
from src.logger import setup_logging, get_logger
|
||||
|
||||
# Core components
|
||||
from src.core.csi_processor import CSIProcessor
|
||||
from src.core.phase_sanitizer import PhaseSanitizer
|
||||
from src.core.pose_estimator import PoseEstimator
|
||||
from src.core.router_interface import RouterInterface
|
||||
|
||||
# Services
|
||||
from src.services.orchestrator import ServiceOrchestrator
|
||||
from src.services.health_check import HealthCheckService
|
||||
from src.services.metrics import MetricsService
|
||||
|
||||
# Database
|
||||
from src.database.connection import get_database_manager
|
||||
from src.database.models import (
|
||||
Device, Session, CSIData, PoseDetection,
|
||||
SystemMetric, AuditLog
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Core app
|
||||
'app',
|
||||
'get_settings',
|
||||
'Settings',
|
||||
'setup_logging',
|
||||
'get_logger',
|
||||
|
||||
# Core processing
|
||||
'CSIProcessor',
|
||||
'PhaseSanitizer',
|
||||
'PoseEstimator',
|
||||
'RouterInterface',
|
||||
|
||||
# Services
|
||||
'ServiceOrchestrator',
|
||||
'HealthCheckService',
|
||||
'MetricsService',
|
||||
|
||||
# Database
|
||||
'get_database_manager',
|
||||
'Device',
|
||||
'Session',
|
||||
'CSIData',
|
||||
'PoseDetection',
|
||||
'SystemMetric',
|
||||
'AuditLog',
|
||||
|
||||
# Metadata
|
||||
'__version__',
|
||||
'__version_info__',
|
||||
'__author__',
|
||||
'__email__',
|
||||
'__license__',
|
||||
'__copyright__',
|
||||
]
|
||||
|
||||
except ImportError as e:
|
||||
# Handle import errors gracefully during package installation
|
||||
import warnings
|
||||
warnings.warn(
|
||||
f"Some components could not be imported: {e}. "
|
||||
"This is normal during package installation.",
|
||||
ImportWarning
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'__version__',
|
||||
'__version_info__',
|
||||
'__author__',
|
||||
'__email__',
|
||||
'__license__',
|
||||
'__copyright__',
|
||||
]
|
||||
|
||||
|
||||
def get_version():
|
||||
"""Get the package version."""
|
||||
return __version__
|
||||
|
||||
|
||||
def get_version_info():
|
||||
"""Get the package version as a tuple."""
|
||||
return __version_info__
|
||||
|
||||
|
||||
def get_package_info():
|
||||
"""Get comprehensive package information."""
|
||||
return {
|
||||
'name': __title__,
|
||||
'version': __version__,
|
||||
'version_info': __version_info__,
|
||||
'description': __description__,
|
||||
'author': __author__,
|
||||
'author_email': __email__,
|
||||
'license': __license__,
|
||||
'copyright': __copyright__,
|
||||
'url': __url__,
|
||||
'download_url': __download_url__,
|
||||
}
|
||||
|
||||
|
||||
def check_dependencies():
|
||||
"""Check if all required dependencies are available."""
|
||||
missing_deps = []
|
||||
optional_deps = []
|
||||
|
||||
# Core dependencies
|
||||
required_modules = [
|
||||
('fastapi', 'FastAPI'),
|
||||
('uvicorn', 'Uvicorn'),
|
||||
('pydantic', 'Pydantic'),
|
||||
('sqlalchemy', 'SQLAlchemy'),
|
||||
('numpy', 'NumPy'),
|
||||
('torch', 'PyTorch'),
|
||||
('cv2', 'OpenCV'),
|
||||
('scipy', 'SciPy'),
|
||||
('pandas', 'Pandas'),
|
||||
('redis', 'Redis'),
|
||||
('psutil', 'psutil'),
|
||||
('click', 'Click'),
|
||||
]
|
||||
|
||||
for module_name, display_name in required_modules:
|
||||
try:
|
||||
__import__(module_name)
|
||||
except ImportError:
|
||||
missing_deps.append(display_name)
|
||||
|
||||
# Optional dependencies
|
||||
optional_modules = [
|
||||
('scapy', 'Scapy (for network packet capture)'),
|
||||
('paramiko', 'Paramiko (for SSH connections)'),
|
||||
('serial', 'PySerial (for serial communication)'),
|
||||
('matplotlib', 'Matplotlib (for plotting)'),
|
||||
('prometheus_client', 'Prometheus Client (for metrics)'),
|
||||
]
|
||||
|
||||
for module_name, display_name in optional_modules:
|
||||
try:
|
||||
__import__(module_name)
|
||||
except ImportError:
|
||||
optional_deps.append(display_name)
|
||||
|
||||
return {
|
||||
'missing_required': missing_deps,
|
||||
'missing_optional': optional_deps,
|
||||
'all_required_available': len(missing_deps) == 0,
|
||||
}
|
||||
|
||||
|
||||
def print_system_info():
|
||||
"""Print system and package information."""
|
||||
import sys
|
||||
import platform
|
||||
|
||||
info = get_package_info()
|
||||
deps = check_dependencies()
|
||||
|
||||
print(f"WiFi-DensePose v{info['version']}")
|
||||
print(f"Python {sys.version}")
|
||||
print(f"Platform: {platform.platform()}")
|
||||
print(f"Architecture: {platform.architecture()[0]}")
|
||||
print()
|
||||
|
||||
if deps['all_required_available']:
|
||||
print("✅ All required dependencies are available")
|
||||
else:
|
||||
print("❌ Missing required dependencies:")
|
||||
for dep in deps['missing_required']:
|
||||
print(f" - {dep}")
|
||||
|
||||
if deps['missing_optional']:
|
||||
print("\n⚠️ Missing optional dependencies:")
|
||||
for dep in deps['missing_optional']:
|
||||
print(f" - {dep}")
|
||||
|
||||
print(f"\nFor more information, visit: {info['url']}")
|
||||
|
||||
|
||||
# Package-level configuration
|
||||
import logging
|
||||
|
||||
# Set up basic logging configuration
|
||||
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
||||
|
||||
# Suppress some noisy third-party loggers
|
||||
logging.getLogger('urllib3').setLevel(logging.WARNING)
|
||||
logging.getLogger('requests').setLevel(logging.WARNING)
|
||||
logging.getLogger('asyncio').setLevel(logging.WARNING)
|
||||
|
||||
# Package initialization message
|
||||
if __name__ != '__main__':
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.debug(f"WiFi-DensePose package v{__version__} initialized")
|
||||
|
||||
|
||||
# Compatibility aliases for backward compatibility
|
||||
try:
|
||||
WifiDensePose = app # Legacy alias
|
||||
except NameError:
|
||||
WifiDensePose = None # Will be None if app import failed
|
||||
|
||||
try:
|
||||
get_config = get_settings # Legacy alias
|
||||
except NameError:
|
||||
get_config = None # Will be None if get_settings import failed
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the package when run as a module."""
|
||||
print_system_info()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
7
v1/src/api/__init__.py
Normal file
7
v1/src/api/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
WiFi-DensePose FastAPI application package
|
||||
"""
|
||||
|
||||
# API package - routers and dependencies are imported by app.py
|
||||
|
||||
__all__ = []
|
||||
447
v1/src/api/dependencies.py
Normal file
447
v1/src/api/dependencies.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""
|
||||
Dependency injection for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from functools import lru_cache
|
||||
|
||||
from fastapi import Depends, HTTPException, status, Request
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
|
||||
from src.config.settings import get_settings
|
||||
from src.config.domains import get_domain_config
|
||||
from src.services.pose_service import PoseService
|
||||
from src.services.stream_service import StreamService
|
||||
from src.services.hardware_service import HardwareService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Security scheme for JWT authentication
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
# Service dependencies
|
||||
@lru_cache()
|
||||
def get_pose_service() -> PoseService:
|
||||
"""Get pose service instance."""
|
||||
settings = get_settings()
|
||||
domain_config = get_domain_config()
|
||||
|
||||
return PoseService(
|
||||
settings=settings,
|
||||
domain_config=domain_config
|
||||
)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_stream_service() -> StreamService:
|
||||
"""Get stream service instance."""
|
||||
settings = get_settings()
|
||||
domain_config = get_domain_config()
|
||||
|
||||
return StreamService(
|
||||
settings=settings,
|
||||
domain_config=domain_config
|
||||
)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_hardware_service() -> HardwareService:
|
||||
"""Get hardware service instance."""
|
||||
settings = get_settings()
|
||||
domain_config = get_domain_config()
|
||||
|
||||
return HardwareService(
|
||||
settings=settings,
|
||||
domain_config=domain_config
|
||||
)
|
||||
|
||||
|
||||
# Authentication dependencies
|
||||
async def get_current_user(
|
||||
request: Request,
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get current authenticated user."""
|
||||
settings = get_settings()
|
||||
|
||||
# Skip authentication if disabled
|
||||
if not settings.enable_authentication:
|
||||
return None
|
||||
|
||||
# Check if user is already set by middleware
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
return request.state.user
|
||||
|
||||
# No credentials provided
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
# This would normally validate the JWT token
|
||||
# For now, return a mock user for development
|
||||
if settings.is_development:
|
||||
return {
|
||||
"id": "dev-user",
|
||||
"username": "developer",
|
||||
"email": "dev@example.com",
|
||||
"is_admin": True,
|
||||
"permissions": ["read", "write", "admin"]
|
||||
}
|
||||
|
||||
# In production, implement proper JWT validation
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication not implemented",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get current active user (required authentication)."""
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Check if user is active
|
||||
if not current_user.get("is_active", True):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
|
||||
async def get_admin_user(
|
||||
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get current admin user (admin privileges required)."""
|
||||
if not current_user.get("is_admin", False):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin privileges required"
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
|
||||
# Permission dependencies
|
||||
def require_permission(permission: str):
|
||||
"""Dependency factory for permission checking."""
|
||||
|
||||
async def check_permission(
|
||||
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Check if user has required permission."""
|
||||
user_permissions = current_user.get("permissions", [])
|
||||
|
||||
# Admin users have all permissions
|
||||
if current_user.get("is_admin", False):
|
||||
return current_user
|
||||
|
||||
# Check specific permission
|
||||
if permission not in user_permissions:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Permission '{permission}' required"
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
return check_permission
|
||||
|
||||
|
||||
# Zone access dependencies
|
||||
async def validate_zone_access(
|
||||
zone_id: str,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
|
||||
) -> str:
|
||||
"""Validate user access to a specific zone."""
|
||||
domain_config = get_domain_config()
|
||||
|
||||
# Check if zone exists
|
||||
zone = domain_config.get_zone(zone_id)
|
||||
if not zone:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Zone '{zone_id}' not found"
|
||||
)
|
||||
|
||||
# Check if zone is enabled
|
||||
if not zone.enabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Zone '{zone_id}' is disabled"
|
||||
)
|
||||
|
||||
# If authentication is enabled, check user access
|
||||
if current_user:
|
||||
# Admin users have access to all zones
|
||||
if current_user.get("is_admin", False):
|
||||
return zone_id
|
||||
|
||||
# Check user's zone permissions
|
||||
user_zones = current_user.get("zones", [])
|
||||
if user_zones and zone_id not in user_zones:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Access denied to zone '{zone_id}'"
|
||||
)
|
||||
|
||||
return zone_id
|
||||
|
||||
|
||||
# Router access dependencies
|
||||
async def validate_router_access(
|
||||
router_id: str,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
|
||||
) -> str:
|
||||
"""Validate user access to a specific router."""
|
||||
domain_config = get_domain_config()
|
||||
|
||||
# Check if router exists
|
||||
router = domain_config.get_router(router_id)
|
||||
if not router:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Router '{router_id}' not found"
|
||||
)
|
||||
|
||||
# Check if router is enabled
|
||||
if not router.enabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Router '{router_id}' is disabled"
|
||||
)
|
||||
|
||||
# If authentication is enabled, check user access
|
||||
if current_user:
|
||||
# Admin users have access to all routers
|
||||
if current_user.get("is_admin", False):
|
||||
return router_id
|
||||
|
||||
# Check user's router permissions
|
||||
user_routers = current_user.get("routers", [])
|
||||
if user_routers and router_id not in user_routers:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Access denied to router '{router_id}'"
|
||||
)
|
||||
|
||||
return router_id
|
||||
|
||||
|
||||
# Service health dependencies
|
||||
async def check_service_health(
|
||||
request: Request,
|
||||
service_name: str
|
||||
) -> bool:
|
||||
"""Check if a service is healthy."""
|
||||
try:
|
||||
if service_name == "pose":
|
||||
service = getattr(request.app.state, 'pose_service', None)
|
||||
elif service_name == "stream":
|
||||
service = getattr(request.app.state, 'stream_service', None)
|
||||
elif service_name == "hardware":
|
||||
service = getattr(request.app.state, 'hardware_service', None)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unknown service: {service_name}"
|
||||
)
|
||||
|
||||
if not service:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Service '{service_name}' not available"
|
||||
)
|
||||
|
||||
# Check service health
|
||||
status_info = await service.get_status()
|
||||
if status_info.get("status") != "healthy":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Service '{service_name}' is unhealthy: {status_info.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking service health for {service_name}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Service '{service_name}' health check failed"
|
||||
)
|
||||
|
||||
|
||||
# Rate limiting dependencies
|
||||
async def check_rate_limit(
|
||||
request: Request,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
|
||||
) -> bool:
|
||||
"""Check rate limiting status."""
|
||||
settings = get_settings()
|
||||
|
||||
# Skip if rate limiting is disabled
|
||||
if not settings.enable_rate_limiting:
|
||||
return True
|
||||
|
||||
# Rate limiting is handled by middleware
|
||||
# This dependency can be used for additional checks
|
||||
return True
|
||||
|
||||
|
||||
# Configuration dependencies
|
||||
def get_zone_config(zone_id: str = Depends(validate_zone_access)):
|
||||
"""Get zone configuration."""
|
||||
domain_config = get_domain_config()
|
||||
return domain_config.get_zone(zone_id)
|
||||
|
||||
|
||||
def get_router_config(router_id: str = Depends(validate_router_access)):
|
||||
"""Get router configuration."""
|
||||
domain_config = get_domain_config()
|
||||
return domain_config.get_router(router_id)
|
||||
|
||||
|
||||
# Pagination dependencies
|
||||
class PaginationParams:
|
||||
"""Pagination parameters."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
page: int = 1,
|
||||
size: int = 20,
|
||||
max_size: int = 100
|
||||
):
|
||||
if page < 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Page must be >= 1"
|
||||
)
|
||||
|
||||
if size < 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Size must be >= 1"
|
||||
)
|
||||
|
||||
if size > max_size:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Size must be <= {max_size}"
|
||||
)
|
||||
|
||||
self.page = page
|
||||
self.size = size
|
||||
self.offset = (page - 1) * size
|
||||
self.limit = size
|
||||
|
||||
|
||||
def get_pagination_params(
|
||||
page: int = 1,
|
||||
size: int = 20
|
||||
) -> PaginationParams:
|
||||
"""Get pagination parameters."""
|
||||
return PaginationParams(page=page, size=size)
|
||||
|
||||
|
||||
# Query filter dependencies
|
||||
class QueryFilters:
|
||||
"""Common query filters."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_time: Optional[str] = None,
|
||||
end_time: Optional[str] = None,
|
||||
min_confidence: Optional[float] = None,
|
||||
activity: Optional[str] = None
|
||||
):
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
self.min_confidence = min_confidence
|
||||
self.activity = activity
|
||||
|
||||
# Validate confidence
|
||||
if min_confidence is not None:
|
||||
if not 0.0 <= min_confidence <= 1.0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="min_confidence must be between 0.0 and 1.0"
|
||||
)
|
||||
|
||||
|
||||
def get_query_filters(
|
||||
start_time: Optional[str] = None,
|
||||
end_time: Optional[str] = None,
|
||||
min_confidence: Optional[float] = None,
|
||||
activity: Optional[str] = None
|
||||
) -> QueryFilters:
|
||||
"""Get query filters."""
|
||||
return QueryFilters(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
min_confidence=min_confidence,
|
||||
activity=activity
|
||||
)
|
||||
|
||||
|
||||
# WebSocket dependencies
|
||||
async def get_websocket_user(
|
||||
websocket_token: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get user from WebSocket token."""
|
||||
settings = get_settings()
|
||||
|
||||
# Skip authentication if disabled
|
||||
if not settings.enable_authentication:
|
||||
return None
|
||||
|
||||
# For development, return mock user
|
||||
if settings.is_development:
|
||||
return {
|
||||
"id": "ws-user",
|
||||
"username": "websocket_user",
|
||||
"is_admin": False,
|
||||
"permissions": ["read"]
|
||||
}
|
||||
|
||||
# In production, implement proper token validation
|
||||
return None
|
||||
|
||||
|
||||
async def get_current_user_ws(
|
||||
websocket_token: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get current user for WebSocket connections."""
|
||||
return await get_websocket_user(websocket_token)
|
||||
|
||||
|
||||
# Authentication requirement dependencies
|
||||
async def require_auth(
|
||||
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Require authentication for endpoint access."""
|
||||
return current_user
|
||||
|
||||
|
||||
# Development dependencies
|
||||
async def development_only():
|
||||
"""Dependency that only allows access in development."""
|
||||
settings = get_settings()
|
||||
|
||||
if not settings.is_development:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Endpoint not available in production"
|
||||
)
|
||||
|
||||
return True
|
||||
421
v1/src/api/main.py
Normal file
421
v1/src/api/main.py
Normal file
@@ -0,0 +1,421 @@
|
||||
"""
|
||||
FastAPI application for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import logging.config
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict, Any
|
||||
|
||||
from fastapi import FastAPI, Request, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
|
||||
from src.config.settings import get_settings
|
||||
from src.config.domains import get_domain_config
|
||||
from src.api.routers import pose, stream, health
|
||||
from src.api.middleware.auth import AuthMiddleware
|
||||
from src.api.middleware.rate_limit import RateLimitMiddleware
|
||||
from src.api.dependencies import get_pose_service, get_stream_service, get_hardware_service
|
||||
from src.api.websocket.connection_manager import connection_manager
|
||||
from src.api.websocket.pose_stream import PoseStreamHandler
|
||||
|
||||
# Configure logging
|
||||
settings = get_settings()
|
||||
logging.config.dictConfig(settings.get_logging_config())
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan manager."""
|
||||
logger.info("Starting WiFi-DensePose API...")
|
||||
|
||||
try:
|
||||
# Initialize services
|
||||
await initialize_services(app)
|
||||
|
||||
# Start background tasks
|
||||
await start_background_tasks(app)
|
||||
|
||||
logger.info("WiFi-DensePose API started successfully")
|
||||
|
||||
yield
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start application: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Cleanup on shutdown
|
||||
logger.info("Shutting down WiFi-DensePose API...")
|
||||
await cleanup_services(app)
|
||||
logger.info("WiFi-DensePose API shutdown complete")
|
||||
|
||||
|
||||
async def initialize_services(app: FastAPI):
|
||||
"""Initialize application services."""
|
||||
try:
|
||||
# Initialize hardware service
|
||||
hardware_service = get_hardware_service()
|
||||
await hardware_service.initialize()
|
||||
|
||||
# Initialize pose service
|
||||
pose_service = get_pose_service()
|
||||
await pose_service.initialize()
|
||||
|
||||
# Initialize stream service
|
||||
stream_service = get_stream_service()
|
||||
await stream_service.initialize()
|
||||
|
||||
# Initialize pose stream handler
|
||||
pose_stream_handler = PoseStreamHandler(
|
||||
connection_manager=connection_manager,
|
||||
pose_service=pose_service,
|
||||
stream_service=stream_service
|
||||
)
|
||||
|
||||
# Store in app state for access in routes
|
||||
app.state.hardware_service = hardware_service
|
||||
app.state.pose_service = pose_service
|
||||
app.state.stream_service = stream_service
|
||||
app.state.pose_stream_handler = pose_stream_handler
|
||||
|
||||
logger.info("Services initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize services: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def start_background_tasks(app: FastAPI):
|
||||
"""Start background tasks."""
|
||||
try:
|
||||
# Start pose service
|
||||
pose_service = app.state.pose_service
|
||||
await pose_service.start()
|
||||
logger.info("Pose service started")
|
||||
|
||||
# Start pose streaming if enabled
|
||||
if settings.enable_real_time_processing:
|
||||
pose_stream_handler = app.state.pose_stream_handler
|
||||
await pose_stream_handler.start_streaming()
|
||||
|
||||
logger.info("Background tasks started")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start background tasks: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def cleanup_services(app: FastAPI):
|
||||
"""Cleanup services on shutdown."""
|
||||
try:
|
||||
# Stop pose streaming
|
||||
if hasattr(app.state, 'pose_stream_handler'):
|
||||
await app.state.pose_stream_handler.shutdown()
|
||||
|
||||
# Shutdown connection manager
|
||||
await connection_manager.shutdown()
|
||||
|
||||
# Cleanup services
|
||||
if hasattr(app.state, 'stream_service'):
|
||||
await app.state.stream_service.shutdown()
|
||||
|
||||
if hasattr(app.state, 'pose_service'):
|
||||
await app.state.pose_service.stop()
|
||||
|
||||
if hasattr(app.state, 'hardware_service'):
|
||||
await app.state.hardware_service.shutdown()
|
||||
|
||||
logger.info("Services cleaned up successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during cleanup: {e}")
|
||||
|
||||
|
||||
# Create FastAPI application
|
||||
app = FastAPI(
|
||||
title=settings.app_name,
|
||||
version=settings.version,
|
||||
description="WiFi-based human pose estimation and activity recognition API",
|
||||
docs_url=settings.docs_url if not settings.is_production else None,
|
||||
redoc_url=settings.redoc_url if not settings.is_production else None,
|
||||
openapi_url=settings.openapi_url if not settings.is_production else None,
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Add middleware
|
||||
if settings.enable_rate_limiting:
|
||||
app.add_middleware(RateLimitMiddleware)
|
||||
|
||||
if settings.enable_authentication:
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
# Add CORS middleware
|
||||
cors_config = settings.get_cors_config()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
**cors_config
|
||||
)
|
||||
|
||||
# Add trusted host middleware for production
|
||||
if settings.is_production:
|
||||
app.add_middleware(
|
||||
TrustedHostMiddleware,
|
||||
allowed_hosts=settings.allowed_hosts
|
||||
)
|
||||
|
||||
|
||||
# Exception handlers
|
||||
@app.exception_handler(StarletteHTTPException)
|
||||
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
|
||||
"""Handle HTTP exceptions."""
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={
|
||||
"error": {
|
||||
"code": exc.status_code,
|
||||
"message": exc.detail,
|
||||
"type": "http_error"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
"""Handle request validation errors."""
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content={
|
||||
"error": {
|
||||
"code": 422,
|
||||
"message": "Validation error",
|
||||
"type": "validation_error",
|
||||
"details": exc.errors()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def general_exception_handler(request: Request, exc: Exception):
|
||||
"""Handle general exceptions."""
|
||||
logger.error(f"Unhandled exception: {exc}", exc_info=True)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": {
|
||||
"code": 500,
|
||||
"message": "Internal server error",
|
||||
"type": "internal_error"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Middleware for request logging
|
||||
@app.middleware("http")
|
||||
async def log_requests(request: Request, call_next):
|
||||
"""Log all requests."""
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Calculate processing time
|
||||
process_time = asyncio.get_event_loop().time() - start_time
|
||||
|
||||
# Log request
|
||||
logger.info(
|
||||
f"{request.method} {request.url.path} - "
|
||||
f"Status: {response.status_code} - "
|
||||
f"Time: {process_time:.3f}s"
|
||||
)
|
||||
|
||||
# Add processing time header
|
||||
response.headers["X-Process-Time"] = str(process_time)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# Include routers
|
||||
app.include_router(
|
||||
health.router,
|
||||
prefix="/health",
|
||||
tags=["Health"]
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
pose.router,
|
||||
prefix=f"{settings.api_prefix}/pose",
|
||||
tags=["Pose Estimation"]
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
stream.router,
|
||||
prefix=f"{settings.api_prefix}/stream",
|
||||
tags=["Streaming"]
|
||||
)
|
||||
|
||||
|
||||
# Root endpoint
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint with API information."""
|
||||
return {
|
||||
"name": settings.app_name,
|
||||
"version": settings.version,
|
||||
"environment": settings.environment,
|
||||
"docs_url": settings.docs_url,
|
||||
"api_prefix": settings.api_prefix,
|
||||
"features": {
|
||||
"authentication": settings.enable_authentication,
|
||||
"rate_limiting": settings.enable_rate_limiting,
|
||||
"websockets": settings.enable_websockets,
|
||||
"real_time_processing": settings.enable_real_time_processing
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# API information endpoint
|
||||
@app.get(f"{settings.api_prefix}/info")
|
||||
async def api_info():
|
||||
"""Get detailed API information."""
|
||||
domain_config = get_domain_config()
|
||||
|
||||
return {
|
||||
"api": {
|
||||
"name": settings.app_name,
|
||||
"version": settings.version,
|
||||
"environment": settings.environment,
|
||||
"prefix": settings.api_prefix
|
||||
},
|
||||
"configuration": {
|
||||
"zones": len(domain_config.zones),
|
||||
"routers": len(domain_config.routers),
|
||||
"pose_models": len(domain_config.pose_models)
|
||||
},
|
||||
"features": {
|
||||
"authentication": settings.enable_authentication,
|
||||
"rate_limiting": settings.enable_rate_limiting,
|
||||
"websockets": settings.enable_websockets,
|
||||
"real_time_processing": settings.enable_real_time_processing,
|
||||
"historical_data": settings.enable_historical_data
|
||||
},
|
||||
"limits": {
|
||||
"rate_limit_requests": settings.rate_limit_requests,
|
||||
"rate_limit_window": settings.rate_limit_window,
|
||||
"max_websocket_connections": domain_config.streaming.max_connections
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Status endpoint
|
||||
@app.get(f"{settings.api_prefix}/status")
|
||||
async def api_status(request: Request):
|
||||
"""Get current API status."""
|
||||
try:
|
||||
# Get services from app state
|
||||
hardware_service = getattr(request.app.state, 'hardware_service', None)
|
||||
pose_service = getattr(request.app.state, 'pose_service', None)
|
||||
stream_service = getattr(request.app.state, 'stream_service', None)
|
||||
pose_stream_handler = getattr(request.app.state, 'pose_stream_handler', None)
|
||||
|
||||
# Get service statuses
|
||||
status = {
|
||||
"api": {
|
||||
"status": "healthy",
|
||||
"uptime": "unknown",
|
||||
"version": settings.version
|
||||
},
|
||||
"services": {
|
||||
"hardware": await hardware_service.get_status() if hardware_service else {"status": "unavailable"},
|
||||
"pose": await pose_service.get_status() if pose_service else {"status": "unavailable"},
|
||||
"stream": await stream_service.get_status() if stream_service else {"status": "unavailable"}
|
||||
},
|
||||
"streaming": pose_stream_handler.get_stream_status() if pose_stream_handler else {"is_streaming": False},
|
||||
"connections": await connection_manager.get_connection_stats()
|
||||
}
|
||||
|
||||
return status
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting API status: {e}")
|
||||
return {
|
||||
"api": {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Metrics endpoint (if enabled)
|
||||
if settings.metrics_enabled:
|
||||
@app.get(f"{settings.api_prefix}/metrics")
|
||||
async def api_metrics(request: Request):
|
||||
"""Get API metrics."""
|
||||
try:
|
||||
# Get services from app state
|
||||
pose_stream_handler = getattr(request.app.state, 'pose_stream_handler', None)
|
||||
|
||||
metrics = {
|
||||
"connections": await connection_manager.get_metrics(),
|
||||
"streaming": await pose_stream_handler.get_performance_metrics() if pose_stream_handler else {}
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting metrics: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
# Development endpoints (only in development)
|
||||
if settings.is_development and settings.enable_test_endpoints:
|
||||
@app.get(f"{settings.api_prefix}/dev/config")
|
||||
async def dev_config():
|
||||
"""Get current configuration (development only)."""
|
||||
domain_config = get_domain_config()
|
||||
return {
|
||||
"settings": settings.dict(),
|
||||
"domain_config": domain_config.to_dict()
|
||||
}
|
||||
|
||||
@app.post(f"{settings.api_prefix}/dev/reset")
|
||||
async def dev_reset(request: Request):
|
||||
"""Reset services (development only)."""
|
||||
try:
|
||||
# Reset services
|
||||
hardware_service = getattr(request.app.state, 'hardware_service', None)
|
||||
pose_service = getattr(request.app.state, 'pose_service', None)
|
||||
|
||||
if hardware_service:
|
||||
await hardware_service.reset()
|
||||
|
||||
if pose_service:
|
||||
await pose_service.reset()
|
||||
|
||||
return {"message": "Services reset successfully"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error resetting services: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"src.api.main:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
reload=settings.reload,
|
||||
workers=settings.workers if not settings.reload else 1,
|
||||
log_level=settings.log_level.lower()
|
||||
)
|
||||
8
v1/src/api/middleware/__init__.py
Normal file
8
v1/src/api/middleware/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
FastAPI middleware package
|
||||
"""
|
||||
|
||||
from .auth import AuthMiddleware
|
||||
from .rate_limit import RateLimitMiddleware
|
||||
|
||||
__all__ = ["AuthMiddleware", "RateLimitMiddleware"]
|
||||
322
v1/src/api/middleware/auth.py
Normal file
322
v1/src/api/middleware/auth.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
JWT Authentication middleware for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import Request, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from src.config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
"""JWT Authentication middleware."""
|
||||
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
self.settings = get_settings()
|
||||
|
||||
# Paths that don't require authentication
|
||||
self.public_paths = {
|
||||
"/",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/health",
|
||||
"/ready",
|
||||
"/live",
|
||||
"/version",
|
||||
"/metrics"
|
||||
}
|
||||
|
||||
# Paths that require authentication
|
||||
self.protected_paths = {
|
||||
"/api/v1/pose/analyze",
|
||||
"/api/v1/pose/calibrate",
|
||||
"/api/v1/pose/historical",
|
||||
"/api/v1/stream/start",
|
||||
"/api/v1/stream/stop",
|
||||
"/api/v1/stream/clients",
|
||||
"/api/v1/stream/broadcast"
|
||||
}
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""Process request through authentication middleware."""
|
||||
|
||||
# Skip authentication for public paths
|
||||
if self._is_public_path(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Extract and validate token
|
||||
token = self._extract_token(request)
|
||||
|
||||
if token:
|
||||
try:
|
||||
# Verify token and add user info to request state
|
||||
user_data = await self._verify_token(token)
|
||||
request.state.user = user_data
|
||||
request.state.authenticated = True
|
||||
|
||||
logger.debug(f"Authenticated user: {user_data.get('id')}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Token validation failed: {e}")
|
||||
|
||||
# For protected paths, return 401
|
||||
if self._is_protected_path(request.url.path):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"error": {
|
||||
"code": 401,
|
||||
"message": "Invalid or expired token",
|
||||
"type": "authentication_error"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# For other paths, continue without authentication
|
||||
request.state.user = None
|
||||
request.state.authenticated = False
|
||||
else:
|
||||
# No token provided
|
||||
if self._is_protected_path(request.url.path):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"error": {
|
||||
"code": 401,
|
||||
"message": "Authentication required",
|
||||
"type": "authentication_error"
|
||||
}
|
||||
},
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
request.state.user = None
|
||||
request.state.authenticated = False
|
||||
|
||||
# Continue with request processing
|
||||
response = await call_next(request)
|
||||
|
||||
# Add authentication headers to response
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
response.headers["X-User-ID"] = request.state.user.get("id", "")
|
||||
response.headers["X-Authenticated"] = "true"
|
||||
else:
|
||||
response.headers["X-Authenticated"] = "false"
|
||||
|
||||
return response
|
||||
|
||||
def _is_public_path(self, path: str) -> bool:
|
||||
"""Check if path is public (doesn't require authentication)."""
|
||||
# Exact match
|
||||
if path in self.public_paths:
|
||||
return True
|
||||
|
||||
# Pattern matching for public paths
|
||||
public_patterns = [
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/api/v1/pose/current", # Allow anonymous access to current pose data
|
||||
"/api/v1/pose/zones/", # Allow anonymous access to zone data
|
||||
"/api/v1/pose/activities", # Allow anonymous access to activities
|
||||
"/api/v1/pose/stats", # Allow anonymous access to stats
|
||||
"/api/v1/stream/status" # Allow anonymous access to stream status
|
||||
]
|
||||
|
||||
for pattern in public_patterns:
|
||||
if path.startswith(pattern):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_protected_path(self, path: str) -> bool:
|
||||
"""Check if path requires authentication."""
|
||||
# Exact match
|
||||
if path in self.protected_paths:
|
||||
return True
|
||||
|
||||
# Pattern matching for protected paths
|
||||
protected_patterns = [
|
||||
"/api/v1/pose/analyze",
|
||||
"/api/v1/pose/calibrate",
|
||||
"/api/v1/pose/historical",
|
||||
"/api/v1/stream/start",
|
||||
"/api/v1/stream/stop",
|
||||
"/api/v1/stream/clients",
|
||||
"/api/v1/stream/broadcast"
|
||||
]
|
||||
|
||||
for pattern in protected_patterns:
|
||||
if path.startswith(pattern):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _extract_token(self, request: Request) -> Optional[str]:
|
||||
"""Extract JWT token from request."""
|
||||
# Check Authorization header
|
||||
auth_header = request.headers.get("authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
return auth_header.split(" ")[1]
|
||||
|
||||
# Check query parameter (for WebSocket connections)
|
||||
token = request.query_params.get("token")
|
||||
if token:
|
||||
return token
|
||||
|
||||
# Check cookie
|
||||
token = request.cookies.get("access_token")
|
||||
if token:
|
||||
return token
|
||||
|
||||
return None
|
||||
|
||||
async def _verify_token(self, token: str) -> Dict[str, Any]:
|
||||
"""Verify JWT token and return user data."""
|
||||
try:
|
||||
# Decode JWT token
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
self.settings.secret_key,
|
||||
algorithms=[self.settings.jwt_algorithm]
|
||||
)
|
||||
|
||||
# Extract user information
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
raise ValueError("Token missing user ID")
|
||||
|
||||
# Check token expiration
|
||||
exp = payload.get("exp")
|
||||
if exp and datetime.utcnow() > datetime.fromtimestamp(exp):
|
||||
raise ValueError("Token expired")
|
||||
|
||||
# Build user object
|
||||
user_data = {
|
||||
"id": user_id,
|
||||
"username": payload.get("username"),
|
||||
"email": payload.get("email"),
|
||||
"is_admin": payload.get("is_admin", False),
|
||||
"permissions": payload.get("permissions", []),
|
||||
"accessible_zones": payload.get("accessible_zones", []),
|
||||
"token_issued_at": payload.get("iat"),
|
||||
"token_expires_at": payload.get("exp"),
|
||||
"session_id": payload.get("session_id")
|
||||
}
|
||||
|
||||
return user_data
|
||||
|
||||
except JWTError as e:
|
||||
raise ValueError(f"JWT validation failed: {e}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Token verification error: {e}")
|
||||
|
||||
def _log_authentication_event(self, request: Request, event_type: str, details: Dict[str, Any] = None):
|
||||
"""Log authentication events for security monitoring."""
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
user_agent = request.headers.get("user-agent", "unknown")
|
||||
|
||||
log_data = {
|
||||
"event_type": event_type,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"client_ip": client_ip,
|
||||
"user_agent": user_agent,
|
||||
"path": request.url.path,
|
||||
"method": request.method
|
||||
}
|
||||
|
||||
if details:
|
||||
log_data.update(details)
|
||||
|
||||
if event_type in ["authentication_failed", "token_expired", "invalid_token"]:
|
||||
logger.warning(f"Auth event: {log_data}")
|
||||
else:
|
||||
logger.info(f"Auth event: {log_data}")
|
||||
|
||||
|
||||
class TokenBlacklist:
|
||||
"""Simple in-memory token blacklist for logout functionality."""
|
||||
|
||||
def __init__(self):
|
||||
self._blacklisted_tokens = set()
|
||||
self._cleanup_interval = 3600 # 1 hour
|
||||
self._last_cleanup = datetime.utcnow()
|
||||
|
||||
def add_token(self, token: str):
|
||||
"""Add token to blacklist."""
|
||||
self._blacklisted_tokens.add(token)
|
||||
self._cleanup_if_needed()
|
||||
|
||||
def is_blacklisted(self, token: str) -> bool:
|
||||
"""Check if token is blacklisted."""
|
||||
self._cleanup_if_needed()
|
||||
return token in self._blacklisted_tokens
|
||||
|
||||
def _cleanup_if_needed(self):
|
||||
"""Clean up expired tokens from blacklist."""
|
||||
now = datetime.utcnow()
|
||||
if (now - self._last_cleanup).total_seconds() > self._cleanup_interval:
|
||||
# In a real implementation, you would check token expiration
|
||||
# For now, we'll just clear old tokens periodically
|
||||
self._blacklisted_tokens.clear()
|
||||
self._last_cleanup = now
|
||||
|
||||
|
||||
# Global token blacklist instance
|
||||
token_blacklist = TokenBlacklist()
|
||||
|
||||
|
||||
class SecurityHeaders:
|
||||
"""Security headers for API responses."""
|
||||
|
||||
@staticmethod
|
||||
def add_security_headers(response: Response) -> Response:
|
||||
"""Add security headers to response."""
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
response.headers["Content-Security-Policy"] = (
|
||||
"default-src 'self'; "
|
||||
"script-src 'self' 'unsafe-inline'; "
|
||||
"style-src 'self' 'unsafe-inline'; "
|
||||
"img-src 'self' data:; "
|
||||
"connect-src 'self' ws: wss:;"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class APIKeyAuth:
|
||||
"""Alternative API key authentication for service-to-service communication."""
|
||||
|
||||
def __init__(self, api_keys: Dict[str, Dict[str, Any]] = None):
|
||||
self.api_keys = api_keys or {}
|
||||
|
||||
def verify_api_key(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""Verify API key and return associated service info."""
|
||||
if api_key in self.api_keys:
|
||||
return self.api_keys[api_key]
|
||||
return None
|
||||
|
||||
def add_api_key(self, api_key: str, service_info: Dict[str, Any]):
|
||||
"""Add new API key."""
|
||||
self.api_keys[api_key] = service_info
|
||||
|
||||
def revoke_api_key(self, api_key: str):
|
||||
"""Revoke API key."""
|
||||
if api_key in self.api_keys:
|
||||
del self.api_keys[api_key]
|
||||
|
||||
|
||||
# Global API key auth instance
|
||||
api_key_auth = APIKeyAuth()
|
||||
429
v1/src/api/middleware/rate_limit.py
Normal file
429
v1/src/api/middleware/rate_limit.py
Normal file
@@ -0,0 +1,429 @@
|
||||
"""
|
||||
Rate limiting middleware for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict, deque
|
||||
|
||||
from fastapi import Request, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from src.config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Rate limiting middleware with sliding window algorithm."""
|
||||
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
self.settings = get_settings()
|
||||
|
||||
# Rate limit storage (in production, use Redis)
|
||||
self.request_counts = defaultdict(lambda: deque())
|
||||
self.blocked_clients = {}
|
||||
|
||||
# Rate limit configurations
|
||||
self.rate_limits = {
|
||||
"anonymous": {
|
||||
"requests": self.settings.rate_limit_requests,
|
||||
"window": self.settings.rate_limit_window,
|
||||
"burst": 10 # Allow burst of 10 requests
|
||||
},
|
||||
"authenticated": {
|
||||
"requests": self.settings.rate_limit_authenticated_requests,
|
||||
"window": self.settings.rate_limit_window,
|
||||
"burst": 50
|
||||
},
|
||||
"admin": {
|
||||
"requests": 10000, # Very high limit for admins
|
||||
"window": self.settings.rate_limit_window,
|
||||
"burst": 100
|
||||
}
|
||||
}
|
||||
|
||||
# Path-specific rate limits
|
||||
self.path_limits = {
|
||||
"/api/v1/pose/current": {"requests": 60, "window": 60}, # 1 per second
|
||||
"/api/v1/pose/analyze": {"requests": 10, "window": 60}, # 10 per minute
|
||||
"/api/v1/pose/calibrate": {"requests": 1, "window": 300}, # 1 per 5 minutes
|
||||
"/api/v1/stream/start": {"requests": 5, "window": 60}, # 5 per minute
|
||||
"/api/v1/stream/stop": {"requests": 5, "window": 60}, # 5 per minute
|
||||
}
|
||||
|
||||
# Exempt paths from rate limiting
|
||||
self.exempt_paths = {
|
||||
"/health",
|
||||
"/ready",
|
||||
"/live",
|
||||
"/version",
|
||||
"/metrics"
|
||||
}
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""Process request through rate limiting middleware."""
|
||||
|
||||
# Skip rate limiting for exempt paths
|
||||
if self._is_exempt_path(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Get client identifier
|
||||
client_id = self._get_client_id(request)
|
||||
|
||||
# Check if client is temporarily blocked
|
||||
if self._is_client_blocked(client_id):
|
||||
return self._create_rate_limit_response(
|
||||
"Client temporarily blocked due to excessive requests"
|
||||
)
|
||||
|
||||
# Get user type for rate limiting
|
||||
user_type = self._get_user_type(request)
|
||||
|
||||
# Check rate limits
|
||||
rate_limit_result = self._check_rate_limits(
|
||||
client_id,
|
||||
request.url.path,
|
||||
user_type
|
||||
)
|
||||
|
||||
if not rate_limit_result["allowed"]:
|
||||
# Log rate limit violation
|
||||
self._log_rate_limit_violation(request, client_id, rate_limit_result)
|
||||
|
||||
# Check if client should be temporarily blocked
|
||||
if rate_limit_result.get("violations", 0) > 5:
|
||||
self._block_client(client_id, duration=300) # 5 minutes
|
||||
|
||||
return self._create_rate_limit_response(
|
||||
rate_limit_result["message"],
|
||||
retry_after=rate_limit_result.get("retry_after", 60)
|
||||
)
|
||||
|
||||
# Record the request
|
||||
self._record_request(client_id, request.url.path)
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add rate limit headers
|
||||
self._add_rate_limit_headers(response, client_id, user_type)
|
||||
|
||||
return response
|
||||
|
||||
def _is_exempt_path(self, path: str) -> bool:
|
||||
"""Check if path is exempt from rate limiting."""
|
||||
return path in self.exempt_paths
|
||||
|
||||
def _get_client_id(self, request: Request) -> str:
|
||||
"""Get unique client identifier for rate limiting."""
|
||||
# Try to get user ID from request state (set by auth middleware)
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
return f"user:{request.state.user['id']}"
|
||||
|
||||
# Fall back to IP address
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
|
||||
# Include user agent for better identification
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
user_agent_hash = str(hash(user_agent))[:8]
|
||||
|
||||
return f"ip:{client_ip}:{user_agent_hash}"
|
||||
|
||||
def _get_user_type(self, request: Request) -> str:
|
||||
"""Determine user type for rate limiting."""
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
if request.state.user.get("is_admin", False):
|
||||
return "admin"
|
||||
return "authenticated"
|
||||
return "anonymous"
|
||||
|
||||
def _check_rate_limits(self, client_id: str, path: str, user_type: str) -> Dict:
|
||||
"""Check if request is within rate limits."""
|
||||
now = time.time()
|
||||
|
||||
# Get applicable rate limits
|
||||
general_limit = self.rate_limits[user_type]
|
||||
path_limit = self.path_limits.get(path)
|
||||
|
||||
# Check general rate limit
|
||||
general_result = self._check_limit(
|
||||
client_id,
|
||||
"general",
|
||||
general_limit["requests"],
|
||||
general_limit["window"],
|
||||
now
|
||||
)
|
||||
|
||||
if not general_result["allowed"]:
|
||||
return general_result
|
||||
|
||||
# Check path-specific rate limit if exists
|
||||
if path_limit:
|
||||
path_result = self._check_limit(
|
||||
client_id,
|
||||
f"path:{path}",
|
||||
path_limit["requests"],
|
||||
path_limit["window"],
|
||||
now
|
||||
)
|
||||
|
||||
if not path_result["allowed"]:
|
||||
return path_result
|
||||
|
||||
return {"allowed": True}
|
||||
|
||||
def _check_limit(self, client_id: str, limit_type: str, max_requests: int, window: int, now: float) -> Dict:
|
||||
"""Check specific rate limit using sliding window."""
|
||||
key = f"{client_id}:{limit_type}"
|
||||
requests = self.request_counts[key]
|
||||
|
||||
# Remove old requests outside the window
|
||||
cutoff = now - window
|
||||
while requests and requests[0] <= cutoff:
|
||||
requests.popleft()
|
||||
|
||||
# Check if limit exceeded
|
||||
if len(requests) >= max_requests:
|
||||
# Calculate retry after time
|
||||
oldest_request = requests[0] if requests else now
|
||||
retry_after = int(oldest_request + window - now) + 1
|
||||
|
||||
return {
|
||||
"allowed": False,
|
||||
"message": f"Rate limit exceeded: {max_requests} requests per {window} seconds",
|
||||
"retry_after": retry_after,
|
||||
"current_count": len(requests),
|
||||
"limit": max_requests,
|
||||
"window": window
|
||||
}
|
||||
|
||||
return {
|
||||
"allowed": True,
|
||||
"current_count": len(requests),
|
||||
"limit": max_requests,
|
||||
"window": window
|
||||
}
|
||||
|
||||
def _record_request(self, client_id: str, path: str):
|
||||
"""Record a request for rate limiting."""
|
||||
now = time.time()
|
||||
|
||||
# Record general request
|
||||
general_key = f"{client_id}:general"
|
||||
self.request_counts[general_key].append(now)
|
||||
|
||||
# Record path-specific request if path has specific limits
|
||||
if path in self.path_limits:
|
||||
path_key = f"{client_id}:path:{path}"
|
||||
self.request_counts[path_key].append(now)
|
||||
|
||||
def _is_client_blocked(self, client_id: str) -> bool:
|
||||
"""Check if client is temporarily blocked."""
|
||||
if client_id in self.blocked_clients:
|
||||
block_until = self.blocked_clients[client_id]
|
||||
if time.time() < block_until:
|
||||
return True
|
||||
else:
|
||||
# Block expired, remove it
|
||||
del self.blocked_clients[client_id]
|
||||
return False
|
||||
|
||||
def _block_client(self, client_id: str, duration: int):
|
||||
"""Temporarily block a client."""
|
||||
self.blocked_clients[client_id] = time.time() + duration
|
||||
logger.warning(f"Client {client_id} blocked for {duration} seconds due to rate limit violations")
|
||||
|
||||
def _create_rate_limit_response(self, message: str, retry_after: int = 60) -> JSONResponse:
|
||||
"""Create rate limit exceeded response."""
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": {
|
||||
"code": 429,
|
||||
"message": message,
|
||||
"type": "rate_limit_exceeded"
|
||||
}
|
||||
},
|
||||
headers={
|
||||
"Retry-After": str(retry_after),
|
||||
"X-RateLimit-Limit": "Exceeded",
|
||||
"X-RateLimit-Remaining": "0"
|
||||
}
|
||||
)
|
||||
|
||||
def _add_rate_limit_headers(self, response: Response, client_id: str, user_type: str):
|
||||
"""Add rate limit headers to response."""
|
||||
try:
|
||||
general_limit = self.rate_limits[user_type]
|
||||
general_key = f"{client_id}:general"
|
||||
current_requests = len(self.request_counts[general_key])
|
||||
|
||||
remaining = max(0, general_limit["requests"] - current_requests)
|
||||
|
||||
response.headers["X-RateLimit-Limit"] = str(general_limit["requests"])
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
response.headers["X-RateLimit-Window"] = str(general_limit["window"])
|
||||
|
||||
# Add reset time
|
||||
if self.request_counts[general_key]:
|
||||
oldest_request = self.request_counts[general_key][0]
|
||||
reset_time = int(oldest_request + general_limit["window"])
|
||||
response.headers["X-RateLimit-Reset"] = str(reset_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding rate limit headers: {e}")
|
||||
|
||||
def _log_rate_limit_violation(self, request: Request, client_id: str, result: Dict):
|
||||
"""Log rate limit violations for monitoring."""
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
user_agent = request.headers.get("user-agent", "unknown")
|
||||
|
||||
log_data = {
|
||||
"event_type": "rate_limit_violation",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"client_id": client_id,
|
||||
"client_ip": client_ip,
|
||||
"user_agent": user_agent,
|
||||
"path": request.url.path,
|
||||
"method": request.method,
|
||||
"current_count": result.get("current_count"),
|
||||
"limit": result.get("limit"),
|
||||
"window": result.get("window")
|
||||
}
|
||||
|
||||
logger.warning(f"Rate limit violation: {log_data}")
|
||||
|
||||
def cleanup_old_data(self):
|
||||
"""Clean up old rate limiting data (call periodically)."""
|
||||
now = time.time()
|
||||
cutoff = now - 3600 # Keep data for 1 hour
|
||||
|
||||
# Clean up request counts
|
||||
for key in list(self.request_counts.keys()):
|
||||
requests = self.request_counts[key]
|
||||
while requests and requests[0] <= cutoff:
|
||||
requests.popleft()
|
||||
|
||||
# Remove empty deques
|
||||
if not requests:
|
||||
del self.request_counts[key]
|
||||
|
||||
# Clean up expired blocks
|
||||
expired_blocks = [
|
||||
client_id for client_id, block_until in self.blocked_clients.items()
|
||||
if now >= block_until
|
||||
]
|
||||
|
||||
for client_id in expired_blocks:
|
||||
del self.blocked_clients[client_id]
|
||||
|
||||
|
||||
class AdaptiveRateLimit:
|
||||
"""Adaptive rate limiting based on system load."""
|
||||
|
||||
def __init__(self):
|
||||
self.base_limits = {}
|
||||
self.current_multiplier = 1.0
|
||||
self.load_history = deque(maxlen=60) # Keep 1 minute of load data
|
||||
|
||||
def update_system_load(self, cpu_percent: float, memory_percent: float):
|
||||
"""Update system load metrics."""
|
||||
load_score = (cpu_percent + memory_percent) / 2
|
||||
self.load_history.append(load_score)
|
||||
|
||||
# Calculate adaptive multiplier
|
||||
if len(self.load_history) >= 10:
|
||||
avg_load = sum(self.load_history) / len(self.load_history)
|
||||
|
||||
if avg_load > 80:
|
||||
self.current_multiplier = 0.5 # Reduce limits by 50%
|
||||
elif avg_load > 60:
|
||||
self.current_multiplier = 0.7 # Reduce limits by 30%
|
||||
elif avg_load < 30:
|
||||
self.current_multiplier = 1.2 # Increase limits by 20%
|
||||
else:
|
||||
self.current_multiplier = 1.0 # Normal limits
|
||||
|
||||
def get_adjusted_limit(self, base_limit: int) -> int:
|
||||
"""Get adjusted rate limit based on system load."""
|
||||
return max(1, int(base_limit * self.current_multiplier))
|
||||
|
||||
|
||||
class RateLimitStorage:
|
||||
"""Abstract interface for rate limit storage (Redis implementation)."""
|
||||
|
||||
async def get_count(self, key: str, window: int) -> int:
|
||||
"""Get current request count for key within window."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def increment(self, key: str, window: int) -> int:
|
||||
"""Increment request count and return new count."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def is_blocked(self, client_id: str) -> bool:
|
||||
"""Check if client is blocked."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def block_client(self, client_id: str, duration: int):
|
||||
"""Block client for duration seconds."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RedisRateLimitStorage(RateLimitStorage):
|
||||
"""Redis-based rate limit storage for production use."""
|
||||
|
||||
def __init__(self, redis_client):
|
||||
self.redis = redis_client
|
||||
|
||||
async def get_count(self, key: str, window: int) -> int:
|
||||
"""Get current request count using Redis sliding window."""
|
||||
now = time.time()
|
||||
pipeline = self.redis.pipeline()
|
||||
|
||||
# Remove old entries
|
||||
pipeline.zremrangebyscore(key, 0, now - window)
|
||||
|
||||
# Count current entries
|
||||
pipeline.zcard(key)
|
||||
|
||||
results = await pipeline.execute()
|
||||
return results[1]
|
||||
|
||||
async def increment(self, key: str, window: int) -> int:
|
||||
"""Increment request count using Redis."""
|
||||
now = time.time()
|
||||
pipeline = self.redis.pipeline()
|
||||
|
||||
# Add current request
|
||||
pipeline.zadd(key, {str(now): now})
|
||||
|
||||
# Remove old entries
|
||||
pipeline.zremrangebyscore(key, 0, now - window)
|
||||
|
||||
# Set expiration
|
||||
pipeline.expire(key, window + 1)
|
||||
|
||||
# Get count
|
||||
pipeline.zcard(key)
|
||||
|
||||
results = await pipeline.execute()
|
||||
return results[3]
|
||||
|
||||
async def is_blocked(self, client_id: str) -> bool:
|
||||
"""Check if client is blocked."""
|
||||
block_key = f"blocked:{client_id}"
|
||||
return await self.redis.exists(block_key)
|
||||
|
||||
async def block_client(self, client_id: str, duration: int):
|
||||
"""Block client for duration seconds."""
|
||||
block_key = f"blocked:{client_id}"
|
||||
await self.redis.setex(block_key, duration, "1")
|
||||
|
||||
|
||||
# Global adaptive rate limiter instance
|
||||
adaptive_rate_limit = AdaptiveRateLimit()
|
||||
7
v1/src/api/routers/__init__.py
Normal file
7
v1/src/api/routers/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
API routers package
|
||||
"""
|
||||
|
||||
from . import pose, stream, health
|
||||
|
||||
__all__ = ["pose", "stream", "health"]
|
||||
419
v1/src/api/routers/health.py
Normal file
419
v1/src/api/routers/health.py
Normal file
@@ -0,0 +1,419 @@
|
||||
"""
|
||||
Health check API endpoints
|
||||
"""
|
||||
|
||||
import logging
|
||||
import psutil
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.api.dependencies import get_current_user
|
||||
from src.config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Response models
|
||||
class ComponentHealth(BaseModel):
|
||||
"""Health status for a system component."""
|
||||
|
||||
name: str = Field(..., description="Component name")
|
||||
status: str = Field(..., description="Health status (healthy, degraded, unhealthy)")
|
||||
message: Optional[str] = Field(default=None, description="Status message")
|
||||
last_check: datetime = Field(..., description="Last health check timestamp")
|
||||
uptime_seconds: Optional[float] = Field(default=None, description="Component uptime")
|
||||
metrics: Optional[Dict[str, Any]] = Field(default=None, description="Component metrics")
|
||||
|
||||
|
||||
class SystemHealth(BaseModel):
|
||||
"""Overall system health status."""
|
||||
|
||||
status: str = Field(..., description="Overall system status")
|
||||
timestamp: datetime = Field(..., description="Health check timestamp")
|
||||
uptime_seconds: float = Field(..., description="System uptime")
|
||||
components: Dict[str, ComponentHealth] = Field(..., description="Component health status")
|
||||
system_metrics: Dict[str, Any] = Field(..., description="System-level metrics")
|
||||
|
||||
|
||||
class ReadinessCheck(BaseModel):
|
||||
"""System readiness check result."""
|
||||
|
||||
ready: bool = Field(..., description="Whether system is ready to serve requests")
|
||||
timestamp: datetime = Field(..., description="Readiness check timestamp")
|
||||
checks: Dict[str, bool] = Field(..., description="Individual readiness checks")
|
||||
message: str = Field(..., description="Readiness status message")
|
||||
|
||||
|
||||
# Health check endpoints
|
||||
@router.get("/health", response_model=SystemHealth)
|
||||
async def health_check(request: Request):
|
||||
"""Comprehensive system health check."""
|
||||
try:
|
||||
# Get services from app state
|
||||
hardware_service = getattr(request.app.state, 'hardware_service', None)
|
||||
pose_service = getattr(request.app.state, 'pose_service', None)
|
||||
stream_service = getattr(request.app.state, 'stream_service', None)
|
||||
|
||||
timestamp = datetime.utcnow()
|
||||
components = {}
|
||||
overall_status = "healthy"
|
||||
|
||||
# Check hardware service
|
||||
if hardware_service:
|
||||
try:
|
||||
hw_health = await hardware_service.health_check()
|
||||
components["hardware"] = ComponentHealth(
|
||||
name="Hardware Service",
|
||||
status=hw_health["status"],
|
||||
message=hw_health.get("message"),
|
||||
last_check=timestamp,
|
||||
uptime_seconds=hw_health.get("uptime_seconds"),
|
||||
metrics=hw_health.get("metrics")
|
||||
)
|
||||
|
||||
if hw_health["status"] != "healthy":
|
||||
overall_status = "degraded" if overall_status == "healthy" else "unhealthy"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Hardware service health check failed: {e}")
|
||||
components["hardware"] = ComponentHealth(
|
||||
name="Hardware Service",
|
||||
status="unhealthy",
|
||||
message=f"Health check failed: {str(e)}",
|
||||
last_check=timestamp
|
||||
)
|
||||
overall_status = "unhealthy"
|
||||
else:
|
||||
components["hardware"] = ComponentHealth(
|
||||
name="Hardware Service",
|
||||
status="unavailable",
|
||||
message="Service not initialized",
|
||||
last_check=timestamp
|
||||
)
|
||||
overall_status = "degraded"
|
||||
|
||||
# Check pose service
|
||||
if pose_service:
|
||||
try:
|
||||
pose_health = await pose_service.health_check()
|
||||
components["pose"] = ComponentHealth(
|
||||
name="Pose Service",
|
||||
status=pose_health["status"],
|
||||
message=pose_health.get("message"),
|
||||
last_check=timestamp,
|
||||
uptime_seconds=pose_health.get("uptime_seconds"),
|
||||
metrics=pose_health.get("metrics")
|
||||
)
|
||||
|
||||
if pose_health["status"] != "healthy":
|
||||
overall_status = "degraded" if overall_status == "healthy" else "unhealthy"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pose service health check failed: {e}")
|
||||
components["pose"] = ComponentHealth(
|
||||
name="Pose Service",
|
||||
status="unhealthy",
|
||||
message=f"Health check failed: {str(e)}",
|
||||
last_check=timestamp
|
||||
)
|
||||
overall_status = "unhealthy"
|
||||
else:
|
||||
components["pose"] = ComponentHealth(
|
||||
name="Pose Service",
|
||||
status="unavailable",
|
||||
message="Service not initialized",
|
||||
last_check=timestamp
|
||||
)
|
||||
overall_status = "degraded"
|
||||
|
||||
# Check stream service
|
||||
if stream_service:
|
||||
try:
|
||||
stream_health = await stream_service.health_check()
|
||||
components["stream"] = ComponentHealth(
|
||||
name="Stream Service",
|
||||
status=stream_health["status"],
|
||||
message=stream_health.get("message"),
|
||||
last_check=timestamp,
|
||||
uptime_seconds=stream_health.get("uptime_seconds"),
|
||||
metrics=stream_health.get("metrics")
|
||||
)
|
||||
|
||||
if stream_health["status"] != "healthy":
|
||||
overall_status = "degraded" if overall_status == "healthy" else "unhealthy"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Stream service health check failed: {e}")
|
||||
components["stream"] = ComponentHealth(
|
||||
name="Stream Service",
|
||||
status="unhealthy",
|
||||
message=f"Health check failed: {str(e)}",
|
||||
last_check=timestamp
|
||||
)
|
||||
overall_status = "unhealthy"
|
||||
else:
|
||||
components["stream"] = ComponentHealth(
|
||||
name="Stream Service",
|
||||
status="unavailable",
|
||||
message="Service not initialized",
|
||||
last_check=timestamp
|
||||
)
|
||||
overall_status = "degraded"
|
||||
|
||||
# Get system metrics
|
||||
system_metrics = get_system_metrics()
|
||||
|
||||
# Calculate system uptime (placeholder - would need actual startup time)
|
||||
uptime_seconds = 0.0 # TODO: Implement actual uptime tracking
|
||||
|
||||
return SystemHealth(
|
||||
status=overall_status,
|
||||
timestamp=timestamp,
|
||||
uptime_seconds=uptime_seconds,
|
||||
components=components,
|
||||
system_metrics=system_metrics
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Health check failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Health check failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/ready", response_model=ReadinessCheck)
|
||||
async def readiness_check(request: Request):
|
||||
"""Check if system is ready to serve requests."""
|
||||
try:
|
||||
timestamp = datetime.utcnow()
|
||||
checks = {}
|
||||
|
||||
# Check if services are available in app state
|
||||
if hasattr(request.app.state, 'pose_service') and request.app.state.pose_service:
|
||||
try:
|
||||
checks["pose_ready"] = await request.app.state.pose_service.is_ready()
|
||||
except Exception as e:
|
||||
logger.warning(f"Pose service readiness check failed: {e}")
|
||||
checks["pose_ready"] = False
|
||||
else:
|
||||
checks["pose_ready"] = False
|
||||
|
||||
if hasattr(request.app.state, 'stream_service') and request.app.state.stream_service:
|
||||
try:
|
||||
checks["stream_ready"] = await request.app.state.stream_service.is_ready()
|
||||
except Exception as e:
|
||||
logger.warning(f"Stream service readiness check failed: {e}")
|
||||
checks["stream_ready"] = False
|
||||
else:
|
||||
checks["stream_ready"] = False
|
||||
|
||||
# Hardware service check (basic availability)
|
||||
checks["hardware_ready"] = True # Basic readiness - API is responding
|
||||
|
||||
# Check system resources
|
||||
checks["memory_available"] = check_memory_availability()
|
||||
checks["disk_space_available"] = check_disk_space()
|
||||
|
||||
# Application is ready if at least the basic services are available
|
||||
# For now, we'll consider it ready if the API is responding
|
||||
ready = True # Basic readiness
|
||||
|
||||
message = "System is ready" if ready else "System is not ready"
|
||||
if not ready:
|
||||
failed_checks = [name for name, status in checks.items() if not status]
|
||||
message += f". Failed checks: {', '.join(failed_checks)}"
|
||||
|
||||
return ReadinessCheck(
|
||||
ready=ready,
|
||||
timestamp=timestamp,
|
||||
checks=checks,
|
||||
message=message
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Readiness check failed: {e}")
|
||||
return ReadinessCheck(
|
||||
ready=False,
|
||||
timestamp=datetime.utcnow(),
|
||||
checks={},
|
||||
message=f"Readiness check failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/live")
|
||||
async def liveness_check():
|
||||
"""Simple liveness check for load balancers."""
|
||||
return {
|
||||
"status": "alive",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/metrics")
|
||||
async def get_health_metrics(
|
||||
request: Request,
|
||||
current_user: Optional[Dict] = Depends(get_current_user)
|
||||
):
|
||||
"""Get detailed system metrics."""
|
||||
try:
|
||||
metrics = get_system_metrics()
|
||||
|
||||
# Add additional metrics if authenticated
|
||||
if current_user:
|
||||
metrics.update(get_detailed_metrics())
|
||||
|
||||
return {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"metrics": metrics
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting system metrics: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get system metrics: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/version")
|
||||
async def get_version_info():
|
||||
"""Get application version information."""
|
||||
settings = get_settings()
|
||||
|
||||
return {
|
||||
"name": settings.app_name,
|
||||
"version": settings.version,
|
||||
"environment": settings.environment,
|
||||
"debug": settings.debug,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
def get_system_metrics() -> Dict[str, Any]:
|
||||
"""Get basic system metrics."""
|
||||
try:
|
||||
# CPU metrics
|
||||
cpu_percent = psutil.cpu_percent(interval=1)
|
||||
cpu_count = psutil.cpu_count()
|
||||
|
||||
# Memory metrics
|
||||
memory = psutil.virtual_memory()
|
||||
memory_metrics = {
|
||||
"total_gb": round(memory.total / (1024**3), 2),
|
||||
"available_gb": round(memory.available / (1024**3), 2),
|
||||
"used_gb": round(memory.used / (1024**3), 2),
|
||||
"percent": memory.percent
|
||||
}
|
||||
|
||||
# Disk metrics
|
||||
disk = psutil.disk_usage('/')
|
||||
disk_metrics = {
|
||||
"total_gb": round(disk.total / (1024**3), 2),
|
||||
"free_gb": round(disk.free / (1024**3), 2),
|
||||
"used_gb": round(disk.used / (1024**3), 2),
|
||||
"percent": round((disk.used / disk.total) * 100, 2)
|
||||
}
|
||||
|
||||
# Network metrics (basic)
|
||||
network = psutil.net_io_counters()
|
||||
network_metrics = {
|
||||
"bytes_sent": network.bytes_sent,
|
||||
"bytes_recv": network.bytes_recv,
|
||||
"packets_sent": network.packets_sent,
|
||||
"packets_recv": network.packets_recv
|
||||
}
|
||||
|
||||
return {
|
||||
"cpu": {
|
||||
"percent": cpu_percent,
|
||||
"count": cpu_count
|
||||
},
|
||||
"memory": memory_metrics,
|
||||
"disk": disk_metrics,
|
||||
"network": network_metrics
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting system metrics: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def get_detailed_metrics() -> Dict[str, Any]:
|
||||
"""Get detailed system metrics (requires authentication)."""
|
||||
try:
|
||||
# Process metrics
|
||||
process = psutil.Process()
|
||||
process_metrics = {
|
||||
"pid": process.pid,
|
||||
"cpu_percent": process.cpu_percent(),
|
||||
"memory_mb": round(process.memory_info().rss / (1024**2), 2),
|
||||
"num_threads": process.num_threads(),
|
||||
"create_time": datetime.fromtimestamp(process.create_time()).isoformat()
|
||||
}
|
||||
|
||||
# Load average (Unix-like systems)
|
||||
load_avg = None
|
||||
try:
|
||||
load_avg = psutil.getloadavg()
|
||||
except AttributeError:
|
||||
# Windows doesn't have load average
|
||||
pass
|
||||
|
||||
# Temperature sensors (if available)
|
||||
temperatures = {}
|
||||
try:
|
||||
temps = psutil.sensors_temperatures()
|
||||
for name, entries in temps.items():
|
||||
temperatures[name] = [
|
||||
{"label": entry.label, "current": entry.current}
|
||||
for entry in entries
|
||||
]
|
||||
except AttributeError:
|
||||
# Not available on all systems
|
||||
pass
|
||||
|
||||
detailed = {
|
||||
"process": process_metrics
|
||||
}
|
||||
|
||||
if load_avg:
|
||||
detailed["load_average"] = {
|
||||
"1min": load_avg[0],
|
||||
"5min": load_avg[1],
|
||||
"15min": load_avg[2]
|
||||
}
|
||||
|
||||
if temperatures:
|
||||
detailed["temperatures"] = temperatures
|
||||
|
||||
return detailed
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting detailed metrics: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def check_memory_availability() -> bool:
|
||||
"""Check if sufficient memory is available."""
|
||||
try:
|
||||
memory = psutil.virtual_memory()
|
||||
# Consider system ready if less than 90% memory is used
|
||||
return memory.percent < 90.0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def check_disk_space() -> bool:
|
||||
"""Check if sufficient disk space is available."""
|
||||
try:
|
||||
disk = psutil.disk_usage('/')
|
||||
# Consider system ready if more than 1GB free space
|
||||
free_gb = disk.free / (1024**3)
|
||||
return free_gb > 1.0
|
||||
except Exception:
|
||||
return False
|
||||
420
v1/src/api/routers/pose.py
Normal file
420
v1/src/api/routers/pose.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""
|
||||
Pose estimation API endpoints
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.api.dependencies import (
|
||||
get_pose_service,
|
||||
get_hardware_service,
|
||||
get_current_user,
|
||||
require_auth
|
||||
)
|
||||
from src.services.pose_service import PoseService
|
||||
from src.services.hardware_service import HardwareService
|
||||
from src.config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Request/Response models
|
||||
class PoseEstimationRequest(BaseModel):
|
||||
"""Request model for pose estimation."""
|
||||
|
||||
zone_ids: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Specific zones to analyze (all zones if not specified)"
|
||||
)
|
||||
confidence_threshold: Optional[float] = Field(
|
||||
default=None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum confidence threshold for detections"
|
||||
)
|
||||
max_persons: Optional[int] = Field(
|
||||
default=None,
|
||||
ge=1,
|
||||
le=50,
|
||||
description="Maximum number of persons to detect"
|
||||
)
|
||||
include_keypoints: bool = Field(
|
||||
default=True,
|
||||
description="Include detailed keypoint data"
|
||||
)
|
||||
include_segmentation: bool = Field(
|
||||
default=False,
|
||||
description="Include DensePose segmentation masks"
|
||||
)
|
||||
|
||||
|
||||
class PersonPose(BaseModel):
|
||||
"""Person pose data model."""
|
||||
|
||||
person_id: str = Field(..., description="Unique person identifier")
|
||||
confidence: float = Field(..., description="Detection confidence score")
|
||||
bounding_box: Dict[str, float] = Field(..., description="Person bounding box")
|
||||
keypoints: Optional[List[Dict[str, Any]]] = Field(
|
||||
default=None,
|
||||
description="Body keypoints with coordinates and confidence"
|
||||
)
|
||||
segmentation: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="DensePose segmentation data"
|
||||
)
|
||||
zone_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Zone where person is detected"
|
||||
)
|
||||
activity: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Detected activity"
|
||||
)
|
||||
timestamp: datetime = Field(..., description="Detection timestamp")
|
||||
|
||||
|
||||
class PoseEstimationResponse(BaseModel):
|
||||
"""Response model for pose estimation."""
|
||||
|
||||
timestamp: datetime = Field(..., description="Analysis timestamp")
|
||||
frame_id: str = Field(..., description="Unique frame identifier")
|
||||
persons: List[PersonPose] = Field(..., description="Detected persons")
|
||||
zone_summary: Dict[str, int] = Field(..., description="Person count per zone")
|
||||
processing_time_ms: float = Field(..., description="Processing time in milliseconds")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
|
||||
|
||||
|
||||
class HistoricalDataRequest(BaseModel):
|
||||
"""Request model for historical pose data."""
|
||||
|
||||
start_time: datetime = Field(..., description="Start time for data query")
|
||||
end_time: datetime = Field(..., description="End time for data query")
|
||||
zone_ids: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Filter by specific zones"
|
||||
)
|
||||
aggregation_interval: Optional[int] = Field(
|
||||
default=300,
|
||||
ge=60,
|
||||
le=3600,
|
||||
description="Aggregation interval in seconds"
|
||||
)
|
||||
include_raw_data: bool = Field(
|
||||
default=False,
|
||||
description="Include raw detection data"
|
||||
)
|
||||
|
||||
|
||||
# Endpoints
|
||||
@router.get("/current", response_model=PoseEstimationResponse)
|
||||
async def get_current_pose_estimation(
|
||||
request: PoseEstimationRequest = Depends(),
|
||||
pose_service: PoseService = Depends(get_pose_service),
|
||||
current_user: Optional[Dict] = Depends(get_current_user)
|
||||
):
|
||||
"""Get current pose estimation from WiFi signals."""
|
||||
try:
|
||||
logger.info(f"Processing pose estimation request from user: {current_user.get('id') if current_user else 'anonymous'}")
|
||||
|
||||
# Get current pose estimation
|
||||
result = await pose_service.estimate_poses(
|
||||
zone_ids=request.zone_ids,
|
||||
confidence_threshold=request.confidence_threshold,
|
||||
max_persons=request.max_persons,
|
||||
include_keypoints=request.include_keypoints,
|
||||
include_segmentation=request.include_segmentation
|
||||
)
|
||||
|
||||
return PoseEstimationResponse(**result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in pose estimation: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Pose estimation failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/analyze", response_model=PoseEstimationResponse)
|
||||
async def analyze_pose_data(
|
||||
request: PoseEstimationRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
pose_service: PoseService = Depends(get_pose_service),
|
||||
current_user: Dict = Depends(require_auth)
|
||||
):
|
||||
"""Trigger pose analysis with custom parameters."""
|
||||
try:
|
||||
logger.info(f"Custom pose analysis requested by user: {current_user['id']}")
|
||||
|
||||
# Trigger analysis
|
||||
result = await pose_service.analyze_with_params(
|
||||
zone_ids=request.zone_ids,
|
||||
confidence_threshold=request.confidence_threshold,
|
||||
max_persons=request.max_persons,
|
||||
include_keypoints=request.include_keypoints,
|
||||
include_segmentation=request.include_segmentation
|
||||
)
|
||||
|
||||
# Schedule background processing if needed
|
||||
if request.include_segmentation:
|
||||
background_tasks.add_task(
|
||||
pose_service.process_segmentation_data,
|
||||
result["frame_id"]
|
||||
)
|
||||
|
||||
return PoseEstimationResponse(**result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in pose analysis: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Pose analysis failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/zones/{zone_id}/occupancy")
|
||||
async def get_zone_occupancy(
|
||||
zone_id: str,
|
||||
pose_service: PoseService = Depends(get_pose_service),
|
||||
current_user: Optional[Dict] = Depends(get_current_user)
|
||||
):
|
||||
"""Get current occupancy for a specific zone."""
|
||||
try:
|
||||
occupancy = await pose_service.get_zone_occupancy(zone_id)
|
||||
|
||||
if occupancy is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Zone '{zone_id}' not found"
|
||||
)
|
||||
|
||||
return {
|
||||
"zone_id": zone_id,
|
||||
"current_occupancy": occupancy["count"],
|
||||
"max_occupancy": occupancy.get("max_occupancy"),
|
||||
"persons": occupancy["persons"],
|
||||
"timestamp": occupancy["timestamp"]
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting zone occupancy: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get zone occupancy: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/zones/summary")
|
||||
async def get_zones_summary(
|
||||
pose_service: PoseService = Depends(get_pose_service),
|
||||
current_user: Optional[Dict] = Depends(get_current_user)
|
||||
):
|
||||
"""Get occupancy summary for all zones."""
|
||||
try:
|
||||
summary = await pose_service.get_zones_summary()
|
||||
|
||||
return {
|
||||
"timestamp": datetime.utcnow(),
|
||||
"total_persons": summary["total_persons"],
|
||||
"zones": summary["zones"],
|
||||
"active_zones": summary["active_zones"]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting zones summary: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get zones summary: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/historical")
|
||||
async def get_historical_data(
|
||||
request: HistoricalDataRequest,
|
||||
pose_service: PoseService = Depends(get_pose_service),
|
||||
current_user: Dict = Depends(require_auth)
|
||||
):
|
||||
"""Get historical pose estimation data."""
|
||||
try:
|
||||
# Validate time range
|
||||
if request.end_time <= request.start_time:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="End time must be after start time"
|
||||
)
|
||||
|
||||
# Limit query range to prevent excessive data
|
||||
max_range = timedelta(days=7)
|
||||
if request.end_time - request.start_time > max_range:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Query range cannot exceed 7 days"
|
||||
)
|
||||
|
||||
data = await pose_service.get_historical_data(
|
||||
start_time=request.start_time,
|
||||
end_time=request.end_time,
|
||||
zone_ids=request.zone_ids,
|
||||
aggregation_interval=request.aggregation_interval,
|
||||
include_raw_data=request.include_raw_data
|
||||
)
|
||||
|
||||
return {
|
||||
"query": {
|
||||
"start_time": request.start_time,
|
||||
"end_time": request.end_time,
|
||||
"zone_ids": request.zone_ids,
|
||||
"aggregation_interval": request.aggregation_interval
|
||||
},
|
||||
"data": data["aggregated_data"],
|
||||
"raw_data": data.get("raw_data") if request.include_raw_data else None,
|
||||
"total_records": data["total_records"]
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting historical data: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get historical data: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/activities")
|
||||
async def get_detected_activities(
|
||||
zone_id: Optional[str] = Query(None, description="Filter by zone ID"),
|
||||
limit: int = Query(10, ge=1, le=100, description="Maximum number of activities"),
|
||||
pose_service: PoseService = Depends(get_pose_service),
|
||||
current_user: Optional[Dict] = Depends(get_current_user)
|
||||
):
|
||||
"""Get recently detected activities."""
|
||||
try:
|
||||
activities = await pose_service.get_recent_activities(
|
||||
zone_id=zone_id,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return {
|
||||
"activities": activities,
|
||||
"total_count": len(activities),
|
||||
"zone_id": zone_id
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting activities: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get activities: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/calibrate")
|
||||
async def calibrate_pose_system(
|
||||
background_tasks: BackgroundTasks,
|
||||
pose_service: PoseService = Depends(get_pose_service),
|
||||
hardware_service: HardwareService = Depends(get_hardware_service),
|
||||
current_user: Dict = Depends(require_auth)
|
||||
):
|
||||
"""Calibrate the pose estimation system."""
|
||||
try:
|
||||
logger.info(f"Pose system calibration initiated by user: {current_user['id']}")
|
||||
|
||||
# Check if calibration is already in progress
|
||||
if await pose_service.is_calibrating():
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Calibration already in progress"
|
||||
)
|
||||
|
||||
# Start calibration process
|
||||
calibration_id = await pose_service.start_calibration()
|
||||
|
||||
# Schedule background calibration task
|
||||
background_tasks.add_task(
|
||||
pose_service.run_calibration,
|
||||
calibration_id
|
||||
)
|
||||
|
||||
return {
|
||||
"calibration_id": calibration_id,
|
||||
"status": "started",
|
||||
"estimated_duration_minutes": 5,
|
||||
"message": "Calibration process started"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting calibration: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to start calibration: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/calibration/status")
|
||||
async def get_calibration_status(
|
||||
pose_service: PoseService = Depends(get_pose_service),
|
||||
current_user: Dict = Depends(require_auth)
|
||||
):
|
||||
"""Get current calibration status."""
|
||||
try:
|
||||
status = await pose_service.get_calibration_status()
|
||||
|
||||
return {
|
||||
"is_calibrating": status["is_calibrating"],
|
||||
"calibration_id": status.get("calibration_id"),
|
||||
"progress_percent": status.get("progress_percent", 0),
|
||||
"current_step": status.get("current_step"),
|
||||
"estimated_remaining_minutes": status.get("estimated_remaining_minutes"),
|
||||
"last_calibration": status.get("last_calibration")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting calibration status: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get calibration status: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_pose_statistics(
|
||||
hours: int = Query(24, ge=1, le=168, description="Hours of data to analyze"),
|
||||
pose_service: PoseService = Depends(get_pose_service),
|
||||
current_user: Optional[Dict] = Depends(get_current_user)
|
||||
):
|
||||
"""Get pose estimation statistics."""
|
||||
try:
|
||||
end_time = datetime.utcnow()
|
||||
start_time = end_time - timedelta(hours=hours)
|
||||
|
||||
stats = await pose_service.get_statistics(
|
||||
start_time=start_time,
|
||||
end_time=end_time
|
||||
)
|
||||
|
||||
return {
|
||||
"period": {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"hours": hours
|
||||
},
|
||||
"statistics": stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting statistics: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get statistics: {str(e)}"
|
||||
)
|
||||
468
v1/src/api/routers/stream.py
Normal file
468
v1/src/api/routers/stream.py
Normal file
@@ -0,0 +1,468 @@
|
||||
"""
|
||||
WebSocket streaming API endpoints
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, HTTPException, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.api.dependencies import (
|
||||
get_stream_service,
|
||||
get_pose_service,
|
||||
get_current_user_ws,
|
||||
require_auth
|
||||
)
|
||||
from src.api.websocket.connection_manager import ConnectionManager
|
||||
from src.services.stream_service import StreamService
|
||||
from src.services.pose_service import PoseService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
# Initialize connection manager
|
||||
connection_manager = ConnectionManager()
|
||||
|
||||
|
||||
# Request/Response models
|
||||
class StreamSubscriptionRequest(BaseModel):
|
||||
"""Request model for stream subscription."""
|
||||
|
||||
zone_ids: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Zones to subscribe to (all zones if not specified)"
|
||||
)
|
||||
stream_types: List[str] = Field(
|
||||
default=["pose_data"],
|
||||
description="Types of data to stream"
|
||||
)
|
||||
min_confidence: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum confidence threshold for streaming"
|
||||
)
|
||||
max_fps: int = Field(
|
||||
default=30,
|
||||
ge=1,
|
||||
le=60,
|
||||
description="Maximum frames per second"
|
||||
)
|
||||
include_metadata: bool = Field(
|
||||
default=True,
|
||||
description="Include metadata in stream"
|
||||
)
|
||||
|
||||
|
||||
class StreamStatus(BaseModel):
|
||||
"""Stream status model."""
|
||||
|
||||
is_active: bool = Field(..., description="Whether streaming is active")
|
||||
connected_clients: int = Field(..., description="Number of connected clients")
|
||||
streams: List[Dict[str, Any]] = Field(..., description="Active streams")
|
||||
uptime_seconds: float = Field(..., description="Stream uptime in seconds")
|
||||
|
||||
|
||||
# WebSocket endpoints
|
||||
@router.websocket("/pose")
|
||||
async def websocket_pose_stream(
|
||||
websocket: WebSocket,
|
||||
zone_ids: Optional[str] = Query(None, description="Comma-separated zone IDs"),
|
||||
min_confidence: float = Query(0.5, ge=0.0, le=1.0),
|
||||
max_fps: int = Query(30, ge=1, le=60),
|
||||
token: Optional[str] = Query(None, description="Authentication token")
|
||||
):
|
||||
"""WebSocket endpoint for real-time pose data streaming."""
|
||||
client_id = None
|
||||
|
||||
try:
|
||||
# Accept WebSocket connection
|
||||
await websocket.accept()
|
||||
|
||||
# Check authentication if enabled
|
||||
from src.config.settings import get_settings
|
||||
settings = get_settings()
|
||||
|
||||
if settings.enable_authentication and not token:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Authentication token required"
|
||||
})
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
# Parse zone IDs
|
||||
zone_list = None
|
||||
if zone_ids:
|
||||
zone_list = [zone.strip() for zone in zone_ids.split(",") if zone.strip()]
|
||||
|
||||
# Register client with connection manager
|
||||
client_id = await connection_manager.connect(
|
||||
websocket=websocket,
|
||||
stream_type="pose",
|
||||
zone_ids=zone_list,
|
||||
min_confidence=min_confidence,
|
||||
max_fps=max_fps
|
||||
)
|
||||
|
||||
logger.info(f"WebSocket client {client_id} connected for pose streaming")
|
||||
|
||||
# Send initial connection confirmation
|
||||
await websocket.send_json({
|
||||
"type": "connection_established",
|
||||
"client_id": client_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"config": {
|
||||
"zone_ids": zone_list,
|
||||
"min_confidence": min_confidence,
|
||||
"max_fps": max_fps
|
||||
}
|
||||
})
|
||||
|
||||
# Keep connection alive and handle incoming messages
|
||||
while True:
|
||||
try:
|
||||
# Wait for client messages (ping, config updates, etc.)
|
||||
message = await websocket.receive_text()
|
||||
data = json.loads(message)
|
||||
|
||||
await handle_websocket_message(client_id, data, websocket)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Invalid JSON format"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling WebSocket message: {e}")
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Internal server error"
|
||||
})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket client {client_id} disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error: {e}")
|
||||
finally:
|
||||
if client_id:
|
||||
await connection_manager.disconnect(client_id)
|
||||
|
||||
|
||||
@router.websocket("/events")
|
||||
async def websocket_events_stream(
|
||||
websocket: WebSocket,
|
||||
event_types: Optional[str] = Query(None, description="Comma-separated event types"),
|
||||
zone_ids: Optional[str] = Query(None, description="Comma-separated zone IDs"),
|
||||
token: Optional[str] = Query(None, description="Authentication token")
|
||||
):
|
||||
"""WebSocket endpoint for real-time event streaming."""
|
||||
client_id = None
|
||||
|
||||
try:
|
||||
await websocket.accept()
|
||||
|
||||
# Check authentication if enabled
|
||||
from src.config.settings import get_settings
|
||||
settings = get_settings()
|
||||
|
||||
if settings.enable_authentication and not token:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Authentication token required"
|
||||
})
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
# Parse parameters
|
||||
event_list = None
|
||||
if event_types:
|
||||
event_list = [event.strip() for event in event_types.split(",") if event.strip()]
|
||||
|
||||
zone_list = None
|
||||
if zone_ids:
|
||||
zone_list = [zone.strip() for zone in zone_ids.split(",") if zone.strip()]
|
||||
|
||||
# Register client
|
||||
client_id = await connection_manager.connect(
|
||||
websocket=websocket,
|
||||
stream_type="events",
|
||||
zone_ids=zone_list,
|
||||
event_types=event_list
|
||||
)
|
||||
|
||||
logger.info(f"WebSocket client {client_id} connected for event streaming")
|
||||
|
||||
# Send confirmation
|
||||
await websocket.send_json({
|
||||
"type": "connection_established",
|
||||
"client_id": client_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"config": {
|
||||
"event_types": event_list,
|
||||
"zone_ids": zone_list
|
||||
}
|
||||
})
|
||||
|
||||
# Handle messages
|
||||
while True:
|
||||
try:
|
||||
message = await websocket.receive_text()
|
||||
data = json.loads(message)
|
||||
await handle_websocket_message(client_id, data, websocket)
|
||||
except WebSocketDisconnect:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in events WebSocket: {e}")
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"Events WebSocket client {client_id} disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"Events WebSocket error: {e}")
|
||||
finally:
|
||||
if client_id:
|
||||
await connection_manager.disconnect(client_id)
|
||||
|
||||
|
||||
async def handle_websocket_message(client_id: str, data: Dict[str, Any], websocket: WebSocket):
|
||||
"""Handle incoming WebSocket messages."""
|
||||
message_type = data.get("type")
|
||||
|
||||
if message_type == "ping":
|
||||
await websocket.send_json({
|
||||
"type": "pong",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
|
||||
elif message_type == "update_config":
|
||||
# Update client configuration
|
||||
config = data.get("config", {})
|
||||
await connection_manager.update_client_config(client_id, config)
|
||||
|
||||
await websocket.send_json({
|
||||
"type": "config_updated",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"config": config
|
||||
})
|
||||
|
||||
elif message_type == "get_status":
|
||||
# Send current status
|
||||
status = await connection_manager.get_client_status(client_id)
|
||||
await websocket.send_json({
|
||||
"type": "status",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"status": status
|
||||
})
|
||||
|
||||
else:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": f"Unknown message type: {message_type}"
|
||||
})
|
||||
|
||||
|
||||
# HTTP endpoints for stream management
|
||||
@router.get("/status", response_model=StreamStatus)
|
||||
async def get_stream_status(
|
||||
stream_service: StreamService = Depends(get_stream_service)
|
||||
):
|
||||
"""Get current streaming status."""
|
||||
try:
|
||||
status = await stream_service.get_status()
|
||||
connections = await connection_manager.get_connection_stats()
|
||||
|
||||
# Calculate uptime (simplified for now)
|
||||
uptime_seconds = 0.0
|
||||
if status.get("running", False):
|
||||
uptime_seconds = 3600.0 # Default 1 hour for demo
|
||||
|
||||
return StreamStatus(
|
||||
is_active=status.get("running", False),
|
||||
connected_clients=connections.get("total_clients", status["connections"]["active"]),
|
||||
streams=[{
|
||||
"type": "pose_stream",
|
||||
"active": status.get("running", False),
|
||||
"buffer_size": status["buffers"]["pose_buffer_size"]
|
||||
}],
|
||||
uptime_seconds=uptime_seconds
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting stream status: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get stream status: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/start")
|
||||
async def start_streaming(
|
||||
stream_service: StreamService = Depends(get_stream_service),
|
||||
current_user: Dict = Depends(require_auth)
|
||||
):
|
||||
"""Start the streaming service."""
|
||||
try:
|
||||
logger.info(f"Starting streaming service by user: {current_user['id']}")
|
||||
|
||||
if await stream_service.is_active():
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={"message": "Streaming service is already active"}
|
||||
)
|
||||
|
||||
await stream_service.start()
|
||||
|
||||
return {
|
||||
"message": "Streaming service started successfully",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting streaming: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to start streaming: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/stop")
|
||||
async def stop_streaming(
|
||||
stream_service: StreamService = Depends(get_stream_service),
|
||||
current_user: Dict = Depends(require_auth)
|
||||
):
|
||||
"""Stop the streaming service."""
|
||||
try:
|
||||
logger.info(f"Stopping streaming service by user: {current_user['id']}")
|
||||
|
||||
await stream_service.stop()
|
||||
await connection_manager.disconnect_all()
|
||||
|
||||
return {
|
||||
"message": "Streaming service stopped successfully",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping streaming: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to stop streaming: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/clients")
|
||||
async def get_connected_clients(
|
||||
current_user: Dict = Depends(require_auth)
|
||||
):
|
||||
"""Get list of connected WebSocket clients."""
|
||||
try:
|
||||
clients = await connection_manager.get_connected_clients()
|
||||
|
||||
return {
|
||||
"total_clients": len(clients),
|
||||
"clients": clients,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting connected clients: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get connected clients: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/clients/{client_id}")
|
||||
async def disconnect_client(
|
||||
client_id: str,
|
||||
current_user: Dict = Depends(require_auth)
|
||||
):
|
||||
"""Disconnect a specific WebSocket client."""
|
||||
try:
|
||||
logger.info(f"Disconnecting client {client_id} by user: {current_user['id']}")
|
||||
|
||||
success = await connection_manager.disconnect(client_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Client {client_id} not found"
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Client {client_id} disconnected successfully",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error disconnecting client: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to disconnect client: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/broadcast")
|
||||
async def broadcast_message(
|
||||
message: Dict[str, Any],
|
||||
stream_type: Optional[str] = Query(None, description="Target stream type"),
|
||||
zone_ids: Optional[List[str]] = Query(None, description="Target zone IDs"),
|
||||
current_user: Dict = Depends(require_auth)
|
||||
):
|
||||
"""Broadcast a message to connected WebSocket clients."""
|
||||
try:
|
||||
logger.info(f"Broadcasting message by user: {current_user['id']}")
|
||||
|
||||
# Add metadata to message
|
||||
broadcast_data = {
|
||||
**message,
|
||||
"broadcast_timestamp": datetime.utcnow().isoformat(),
|
||||
"sender": current_user["id"]
|
||||
}
|
||||
|
||||
# Broadcast to matching clients
|
||||
sent_count = await connection_manager.broadcast(
|
||||
data=broadcast_data,
|
||||
stream_type=stream_type,
|
||||
zone_ids=zone_ids
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Broadcast sent successfully",
|
||||
"recipients": sent_count,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error broadcasting message: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to broadcast message: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/metrics")
|
||||
async def get_streaming_metrics():
|
||||
"""Get streaming performance metrics."""
|
||||
try:
|
||||
metrics = await connection_manager.get_metrics()
|
||||
|
||||
return {
|
||||
"metrics": metrics,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting streaming metrics: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get streaming metrics: {str(e)}"
|
||||
)
|
||||
8
v1/src/api/websocket/__init__.py
Normal file
8
v1/src/api/websocket/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
WebSocket handlers package
|
||||
"""
|
||||
|
||||
from .connection_manager import ConnectionManager
|
||||
from .pose_stream import PoseStreamHandler
|
||||
|
||||
__all__ = ["ConnectionManager", "PoseStreamHandler"]
|
||||
461
v1/src/api/websocket/connection_manager.py
Normal file
461
v1/src/api/websocket/connection_manager.py
Normal file
@@ -0,0 +1,461 @@
|
||||
"""
|
||||
WebSocket connection manager for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Any, Set
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict
|
||||
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebSocketConnection:
|
||||
"""Represents a WebSocket connection with metadata."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
client_id: str,
|
||||
stream_type: str,
|
||||
zone_ids: Optional[List[str]] = None,
|
||||
**config
|
||||
):
|
||||
self.websocket = websocket
|
||||
self.client_id = client_id
|
||||
self.stream_type = stream_type
|
||||
self.zone_ids = zone_ids or []
|
||||
self.config = config
|
||||
self.connected_at = datetime.utcnow()
|
||||
self.last_ping = datetime.utcnow()
|
||||
self.message_count = 0
|
||||
self.is_active = True
|
||||
|
||||
async def send_json(self, data: Dict[str, Any]):
|
||||
"""Send JSON data to client."""
|
||||
try:
|
||||
await self.websocket.send_json(data)
|
||||
self.message_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending to client {self.client_id}: {e}")
|
||||
self.is_active = False
|
||||
raise
|
||||
|
||||
async def send_text(self, message: str):
|
||||
"""Send text message to client."""
|
||||
try:
|
||||
await self.websocket.send_text(message)
|
||||
self.message_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending text to client {self.client_id}: {e}")
|
||||
self.is_active = False
|
||||
raise
|
||||
|
||||
def update_config(self, config: Dict[str, Any]):
|
||||
"""Update connection configuration."""
|
||||
self.config.update(config)
|
||||
|
||||
# Update zone IDs if provided
|
||||
if "zone_ids" in config:
|
||||
self.zone_ids = config["zone_ids"] or []
|
||||
|
||||
def matches_filter(
|
||||
self,
|
||||
stream_type: Optional[str] = None,
|
||||
zone_ids: Optional[List[str]] = None,
|
||||
**filters
|
||||
) -> bool:
|
||||
"""Check if connection matches given filters."""
|
||||
# Check stream type
|
||||
if stream_type and self.stream_type != stream_type:
|
||||
return False
|
||||
|
||||
# Check zone IDs
|
||||
if zone_ids:
|
||||
if not self.zone_ids: # Connection listens to all zones
|
||||
return True
|
||||
# Check if any requested zone is in connection's zones
|
||||
if not any(zone in self.zone_ids for zone in zone_ids):
|
||||
return False
|
||||
|
||||
# Check additional filters
|
||||
for key, value in filters.items():
|
||||
if key in self.config and self.config[key] != value:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_info(self) -> Dict[str, Any]:
|
||||
"""Get connection information."""
|
||||
return {
|
||||
"client_id": self.client_id,
|
||||
"stream_type": self.stream_type,
|
||||
"zone_ids": self.zone_ids,
|
||||
"config": self.config,
|
||||
"connected_at": self.connected_at.isoformat(),
|
||||
"last_ping": self.last_ping.isoformat(),
|
||||
"message_count": self.message_count,
|
||||
"is_active": self.is_active,
|
||||
"uptime_seconds": (datetime.utcnow() - self.connected_at).total_seconds()
|
||||
}
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""Manages WebSocket connections for real-time streaming."""
|
||||
|
||||
def __init__(self):
|
||||
self.connections: Dict[str, WebSocketConnection] = {}
|
||||
self.connections_by_type: Dict[str, Set[str]] = defaultdict(set)
|
||||
self.connections_by_zone: Dict[str, Set[str]] = defaultdict(set)
|
||||
self.metrics = {
|
||||
"total_connections": 0,
|
||||
"active_connections": 0,
|
||||
"messages_sent": 0,
|
||||
"errors": 0,
|
||||
"start_time": datetime.utcnow()
|
||||
}
|
||||
self._cleanup_task = None
|
||||
self._started = False
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
stream_type: str,
|
||||
zone_ids: Optional[List[str]] = None,
|
||||
**config
|
||||
) -> str:
|
||||
"""Register a new WebSocket connection."""
|
||||
client_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Create connection object
|
||||
connection = WebSocketConnection(
|
||||
websocket=websocket,
|
||||
client_id=client_id,
|
||||
stream_type=stream_type,
|
||||
zone_ids=zone_ids,
|
||||
**config
|
||||
)
|
||||
|
||||
# Store connection
|
||||
self.connections[client_id] = connection
|
||||
self.connections_by_type[stream_type].add(client_id)
|
||||
|
||||
# Index by zones
|
||||
if zone_ids:
|
||||
for zone_id in zone_ids:
|
||||
self.connections_by_zone[zone_id].add(client_id)
|
||||
|
||||
# Update metrics
|
||||
self.metrics["total_connections"] += 1
|
||||
self.metrics["active_connections"] = len(self.connections)
|
||||
|
||||
logger.info(f"WebSocket client {client_id} connected for {stream_type}")
|
||||
|
||||
return client_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error connecting WebSocket client: {e}")
|
||||
raise
|
||||
|
||||
async def disconnect(self, client_id: str) -> bool:
|
||||
"""Disconnect a WebSocket client."""
|
||||
if client_id not in self.connections:
|
||||
return False
|
||||
|
||||
try:
|
||||
connection = self.connections[client_id]
|
||||
|
||||
# Remove from indexes
|
||||
self.connections_by_type[connection.stream_type].discard(client_id)
|
||||
|
||||
for zone_id in connection.zone_ids:
|
||||
self.connections_by_zone[zone_id].discard(client_id)
|
||||
|
||||
# Close WebSocket if still active
|
||||
if connection.is_active:
|
||||
try:
|
||||
await connection.websocket.close()
|
||||
except:
|
||||
pass # Connection might already be closed
|
||||
|
||||
# Remove connection
|
||||
del self.connections[client_id]
|
||||
|
||||
# Update metrics
|
||||
self.metrics["active_connections"] = len(self.connections)
|
||||
|
||||
logger.info(f"WebSocket client {client_id} disconnected")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error disconnecting client {client_id}: {e}")
|
||||
return False
|
||||
|
||||
async def disconnect_all(self):
|
||||
"""Disconnect all WebSocket clients."""
|
||||
client_ids = list(self.connections.keys())
|
||||
|
||||
for client_id in client_ids:
|
||||
await self.disconnect(client_id)
|
||||
|
||||
logger.info("All WebSocket clients disconnected")
|
||||
|
||||
async def send_to_client(self, client_id: str, data: Dict[str, Any]) -> bool:
|
||||
"""Send data to a specific client."""
|
||||
if client_id not in self.connections:
|
||||
return False
|
||||
|
||||
connection = self.connections[client_id]
|
||||
|
||||
try:
|
||||
await connection.send_json(data)
|
||||
self.metrics["messages_sent"] += 1
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending to client {client_id}: {e}")
|
||||
self.metrics["errors"] += 1
|
||||
|
||||
# Mark connection as inactive and schedule for cleanup
|
||||
connection.is_active = False
|
||||
return False
|
||||
|
||||
async def broadcast(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
stream_type: Optional[str] = None,
|
||||
zone_ids: Optional[List[str]] = None,
|
||||
**filters
|
||||
) -> int:
|
||||
"""Broadcast data to matching clients."""
|
||||
sent_count = 0
|
||||
failed_clients = []
|
||||
|
||||
# Get matching connections
|
||||
matching_clients = self._get_matching_clients(
|
||||
stream_type=stream_type,
|
||||
zone_ids=zone_ids,
|
||||
**filters
|
||||
)
|
||||
|
||||
# Send to all matching clients
|
||||
for client_id in matching_clients:
|
||||
try:
|
||||
success = await self.send_to_client(client_id, data)
|
||||
if success:
|
||||
sent_count += 1
|
||||
else:
|
||||
failed_clients.append(client_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error broadcasting to client {client_id}: {e}")
|
||||
failed_clients.append(client_id)
|
||||
|
||||
# Clean up failed connections
|
||||
for client_id in failed_clients:
|
||||
await self.disconnect(client_id)
|
||||
|
||||
return sent_count
|
||||
|
||||
async def update_client_config(self, client_id: str, config: Dict[str, Any]) -> bool:
|
||||
"""Update client configuration."""
|
||||
if client_id not in self.connections:
|
||||
return False
|
||||
|
||||
connection = self.connections[client_id]
|
||||
old_zones = set(connection.zone_ids)
|
||||
|
||||
# Update configuration
|
||||
connection.update_config(config)
|
||||
|
||||
# Update zone indexes if zones changed
|
||||
new_zones = set(connection.zone_ids)
|
||||
|
||||
# Remove from old zones
|
||||
for zone_id in old_zones - new_zones:
|
||||
self.connections_by_zone[zone_id].discard(client_id)
|
||||
|
||||
# Add to new zones
|
||||
for zone_id in new_zones - old_zones:
|
||||
self.connections_by_zone[zone_id].add(client_id)
|
||||
|
||||
return True
|
||||
|
||||
async def get_client_status(self, client_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get status of a specific client."""
|
||||
if client_id not in self.connections:
|
||||
return None
|
||||
|
||||
return self.connections[client_id].get_info()
|
||||
|
||||
async def get_connected_clients(self) -> List[Dict[str, Any]]:
|
||||
"""Get list of all connected clients."""
|
||||
return [conn.get_info() for conn in self.connections.values()]
|
||||
|
||||
async def get_connection_stats(self) -> Dict[str, Any]:
|
||||
"""Get connection statistics."""
|
||||
stats = {
|
||||
"total_clients": len(self.connections),
|
||||
"clients_by_type": {
|
||||
stream_type: len(clients)
|
||||
for stream_type, clients in self.connections_by_type.items()
|
||||
},
|
||||
"clients_by_zone": {
|
||||
zone_id: len(clients)
|
||||
for zone_id, clients in self.connections_by_zone.items()
|
||||
if clients # Only include zones with active clients
|
||||
},
|
||||
"active_clients": sum(1 for conn in self.connections.values() if conn.is_active),
|
||||
"inactive_clients": sum(1 for conn in self.connections.values() if not conn.is_active)
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
async def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get detailed metrics."""
|
||||
uptime = (datetime.utcnow() - self.metrics["start_time"]).total_seconds()
|
||||
|
||||
return {
|
||||
**self.metrics,
|
||||
"active_connections": len(self.connections),
|
||||
"uptime_seconds": uptime,
|
||||
"messages_per_second": self.metrics["messages_sent"] / max(uptime, 1),
|
||||
"error_rate": self.metrics["errors"] / max(self.metrics["messages_sent"], 1)
|
||||
}
|
||||
|
||||
def _get_matching_clients(
|
||||
self,
|
||||
stream_type: Optional[str] = None,
|
||||
zone_ids: Optional[List[str]] = None,
|
||||
**filters
|
||||
) -> List[str]:
|
||||
"""Get client IDs that match the given filters."""
|
||||
candidates = set(self.connections.keys())
|
||||
|
||||
# Filter by stream type
|
||||
if stream_type:
|
||||
type_clients = self.connections_by_type.get(stream_type, set())
|
||||
candidates &= type_clients
|
||||
|
||||
# Filter by zones
|
||||
if zone_ids:
|
||||
zone_clients = set()
|
||||
for zone_id in zone_ids:
|
||||
zone_clients.update(self.connections_by_zone.get(zone_id, set()))
|
||||
|
||||
# Also include clients listening to all zones (empty zone list)
|
||||
all_zone_clients = {
|
||||
client_id for client_id, conn in self.connections.items()
|
||||
if not conn.zone_ids
|
||||
}
|
||||
zone_clients.update(all_zone_clients)
|
||||
|
||||
candidates &= zone_clients
|
||||
|
||||
# Apply additional filters
|
||||
matching_clients = []
|
||||
for client_id in candidates:
|
||||
connection = self.connections[client_id]
|
||||
if connection.is_active and connection.matches_filter(**filters):
|
||||
matching_clients.append(client_id)
|
||||
|
||||
return matching_clients
|
||||
|
||||
async def ping_clients(self):
|
||||
"""Send ping to all connected clients."""
|
||||
ping_data = {
|
||||
"type": "ping",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
failed_clients = []
|
||||
|
||||
for client_id, connection in self.connections.items():
|
||||
try:
|
||||
await connection.send_json(ping_data)
|
||||
connection.last_ping = datetime.utcnow()
|
||||
except Exception as e:
|
||||
logger.warning(f"Ping failed for client {client_id}: {e}")
|
||||
failed_clients.append(client_id)
|
||||
|
||||
# Clean up failed connections
|
||||
for client_id in failed_clients:
|
||||
await self.disconnect(client_id)
|
||||
|
||||
async def cleanup_inactive_connections(self):
|
||||
"""Clean up inactive or stale connections."""
|
||||
now = datetime.utcnow()
|
||||
stale_threshold = timedelta(minutes=5) # 5 minutes without ping
|
||||
|
||||
stale_clients = []
|
||||
|
||||
for client_id, connection in self.connections.items():
|
||||
# Check if connection is inactive
|
||||
if not connection.is_active:
|
||||
stale_clients.append(client_id)
|
||||
continue
|
||||
|
||||
# Check if connection is stale (no ping response)
|
||||
if now - connection.last_ping > stale_threshold:
|
||||
logger.warning(f"Client {client_id} appears stale, disconnecting")
|
||||
stale_clients.append(client_id)
|
||||
|
||||
# Clean up stale connections
|
||||
for client_id in stale_clients:
|
||||
await self.disconnect(client_id)
|
||||
|
||||
if stale_clients:
|
||||
logger.info(f"Cleaned up {len(stale_clients)} stale connections")
|
||||
|
||||
async def start(self):
|
||||
"""Start the connection manager."""
|
||||
if not self._started:
|
||||
self._start_cleanup_task()
|
||||
self._started = True
|
||||
logger.info("Connection manager started")
|
||||
|
||||
def _start_cleanup_task(self):
|
||||
"""Start background cleanup task."""
|
||||
async def cleanup_loop():
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(60) # Run every minute
|
||||
await self.cleanup_inactive_connections()
|
||||
|
||||
# Send periodic ping every 2 minutes
|
||||
if datetime.utcnow().minute % 2 == 0:
|
||||
await self.ping_clients()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup task: {e}")
|
||||
|
||||
try:
|
||||
self._cleanup_task = asyncio.create_task(cleanup_loop())
|
||||
except RuntimeError:
|
||||
# No event loop running, will start later
|
||||
logger.debug("No event loop running, cleanup task will start later")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown connection manager."""
|
||||
# Cancel cleanup task
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Disconnect all clients
|
||||
await self.disconnect_all()
|
||||
|
||||
logger.info("Connection manager shutdown complete")
|
||||
|
||||
|
||||
# Global connection manager instance
|
||||
connection_manager = ConnectionManager()
|
||||
384
v1/src/api/websocket/pose_stream.py
Normal file
384
v1/src/api/websocket/pose_stream.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""
|
||||
Pose streaming WebSocket handler
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import WebSocket
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.api.websocket.connection_manager import ConnectionManager
|
||||
from src.services.pose_service import PoseService
|
||||
from src.services.stream_service import StreamService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PoseStreamData(BaseModel):
|
||||
"""Pose stream data model."""
|
||||
|
||||
timestamp: datetime = Field(..., description="Data timestamp")
|
||||
zone_id: str = Field(..., description="Zone identifier")
|
||||
pose_data: Dict[str, Any] = Field(..., description="Pose estimation data")
|
||||
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score")
|
||||
activity: Optional[str] = Field(default=None, description="Detected activity")
|
||||
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Additional metadata")
|
||||
|
||||
|
||||
class PoseStreamHandler:
|
||||
"""Handles pose data streaming to WebSocket clients."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_manager: ConnectionManager,
|
||||
pose_service: PoseService,
|
||||
stream_service: StreamService
|
||||
):
|
||||
self.connection_manager = connection_manager
|
||||
self.pose_service = pose_service
|
||||
self.stream_service = stream_service
|
||||
self.is_streaming = False
|
||||
self.stream_task = None
|
||||
self.subscribers = {}
|
||||
self.stream_config = {
|
||||
"fps": 30,
|
||||
"min_confidence": 0.5,
|
||||
"include_metadata": True,
|
||||
"buffer_size": 100
|
||||
}
|
||||
|
||||
async def start_streaming(self):
|
||||
"""Start pose data streaming."""
|
||||
if self.is_streaming:
|
||||
logger.warning("Pose streaming already active")
|
||||
return
|
||||
|
||||
self.is_streaming = True
|
||||
self.stream_task = asyncio.create_task(self._stream_loop())
|
||||
logger.info("Pose streaming started")
|
||||
|
||||
async def stop_streaming(self):
|
||||
"""Stop pose data streaming."""
|
||||
if not self.is_streaming:
|
||||
return
|
||||
|
||||
self.is_streaming = False
|
||||
|
||||
if self.stream_task:
|
||||
self.stream_task.cancel()
|
||||
try:
|
||||
await self.stream_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("Pose streaming stopped")
|
||||
|
||||
async def _stream_loop(self):
|
||||
"""Main streaming loop."""
|
||||
try:
|
||||
logger.info("🚀 Starting pose streaming loop")
|
||||
while self.is_streaming:
|
||||
try:
|
||||
# Get current pose data from all zones
|
||||
logger.debug("📡 Getting current pose data...")
|
||||
pose_data = await self.pose_service.get_current_pose_data()
|
||||
logger.debug(f"📊 Received pose data: {pose_data}")
|
||||
|
||||
if pose_data:
|
||||
logger.debug("📤 Broadcasting pose data...")
|
||||
await self._process_and_broadcast_pose_data(pose_data)
|
||||
else:
|
||||
logger.debug("⚠️ No pose data received")
|
||||
|
||||
# Control streaming rate
|
||||
await asyncio.sleep(1.0 / self.stream_config["fps"])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in pose streaming loop: {e}")
|
||||
await asyncio.sleep(1.0) # Brief pause on error
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Pose streaming loop cancelled")
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error in pose streaming loop: {e}")
|
||||
finally:
|
||||
logger.info("🛑 Pose streaming loop stopped")
|
||||
self.is_streaming = False
|
||||
|
||||
async def _process_and_broadcast_pose_data(self, raw_pose_data: Dict[str, Any]):
|
||||
"""Process and broadcast pose data to subscribers."""
|
||||
try:
|
||||
# Process data for each zone
|
||||
for zone_id, zone_data in raw_pose_data.items():
|
||||
if not zone_data:
|
||||
continue
|
||||
|
||||
# Create structured pose data
|
||||
pose_stream_data = PoseStreamData(
|
||||
timestamp=datetime.utcnow(),
|
||||
zone_id=zone_id,
|
||||
pose_data=zone_data.get("pose", {}),
|
||||
confidence=zone_data.get("confidence", 0.0),
|
||||
activity=zone_data.get("activity"),
|
||||
metadata=zone_data.get("metadata") if self.stream_config["include_metadata"] else None
|
||||
)
|
||||
|
||||
# Filter by minimum confidence
|
||||
if pose_stream_data.confidence < self.stream_config["min_confidence"]:
|
||||
continue
|
||||
|
||||
# Broadcast to subscribers
|
||||
await self._broadcast_pose_data(pose_stream_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing pose data: {e}")
|
||||
|
||||
async def _broadcast_pose_data(self, pose_data: PoseStreamData):
|
||||
"""Broadcast pose data to matching WebSocket clients."""
|
||||
try:
|
||||
logger.debug(f"📡 Preparing to broadcast pose data for zone {pose_data.zone_id}")
|
||||
|
||||
# Prepare broadcast data
|
||||
broadcast_data = {
|
||||
"type": "pose_data",
|
||||
"timestamp": pose_data.timestamp.isoformat(),
|
||||
"zone_id": pose_data.zone_id,
|
||||
"data": {
|
||||
"pose": pose_data.pose_data,
|
||||
"confidence": pose_data.confidence,
|
||||
"activity": pose_data.activity
|
||||
}
|
||||
}
|
||||
|
||||
# Add metadata if enabled
|
||||
if pose_data.metadata and self.stream_config["include_metadata"]:
|
||||
broadcast_data["metadata"] = pose_data.metadata
|
||||
|
||||
logger.debug(f"📤 Broadcasting data: {broadcast_data}")
|
||||
|
||||
# Broadcast to pose stream subscribers
|
||||
sent_count = await self.connection_manager.broadcast(
|
||||
data=broadcast_data,
|
||||
stream_type="pose",
|
||||
zone_ids=[pose_data.zone_id]
|
||||
)
|
||||
|
||||
logger.info(f"✅ Broadcasted pose data for zone {pose_data.zone_id} to {sent_count} clients")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error broadcasting pose data: {e}")
|
||||
|
||||
async def handle_client_subscription(
|
||||
self,
|
||||
client_id: str,
|
||||
subscription_config: Dict[str, Any]
|
||||
):
|
||||
"""Handle client subscription configuration."""
|
||||
try:
|
||||
# Store client subscription config
|
||||
self.subscribers[client_id] = {
|
||||
"zone_ids": subscription_config.get("zone_ids", []),
|
||||
"min_confidence": subscription_config.get("min_confidence", 0.5),
|
||||
"max_fps": subscription_config.get("max_fps", 30),
|
||||
"include_metadata": subscription_config.get("include_metadata", True),
|
||||
"stream_types": subscription_config.get("stream_types", ["pose_data"]),
|
||||
"subscribed_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
logger.info(f"Updated subscription for client {client_id}")
|
||||
|
||||
# Send confirmation
|
||||
confirmation = {
|
||||
"type": "subscription_updated",
|
||||
"client_id": client_id,
|
||||
"config": self.subscribers[client_id],
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await self.connection_manager.send_to_client(client_id, confirmation)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling client subscription: {e}")
|
||||
|
||||
async def handle_client_disconnect(self, client_id: str):
|
||||
"""Handle client disconnection."""
|
||||
if client_id in self.subscribers:
|
||||
del self.subscribers[client_id]
|
||||
logger.info(f"Removed subscription for disconnected client {client_id}")
|
||||
|
||||
async def send_historical_data(
|
||||
self,
|
||||
client_id: str,
|
||||
zone_id: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
limit: int = 100
|
||||
):
|
||||
"""Send historical pose data to client."""
|
||||
try:
|
||||
# Get historical data from pose service
|
||||
historical_data = await self.pose_service.get_historical_data(
|
||||
zone_id=zone_id,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
# Send data in chunks to avoid overwhelming the client
|
||||
chunk_size = 10
|
||||
for i in range(0, len(historical_data), chunk_size):
|
||||
chunk = historical_data[i:i + chunk_size]
|
||||
|
||||
message = {
|
||||
"type": "historical_data",
|
||||
"zone_id": zone_id,
|
||||
"chunk_index": i // chunk_size,
|
||||
"total_chunks": (len(historical_data) + chunk_size - 1) // chunk_size,
|
||||
"data": chunk,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await self.connection_manager.send_to_client(client_id, message)
|
||||
|
||||
# Small delay between chunks
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Send completion message
|
||||
completion_message = {
|
||||
"type": "historical_data_complete",
|
||||
"zone_id": zone_id,
|
||||
"total_records": len(historical_data),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await self.connection_manager.send_to_client(client_id, completion_message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending historical data: {e}")
|
||||
|
||||
# Send error message to client
|
||||
error_message = {
|
||||
"type": "error",
|
||||
"message": f"Failed to retrieve historical data: {str(e)}",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await self.connection_manager.send_to_client(client_id, error_message)
|
||||
|
||||
async def send_zone_statistics(self, client_id: str, zone_id: str):
|
||||
"""Send zone statistics to client."""
|
||||
try:
|
||||
# Get zone statistics
|
||||
stats = await self.pose_service.get_zone_statistics(zone_id)
|
||||
|
||||
message = {
|
||||
"type": "zone_statistics",
|
||||
"zone_id": zone_id,
|
||||
"statistics": stats,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await self.connection_manager.send_to_client(client_id, message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending zone statistics: {e}")
|
||||
|
||||
async def broadcast_system_event(self, event_type: str, event_data: Dict[str, Any]):
|
||||
"""Broadcast system events to all connected clients."""
|
||||
try:
|
||||
message = {
|
||||
"type": "system_event",
|
||||
"event_type": event_type,
|
||||
"data": event_data,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# Broadcast to all pose stream clients
|
||||
sent_count = await self.connection_manager.broadcast(
|
||||
data=message,
|
||||
stream_type="pose"
|
||||
)
|
||||
|
||||
logger.info(f"Broadcasted system event '{event_type}' to {sent_count} clients")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error broadcasting system event: {e}")
|
||||
|
||||
async def update_stream_config(self, config: Dict[str, Any]):
|
||||
"""Update streaming configuration."""
|
||||
try:
|
||||
# Validate and update configuration
|
||||
if "fps" in config:
|
||||
fps = max(1, min(60, config["fps"]))
|
||||
self.stream_config["fps"] = fps
|
||||
|
||||
if "min_confidence" in config:
|
||||
confidence = max(0.0, min(1.0, config["min_confidence"]))
|
||||
self.stream_config["min_confidence"] = confidence
|
||||
|
||||
if "include_metadata" in config:
|
||||
self.stream_config["include_metadata"] = bool(config["include_metadata"])
|
||||
|
||||
if "buffer_size" in config:
|
||||
buffer_size = max(10, min(1000, config["buffer_size"]))
|
||||
self.stream_config["buffer_size"] = buffer_size
|
||||
|
||||
logger.info(f"Updated stream configuration: {self.stream_config}")
|
||||
|
||||
# Broadcast configuration update to clients
|
||||
await self.broadcast_system_event("stream_config_updated", {
|
||||
"new_config": self.stream_config
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating stream configuration: {e}")
|
||||
|
||||
def get_stream_status(self) -> Dict[str, Any]:
|
||||
"""Get current streaming status."""
|
||||
return {
|
||||
"is_streaming": self.is_streaming,
|
||||
"config": self.stream_config,
|
||||
"subscriber_count": len(self.subscribers),
|
||||
"subscribers": {
|
||||
client_id: {
|
||||
"zone_ids": sub["zone_ids"],
|
||||
"min_confidence": sub["min_confidence"],
|
||||
"subscribed_at": sub["subscribed_at"].isoformat()
|
||||
}
|
||||
for client_id, sub in self.subscribers.items()
|
||||
}
|
||||
}
|
||||
|
||||
async def get_performance_metrics(self) -> Dict[str, Any]:
|
||||
"""Get streaming performance metrics."""
|
||||
try:
|
||||
# Get connection manager metrics
|
||||
conn_metrics = await self.connection_manager.get_metrics()
|
||||
|
||||
# Get pose service metrics
|
||||
pose_metrics = await self.pose_service.get_performance_metrics()
|
||||
|
||||
return {
|
||||
"streaming": {
|
||||
"is_active": self.is_streaming,
|
||||
"fps": self.stream_config["fps"],
|
||||
"subscriber_count": len(self.subscribers)
|
||||
},
|
||||
"connections": conn_metrics,
|
||||
"pose_service": pose_metrics,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting performance metrics: {e}")
|
||||
return {}
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown pose stream handler."""
|
||||
await self.stop_streaming()
|
||||
self.subscribers.clear()
|
||||
logger.info("Pose stream handler shutdown complete")
|
||||
328
v1/src/app.py
Normal file
328
v1/src/app.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""
|
||||
FastAPI application factory and configuration
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.services.orchestrator import ServiceOrchestrator
|
||||
from src.middleware.auth import AuthenticationMiddleware
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from src.middleware.rate_limit import RateLimitMiddleware
|
||||
from src.middleware.error_handler import ErrorHandlingMiddleware
|
||||
from src.api.routers import pose, stream, health
|
||||
from src.api.websocket.connection_manager import connection_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan manager."""
|
||||
logger.info("Starting WiFi-DensePose API...")
|
||||
|
||||
try:
|
||||
# Get orchestrator from app state
|
||||
orchestrator: ServiceOrchestrator = app.state.orchestrator
|
||||
|
||||
# Start connection manager
|
||||
await connection_manager.start()
|
||||
|
||||
# Start all services
|
||||
await orchestrator.start()
|
||||
|
||||
logger.info("WiFi-DensePose API started successfully")
|
||||
|
||||
yield
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start application: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Cleanup on shutdown
|
||||
logger.info("Shutting down WiFi-DensePose API...")
|
||||
|
||||
# Shutdown connection manager
|
||||
await connection_manager.shutdown()
|
||||
|
||||
if hasattr(app.state, 'orchestrator'):
|
||||
await app.state.orchestrator.shutdown()
|
||||
logger.info("WiFi-DensePose API shutdown complete")
|
||||
|
||||
|
||||
def create_app(settings: Settings, orchestrator: ServiceOrchestrator) -> FastAPI:
|
||||
"""Create and configure FastAPI application."""
|
||||
|
||||
# Create FastAPI application
|
||||
app = FastAPI(
|
||||
title=settings.app_name,
|
||||
version=settings.version,
|
||||
description="WiFi-based human pose estimation and activity recognition API",
|
||||
docs_url=settings.docs_url if not settings.is_production else None,
|
||||
redoc_url=settings.redoc_url if not settings.is_production else None,
|
||||
openapi_url=settings.openapi_url if not settings.is_production else None,
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Store orchestrator in app state
|
||||
app.state.orchestrator = orchestrator
|
||||
app.state.settings = settings
|
||||
|
||||
# Add middleware in reverse order (last added = first executed)
|
||||
setup_middleware(app, settings)
|
||||
|
||||
# Add exception handlers
|
||||
setup_exception_handlers(app)
|
||||
|
||||
# Include routers
|
||||
setup_routers(app, settings)
|
||||
|
||||
# Add root endpoints
|
||||
setup_root_endpoints(app, settings)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def setup_middleware(app: FastAPI, settings: Settings):
|
||||
"""Setup application middleware."""
|
||||
|
||||
# Rate limiting middleware
|
||||
if settings.enable_rate_limiting:
|
||||
app.add_middleware(RateLimitMiddleware, settings=settings)
|
||||
|
||||
# Authentication middleware
|
||||
if settings.enable_authentication:
|
||||
app.add_middleware(AuthenticationMiddleware, settings=settings)
|
||||
|
||||
# CORS middleware
|
||||
if settings.cors_enabled:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=settings.cors_allow_credentials,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Trusted host middleware for production
|
||||
if settings.is_production:
|
||||
app.add_middleware(
|
||||
TrustedHostMiddleware,
|
||||
allowed_hosts=settings.allowed_hosts
|
||||
)
|
||||
|
||||
|
||||
def setup_exception_handlers(app: FastAPI):
|
||||
"""Setup global exception handlers."""
|
||||
|
||||
@app.exception_handler(StarletteHTTPException)
|
||||
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
|
||||
"""Handle HTTP exceptions."""
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={
|
||||
"error": {
|
||||
"code": exc.status_code,
|
||||
"message": exc.detail,
|
||||
"type": "http_error",
|
||||
"path": str(request.url.path)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
"""Handle request validation errors."""
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content={
|
||||
"error": {
|
||||
"code": 422,
|
||||
"message": "Validation error",
|
||||
"type": "validation_error",
|
||||
"path": str(request.url.path),
|
||||
"details": exc.errors()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def general_exception_handler(request: Request, exc: Exception):
|
||||
"""Handle general exceptions."""
|
||||
logger.error(f"Unhandled exception on {request.url.path}: {exc}", exc_info=True)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": {
|
||||
"code": 500,
|
||||
"message": "Internal server error",
|
||||
"type": "internal_error",
|
||||
"path": str(request.url.path)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def setup_routers(app: FastAPI, settings: Settings):
|
||||
"""Setup API routers."""
|
||||
|
||||
# Health check router (no prefix)
|
||||
app.include_router(
|
||||
health.router,
|
||||
prefix="/health",
|
||||
tags=["Health"]
|
||||
)
|
||||
|
||||
# API routers with prefix
|
||||
app.include_router(
|
||||
pose.router,
|
||||
prefix=f"{settings.api_prefix}/pose",
|
||||
tags=["Pose Estimation"]
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
stream.router,
|
||||
prefix=f"{settings.api_prefix}/stream",
|
||||
tags=["Streaming"]
|
||||
)
|
||||
|
||||
|
||||
def setup_root_endpoints(app: FastAPI, settings: Settings):
|
||||
"""Setup root application endpoints."""
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint with API information."""
|
||||
return {
|
||||
"name": settings.app_name,
|
||||
"version": settings.version,
|
||||
"environment": settings.environment,
|
||||
"docs_url": settings.docs_url,
|
||||
"api_prefix": settings.api_prefix,
|
||||
"features": {
|
||||
"authentication": settings.enable_authentication,
|
||||
"rate_limiting": settings.enable_rate_limiting,
|
||||
"websockets": settings.enable_websockets,
|
||||
"real_time_processing": settings.enable_real_time_processing
|
||||
}
|
||||
}
|
||||
|
||||
@app.get(f"{settings.api_prefix}/info")
|
||||
async def api_info(request: Request):
|
||||
"""Get detailed API information."""
|
||||
orchestrator: ServiceOrchestrator = request.app.state.orchestrator
|
||||
|
||||
return {
|
||||
"api": {
|
||||
"name": settings.app_name,
|
||||
"version": settings.version,
|
||||
"environment": settings.environment,
|
||||
"prefix": settings.api_prefix
|
||||
},
|
||||
"services": await orchestrator.get_service_info(),
|
||||
"features": {
|
||||
"authentication": settings.enable_authentication,
|
||||
"rate_limiting": settings.enable_rate_limiting,
|
||||
"websockets": settings.enable_websockets,
|
||||
"real_time_processing": settings.enable_real_time_processing,
|
||||
"historical_data": settings.enable_historical_data
|
||||
},
|
||||
"limits": {
|
||||
"rate_limit_requests": settings.rate_limit_requests,
|
||||
"rate_limit_window": settings.rate_limit_window
|
||||
}
|
||||
}
|
||||
|
||||
@app.get(f"{settings.api_prefix}/status")
|
||||
async def api_status(request: Request):
|
||||
"""Get current API status."""
|
||||
try:
|
||||
orchestrator: ServiceOrchestrator = request.app.state.orchestrator
|
||||
|
||||
status = {
|
||||
"api": {
|
||||
"status": "healthy",
|
||||
"version": settings.version,
|
||||
"environment": settings.environment
|
||||
},
|
||||
"services": await orchestrator.get_service_status(),
|
||||
"connections": await connection_manager.get_connection_stats()
|
||||
}
|
||||
|
||||
return status
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting API status: {e}")
|
||||
return {
|
||||
"api": {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
# Metrics endpoint (if enabled)
|
||||
if settings.metrics_enabled:
|
||||
@app.get(f"{settings.api_prefix}/metrics")
|
||||
async def api_metrics(request: Request):
|
||||
"""Get API metrics."""
|
||||
try:
|
||||
orchestrator: ServiceOrchestrator = request.app.state.orchestrator
|
||||
|
||||
metrics = {
|
||||
"connections": await connection_manager.get_metrics(),
|
||||
"services": await orchestrator.get_service_metrics()
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting metrics: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
# Development endpoints (only in development)
|
||||
if settings.is_development and settings.enable_test_endpoints:
|
||||
@app.get(f"{settings.api_prefix}/dev/config")
|
||||
async def dev_config():
|
||||
"""Get current configuration (development only)."""
|
||||
return {
|
||||
"settings": settings.dict(),
|
||||
"environment_variables": dict(os.environ)
|
||||
}
|
||||
|
||||
@app.post(f"{settings.api_prefix}/dev/reset")
|
||||
async def dev_reset(request: Request):
|
||||
"""Reset services (development only)."""
|
||||
try:
|
||||
orchestrator: ServiceOrchestrator = request.app.state.orchestrator
|
||||
await orchestrator.reset_services()
|
||||
return {"message": "Services reset successfully"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error resetting services: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
# Create default app instance for uvicorn
|
||||
def get_app() -> FastAPI:
|
||||
"""Get the default application instance."""
|
||||
from src.config.settings import get_settings
|
||||
from src.services.orchestrator import ServiceOrchestrator
|
||||
|
||||
settings = get_settings()
|
||||
orchestrator = ServiceOrchestrator(settings)
|
||||
return create_app(settings, orchestrator)
|
||||
|
||||
|
||||
# Default app instance for uvicorn
|
||||
app = get_app()
|
||||
621
v1/src/cli.py
Normal file
621
v1/src/cli.py
Normal file
@@ -0,0 +1,621 @@
|
||||
"""
|
||||
Command-line interface for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import click
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from src.config.settings import get_settings, load_settings_from_file
|
||||
from src.logger import setup_logging, get_logger
|
||||
from src.commands.start import start_command
|
||||
from src.commands.stop import stop_command
|
||||
from src.commands.status import status_command
|
||||
|
||||
# Get default settings and setup logging for CLI
|
||||
settings = get_settings()
|
||||
setup_logging(settings)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_settings_with_config(config_file: Optional[str] = None):
|
||||
"""Get settings with optional config file."""
|
||||
if config_file:
|
||||
return load_settings_from_file(config_file)
|
||||
else:
|
||||
return get_settings()
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.option(
|
||||
'--config',
|
||||
'-c',
|
||||
type=click.Path(exists=True),
|
||||
help='Path to configuration file'
|
||||
)
|
||||
@click.option(
|
||||
'--verbose',
|
||||
'-v',
|
||||
is_flag=True,
|
||||
help='Enable verbose logging'
|
||||
)
|
||||
@click.option(
|
||||
'--debug',
|
||||
is_flag=True,
|
||||
help='Enable debug mode'
|
||||
)
|
||||
@click.pass_context
|
||||
def cli(ctx, config: Optional[str], verbose: bool, debug: bool):
|
||||
"""WiFi-DensePose API Command Line Interface."""
|
||||
|
||||
# Ensure context object exists
|
||||
ctx.ensure_object(dict)
|
||||
|
||||
# Store CLI options in context
|
||||
ctx.obj['config_file'] = config
|
||||
ctx.obj['verbose'] = verbose
|
||||
ctx.obj['debug'] = debug
|
||||
|
||||
# Setup logging level
|
||||
if debug:
|
||||
import logging
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
logger.debug("Debug mode enabled")
|
||||
elif verbose:
|
||||
import logging
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
logger.info("Verbose mode enabled")
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option(
|
||||
'--host',
|
||||
default='0.0.0.0',
|
||||
help='Host to bind to (default: 0.0.0.0)'
|
||||
)
|
||||
@click.option(
|
||||
'--port',
|
||||
default=8000,
|
||||
type=int,
|
||||
help='Port to bind to (default: 8000)'
|
||||
)
|
||||
@click.option(
|
||||
'--workers',
|
||||
default=1,
|
||||
type=int,
|
||||
help='Number of worker processes (default: 1)'
|
||||
)
|
||||
@click.option(
|
||||
'--reload',
|
||||
is_flag=True,
|
||||
help='Enable auto-reload for development'
|
||||
)
|
||||
@click.option(
|
||||
'--daemon',
|
||||
'-d',
|
||||
is_flag=True,
|
||||
help='Run as daemon (background process)'
|
||||
)
|
||||
@click.pass_context
|
||||
def start(ctx, host: str, port: int, workers: int, reload: bool, daemon: bool):
|
||||
"""Start the WiFi-DensePose API server."""
|
||||
|
||||
try:
|
||||
# Get settings
|
||||
settings = get_settings_with_config(ctx.obj.get('config_file'))
|
||||
|
||||
# Override settings with CLI options
|
||||
if ctx.obj.get('debug'):
|
||||
settings.debug = True
|
||||
|
||||
# Run start command
|
||||
asyncio.run(start_command(
|
||||
settings=settings,
|
||||
host=host,
|
||||
port=port,
|
||||
workers=workers,
|
||||
reload=reload,
|
||||
daemon=daemon
|
||||
))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received interrupt signal, shutting down...")
|
||||
sys.exit(0)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start server: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option(
|
||||
'--force',
|
||||
'-f',
|
||||
is_flag=True,
|
||||
help='Force stop without graceful shutdown'
|
||||
)
|
||||
@click.option(
|
||||
'--timeout',
|
||||
default=30,
|
||||
type=int,
|
||||
help='Timeout for graceful shutdown (default: 30 seconds)'
|
||||
)
|
||||
@click.pass_context
|
||||
def stop(ctx, force: bool, timeout: int):
|
||||
"""Stop the WiFi-DensePose API server."""
|
||||
|
||||
try:
|
||||
# Get settings
|
||||
settings = get_settings_with_config(ctx.obj.get('config_file'))
|
||||
|
||||
# Run stop command
|
||||
asyncio.run(stop_command(
|
||||
settings=settings,
|
||||
force=force,
|
||||
timeout=timeout
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop server: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option(
|
||||
'--format',
|
||||
type=click.Choice(['text', 'json']),
|
||||
default='text',
|
||||
help='Output format (default: text)'
|
||||
)
|
||||
@click.option(
|
||||
'--detailed',
|
||||
is_flag=True,
|
||||
help='Show detailed status information'
|
||||
)
|
||||
@click.pass_context
|
||||
def status(ctx, format: str, detailed: bool):
|
||||
"""Show the status of the WiFi-DensePose API server."""
|
||||
|
||||
try:
|
||||
# Get settings
|
||||
settings = get_settings_with_config(ctx.obj.get('config_file'))
|
||||
|
||||
# Run status command
|
||||
asyncio.run(status_command(
|
||||
settings=settings,
|
||||
output_format=format,
|
||||
detailed=detailed
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get status: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@cli.group()
|
||||
def db():
|
||||
"""Database management commands."""
|
||||
pass
|
||||
|
||||
|
||||
@db.command()
|
||||
@click.option(
|
||||
'--url',
|
||||
help='Database URL (overrides config)'
|
||||
)
|
||||
@click.pass_context
|
||||
def init(ctx, url: Optional[str]):
|
||||
"""Initialize the database schema."""
|
||||
|
||||
try:
|
||||
from src.database.connection import get_database_manager
|
||||
from alembic.config import Config
|
||||
from alembic import command
|
||||
import os
|
||||
|
||||
# Get settings
|
||||
settings = get_settings_with_config(ctx.obj.get('config_file'))
|
||||
|
||||
if url:
|
||||
settings.database_url = url
|
||||
|
||||
# Initialize database
|
||||
db_manager = get_database_manager(settings)
|
||||
|
||||
async def init_db():
|
||||
await db_manager.initialize()
|
||||
logger.info("Database initialized successfully")
|
||||
|
||||
asyncio.run(init_db())
|
||||
|
||||
# Run migrations if alembic.ini exists
|
||||
alembic_ini_path = "alembic.ini"
|
||||
if os.path.exists(alembic_ini_path):
|
||||
try:
|
||||
alembic_cfg = Config(alembic_ini_path)
|
||||
# Set the database URL in the config
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", settings.get_database_url())
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
logger.info("Database migrations applied successfully")
|
||||
except Exception as migration_error:
|
||||
logger.warning(f"Migration failed, but database is initialized: {migration_error}")
|
||||
else:
|
||||
logger.info("No alembic.ini found, skipping migrations")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize database: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@db.command()
|
||||
@click.option(
|
||||
'--revision',
|
||||
default='head',
|
||||
help='Target revision (default: head)'
|
||||
)
|
||||
@click.pass_context
|
||||
def migrate(ctx, revision: str):
|
||||
"""Run database migrations."""
|
||||
|
||||
try:
|
||||
from alembic.config import Config
|
||||
from alembic import command
|
||||
|
||||
# Run migrations
|
||||
alembic_cfg = Config("alembic.ini")
|
||||
command.upgrade(alembic_cfg, revision)
|
||||
logger.info(f"Database migrated to revision: {revision}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to run migrations: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@db.command()
|
||||
@click.option(
|
||||
'--steps',
|
||||
default=1,
|
||||
type=int,
|
||||
help='Number of steps to rollback (default: 1)'
|
||||
)
|
||||
@click.pass_context
|
||||
def rollback(ctx, steps: int):
|
||||
"""Rollback database migrations."""
|
||||
|
||||
try:
|
||||
from alembic.config import Config
|
||||
from alembic import command
|
||||
|
||||
# Rollback migrations
|
||||
alembic_cfg = Config("alembic.ini")
|
||||
command.downgrade(alembic_cfg, f"-{steps}")
|
||||
logger.info(f"Database rolled back {steps} step(s)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to rollback database: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@cli.group()
|
||||
def tasks():
|
||||
"""Background task management commands."""
|
||||
pass
|
||||
|
||||
|
||||
@tasks.command()
|
||||
@click.option(
|
||||
'--task',
|
||||
type=click.Choice(['cleanup', 'monitoring', 'backup']),
|
||||
help='Specific task to run'
|
||||
)
|
||||
@click.pass_context
|
||||
def run(ctx, task: Optional[str]):
|
||||
"""Run background tasks."""
|
||||
|
||||
try:
|
||||
from src.tasks.cleanup import get_cleanup_manager
|
||||
from src.tasks.monitoring import get_monitoring_manager
|
||||
from src.tasks.backup import get_backup_manager
|
||||
|
||||
# Get settings
|
||||
settings = get_settings_with_config(ctx.obj.get('config_file'))
|
||||
|
||||
async def run_tasks():
|
||||
if task == 'cleanup' or task is None:
|
||||
cleanup_manager = get_cleanup_manager(settings)
|
||||
result = await cleanup_manager.run_all_tasks()
|
||||
logger.info(f"Cleanup result: {result}")
|
||||
|
||||
if task == 'monitoring' or task is None:
|
||||
monitoring_manager = get_monitoring_manager(settings)
|
||||
result = await monitoring_manager.run_all_tasks()
|
||||
logger.info(f"Monitoring result: {result}")
|
||||
|
||||
if task == 'backup' or task is None:
|
||||
backup_manager = get_backup_manager(settings)
|
||||
result = await backup_manager.run_all_tasks()
|
||||
logger.info(f"Backup result: {result}")
|
||||
|
||||
asyncio.run(run_tasks())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to run tasks: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@tasks.command()
|
||||
@click.pass_context
|
||||
def status(ctx):
|
||||
"""Show background task status."""
|
||||
|
||||
try:
|
||||
from src.tasks.cleanup import get_cleanup_manager
|
||||
from src.tasks.monitoring import get_monitoring_manager
|
||||
from src.tasks.backup import get_backup_manager
|
||||
import json
|
||||
|
||||
# Get settings
|
||||
settings = get_settings_with_config(ctx.obj.get('config_file'))
|
||||
|
||||
# Get task managers
|
||||
cleanup_manager = get_cleanup_manager(settings)
|
||||
monitoring_manager = get_monitoring_manager(settings)
|
||||
backup_manager = get_backup_manager(settings)
|
||||
|
||||
# Collect status
|
||||
status_data = {
|
||||
"cleanup": cleanup_manager.get_stats(),
|
||||
"monitoring": monitoring_manager.get_stats(),
|
||||
"backup": backup_manager.get_stats(),
|
||||
}
|
||||
|
||||
# Print status
|
||||
click.echo(json.dumps(status_data, indent=2))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get task status: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@cli.group()
|
||||
def config():
|
||||
"""Configuration management commands."""
|
||||
pass
|
||||
|
||||
|
||||
@config.command()
|
||||
@click.pass_context
|
||||
def show(ctx):
|
||||
"""Show current configuration."""
|
||||
|
||||
try:
|
||||
import json
|
||||
|
||||
# Get settings
|
||||
settings = get_settings_with_config(ctx.obj.get('config_file'))
|
||||
|
||||
# Convert settings to dict (excluding sensitive data)
|
||||
config_dict = {
|
||||
"app_name": settings.app_name,
|
||||
"version": settings.version,
|
||||
"environment": settings.environment,
|
||||
"debug": settings.debug,
|
||||
"host": settings.host,
|
||||
"port": settings.port,
|
||||
"api_prefix": settings.api_prefix,
|
||||
"docs_url": settings.docs_url,
|
||||
"redoc_url": settings.redoc_url,
|
||||
"log_level": settings.log_level,
|
||||
"log_file": settings.log_file,
|
||||
"data_storage_path": settings.data_storage_path,
|
||||
"model_storage_path": settings.model_storage_path,
|
||||
"temp_storage_path": settings.temp_storage_path,
|
||||
"wifi_interface": settings.wifi_interface,
|
||||
"csi_buffer_size": settings.csi_buffer_size,
|
||||
"pose_confidence_threshold": settings.pose_confidence_threshold,
|
||||
"stream_fps": settings.stream_fps,
|
||||
"websocket_ping_interval": settings.websocket_ping_interval,
|
||||
"features": {
|
||||
"authentication": settings.enable_authentication,
|
||||
"rate_limiting": settings.enable_rate_limiting,
|
||||
"websockets": settings.enable_websockets,
|
||||
"historical_data": settings.enable_historical_data,
|
||||
"real_time_processing": settings.enable_real_time_processing,
|
||||
"cors": settings.cors_enabled,
|
||||
}
|
||||
}
|
||||
|
||||
click.echo(json.dumps(config_dict, indent=2))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to show configuration: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@config.command()
|
||||
@click.pass_context
|
||||
def validate(ctx):
|
||||
"""Validate configuration."""
|
||||
|
||||
try:
|
||||
# Get settings
|
||||
settings = get_settings_with_config(ctx.obj.get('config_file'))
|
||||
|
||||
# Validate database connection
|
||||
from src.database.connection import get_database_manager
|
||||
|
||||
async def validate_config():
|
||||
db_manager = get_database_manager(settings)
|
||||
|
||||
try:
|
||||
await db_manager.test_connection()
|
||||
click.echo("✓ Database connection: OK")
|
||||
except Exception as e:
|
||||
click.echo(f"✗ Database connection: FAILED - {e}")
|
||||
return False
|
||||
|
||||
# Validate Redis connection (if configured)
|
||||
redis_url = settings.get_redis_url()
|
||||
if redis_url:
|
||||
try:
|
||||
import redis.asyncio as redis
|
||||
redis_client = redis.from_url(redis_url)
|
||||
await redis_client.ping()
|
||||
click.echo("✓ Redis connection: OK")
|
||||
await redis_client.close()
|
||||
except Exception as e:
|
||||
click.echo(f"✗ Redis connection: FAILED - {e}")
|
||||
return False
|
||||
else:
|
||||
click.echo("- Redis connection: NOT CONFIGURED")
|
||||
|
||||
# Validate directories
|
||||
from pathlib import Path
|
||||
|
||||
directories = [
|
||||
("Data storage", settings.data_storage_path),
|
||||
("Model storage", settings.model_storage_path),
|
||||
("Temp storage", settings.temp_storage_path),
|
||||
]
|
||||
|
||||
for name, directory in directories:
|
||||
path = Path(directory)
|
||||
if path.exists() and path.is_dir():
|
||||
click.echo(f"✓ {name}: OK")
|
||||
else:
|
||||
try:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
click.echo(f"✓ {name}: CREATED - {directory}")
|
||||
except Exception as e:
|
||||
click.echo(f"✗ {name}: FAILED TO CREATE - {directory} ({e})")
|
||||
return False
|
||||
|
||||
click.echo("\n✓ Configuration validation passed")
|
||||
return True
|
||||
|
||||
result = asyncio.run(validate_config())
|
||||
if not result:
|
||||
sys.exit(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to validate configuration: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@config.command()
|
||||
@click.option(
|
||||
'--format',
|
||||
type=click.Choice(['text', 'json']),
|
||||
default='text',
|
||||
help='Output format (default: text)'
|
||||
)
|
||||
@click.pass_context
|
||||
def failsafe(ctx, format: str):
|
||||
"""Show failsafe status and configuration."""
|
||||
|
||||
try:
|
||||
import json
|
||||
from src.database.connection import get_database_manager
|
||||
|
||||
# Get settings
|
||||
settings = get_settings_with_config(ctx.obj.get('config_file'))
|
||||
|
||||
async def check_failsafe_status():
|
||||
db_manager = get_database_manager(settings)
|
||||
|
||||
# Initialize database to check current state
|
||||
try:
|
||||
await db_manager.initialize()
|
||||
except Exception as e:
|
||||
logger.warning(f"Database initialization failed: {e}")
|
||||
|
||||
# Collect failsafe status
|
||||
failsafe_status = {
|
||||
"database": {
|
||||
"failsafe_enabled": settings.enable_database_failsafe,
|
||||
"using_sqlite_fallback": db_manager.is_using_sqlite_fallback(),
|
||||
"sqlite_fallback_path": settings.sqlite_fallback_path,
|
||||
"primary_database_url": settings.get_database_url() if not db_manager.is_using_sqlite_fallback() else None,
|
||||
},
|
||||
"redis": {
|
||||
"failsafe_enabled": settings.enable_redis_failsafe,
|
||||
"redis_enabled": settings.redis_enabled,
|
||||
"redis_required": settings.redis_required,
|
||||
"redis_available": db_manager.is_redis_available(),
|
||||
"redis_url": settings.get_redis_url() if settings.redis_enabled else None,
|
||||
},
|
||||
"overall_status": "healthy"
|
||||
}
|
||||
|
||||
# Determine overall status
|
||||
if failsafe_status["database"]["using_sqlite_fallback"] or not failsafe_status["redis"]["redis_available"]:
|
||||
failsafe_status["overall_status"] = "degraded"
|
||||
|
||||
# Output results
|
||||
if format == 'json':
|
||||
click.echo(json.dumps(failsafe_status, indent=2))
|
||||
else:
|
||||
click.echo("=== Failsafe Status ===\n")
|
||||
|
||||
# Database status
|
||||
click.echo("Database:")
|
||||
if failsafe_status["database"]["using_sqlite_fallback"]:
|
||||
click.echo(" ⚠️ Using SQLite fallback database")
|
||||
click.echo(f" Path: {failsafe_status['database']['sqlite_fallback_path']}")
|
||||
else:
|
||||
click.echo(" ✓ Using primary database (PostgreSQL)")
|
||||
|
||||
click.echo(f" Failsafe enabled: {'Yes' if failsafe_status['database']['failsafe_enabled'] else 'No'}")
|
||||
|
||||
# Redis status
|
||||
click.echo("\nRedis:")
|
||||
if not failsafe_status["redis"]["redis_enabled"]:
|
||||
click.echo(" - Redis disabled")
|
||||
elif not failsafe_status["redis"]["redis_available"]:
|
||||
click.echo(" ⚠️ Redis unavailable (failsafe active)")
|
||||
else:
|
||||
click.echo(" ✓ Redis available")
|
||||
|
||||
click.echo(f" Failsafe enabled: {'Yes' if failsafe_status['redis']['failsafe_enabled'] else 'No'}")
|
||||
click.echo(f" Required: {'Yes' if failsafe_status['redis']['redis_required'] else 'No'}")
|
||||
|
||||
# Overall status
|
||||
status_icon = "✓" if failsafe_status["overall_status"] == "healthy" else "⚠️"
|
||||
click.echo(f"\nOverall Status: {status_icon} {failsafe_status['overall_status'].upper()}")
|
||||
|
||||
if failsafe_status["overall_status"] == "degraded":
|
||||
click.echo("\nNote: System is running in degraded mode using failsafe configurations.")
|
||||
|
||||
asyncio.run(check_failsafe_status())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check failsafe status: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@cli.command()
|
||||
def version():
|
||||
"""Show version information."""
|
||||
|
||||
try:
|
||||
from src.config.settings import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
click.echo(f"WiFi-DensePose API v{settings.version}")
|
||||
click.echo(f"Environment: {settings.environment}")
|
||||
click.echo(f"Python: {sys.version}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get version: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def create_cli(orchestrator=None):
|
||||
"""Create CLI interface for the application."""
|
||||
return cli
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
cli()
|
||||
359
v1/src/commands/start.py
Normal file
359
v1/src/commands/start.py
Normal file
@@ -0,0 +1,359 @@
|
||||
"""
|
||||
Start command implementation for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import uvicorn
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def start_command(
|
||||
settings: Settings,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8000,
|
||||
workers: int = 1,
|
||||
reload: bool = False,
|
||||
daemon: bool = False
|
||||
) -> None:
|
||||
"""Start the WiFi-DensePose API server."""
|
||||
|
||||
logger.info(f"Starting WiFi-DensePose API server...")
|
||||
logger.info(f"Environment: {settings.environment}")
|
||||
logger.info(f"Debug mode: {settings.debug}")
|
||||
logger.info(f"Host: {host}")
|
||||
logger.info(f"Port: {port}")
|
||||
logger.info(f"Workers: {workers}")
|
||||
|
||||
# Validate settings
|
||||
await _validate_startup_requirements(settings)
|
||||
|
||||
# Setup signal handlers
|
||||
_setup_signal_handlers()
|
||||
|
||||
# Create PID file if running as daemon
|
||||
pid_file = None
|
||||
if daemon:
|
||||
pid_file = _create_pid_file(settings)
|
||||
|
||||
try:
|
||||
# Initialize database
|
||||
await _initialize_database(settings)
|
||||
|
||||
# Start background tasks
|
||||
background_tasks = await _start_background_tasks(settings)
|
||||
|
||||
# Configure uvicorn
|
||||
uvicorn_config = {
|
||||
"app": "src.app:app",
|
||||
"host": host,
|
||||
"port": port,
|
||||
"reload": reload,
|
||||
"workers": workers if not reload else 1, # Reload doesn't work with multiple workers
|
||||
"log_level": "debug" if settings.debug else "info",
|
||||
"access_log": True,
|
||||
"use_colors": not daemon,
|
||||
}
|
||||
|
||||
if daemon:
|
||||
# Run as daemon
|
||||
await _run_as_daemon(uvicorn_config, pid_file)
|
||||
else:
|
||||
# Run in foreground
|
||||
await _run_server(uvicorn_config)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received interrupt signal, shutting down...")
|
||||
except Exception as e:
|
||||
logger.error(f"Server startup failed: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Cleanup
|
||||
if pid_file and pid_file.exists():
|
||||
pid_file.unlink()
|
||||
|
||||
# Stop background tasks
|
||||
if 'background_tasks' in locals():
|
||||
await _stop_background_tasks(background_tasks)
|
||||
|
||||
|
||||
async def _validate_startup_requirements(settings: Settings) -> None:
|
||||
"""Validate that all startup requirements are met."""
|
||||
|
||||
logger.info("Validating startup requirements...")
|
||||
|
||||
# Check database connection
|
||||
try:
|
||||
from src.database.connection import get_database_manager
|
||||
|
||||
db_manager = get_database_manager(settings)
|
||||
await db_manager.test_connection()
|
||||
logger.info("✓ Database connection validated")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"✗ Database connection failed: {e}")
|
||||
raise
|
||||
|
||||
# Check Redis connection (if enabled)
|
||||
if settings.redis_enabled:
|
||||
try:
|
||||
redis_stats = await db_manager.get_connection_stats()
|
||||
if "redis" in redis_stats and not redis_stats["redis"].get("error"):
|
||||
logger.info("✓ Redis connection validated")
|
||||
else:
|
||||
logger.warning("⚠ Redis connection failed, continuing without Redis")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠ Redis connection failed: {e}, continuing without Redis")
|
||||
|
||||
# Check required directories
|
||||
directories = [
|
||||
("Log directory", settings.log_directory),
|
||||
("Backup directory", settings.backup_directory),
|
||||
]
|
||||
|
||||
for name, directory in directories:
|
||||
path = Path(directory)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"✓ {name} ready: {directory}")
|
||||
|
||||
logger.info("All startup requirements validated")
|
||||
|
||||
|
||||
async def _initialize_database(settings: Settings) -> None:
|
||||
"""Initialize database connection and run migrations if needed."""
|
||||
|
||||
logger.info("Initializing database...")
|
||||
|
||||
try:
|
||||
from src.database.connection import get_database_manager
|
||||
|
||||
db_manager = get_database_manager(settings)
|
||||
await db_manager.initialize()
|
||||
|
||||
logger.info("Database initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Database initialization failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def _start_background_tasks(settings: Settings) -> dict:
|
||||
"""Start background tasks."""
|
||||
|
||||
logger.info("Starting background tasks...")
|
||||
|
||||
tasks = {}
|
||||
|
||||
try:
|
||||
# Start cleanup task
|
||||
if settings.cleanup_interval_seconds > 0:
|
||||
from src.tasks.cleanup import run_periodic_cleanup
|
||||
|
||||
cleanup_task = asyncio.create_task(run_periodic_cleanup(settings))
|
||||
tasks['cleanup'] = cleanup_task
|
||||
logger.info("✓ Cleanup task started")
|
||||
|
||||
# Start monitoring task
|
||||
if settings.monitoring_interval_seconds > 0:
|
||||
from src.tasks.monitoring import run_periodic_monitoring
|
||||
|
||||
monitoring_task = asyncio.create_task(run_periodic_monitoring(settings))
|
||||
tasks['monitoring'] = monitoring_task
|
||||
logger.info("✓ Monitoring task started")
|
||||
|
||||
# Start backup task
|
||||
if settings.backup_interval_seconds > 0:
|
||||
from src.tasks.backup import run_periodic_backup
|
||||
|
||||
backup_task = asyncio.create_task(run_periodic_backup(settings))
|
||||
tasks['backup'] = backup_task
|
||||
logger.info("✓ Backup task started")
|
||||
|
||||
logger.info(f"Started {len(tasks)} background tasks")
|
||||
return tasks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start background tasks: {e}")
|
||||
# Cancel any started tasks
|
||||
for task in tasks.values():
|
||||
task.cancel()
|
||||
raise
|
||||
|
||||
|
||||
async def _stop_background_tasks(tasks: dict) -> None:
|
||||
"""Stop background tasks gracefully."""
|
||||
|
||||
logger.info("Stopping background tasks...")
|
||||
|
||||
# Cancel all tasks
|
||||
for name, task in tasks.items():
|
||||
if not task.done():
|
||||
logger.info(f"Stopping {name} task...")
|
||||
task.cancel()
|
||||
|
||||
# Wait for tasks to complete
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks.values(), return_exceptions=True)
|
||||
|
||||
logger.info("Background tasks stopped")
|
||||
|
||||
|
||||
def _setup_signal_handlers() -> None:
|
||||
"""Setup signal handlers for graceful shutdown."""
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
logger.info(f"Received signal {signum}, initiating graceful shutdown...")
|
||||
# The actual shutdown will be handled by the main loop
|
||||
sys.exit(0)
|
||||
|
||||
# Setup signal handlers
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
if hasattr(signal, 'SIGHUP'):
|
||||
signal.signal(signal.SIGHUP, signal_handler)
|
||||
|
||||
|
||||
def _create_pid_file(settings: Settings) -> Path:
|
||||
"""Create PID file for daemon mode."""
|
||||
|
||||
pid_file = Path(settings.log_directory) / "wifi-densepose-api.pid"
|
||||
|
||||
# Check if PID file already exists
|
||||
if pid_file.exists():
|
||||
try:
|
||||
with open(pid_file, 'r') as f:
|
||||
old_pid = int(f.read().strip())
|
||||
|
||||
# Check if process is still running
|
||||
try:
|
||||
os.kill(old_pid, 0) # Signal 0 just checks if process exists
|
||||
logger.error(f"Server already running with PID {old_pid}")
|
||||
sys.exit(1)
|
||||
except OSError:
|
||||
# Process doesn't exist, remove stale PID file
|
||||
pid_file.unlink()
|
||||
logger.info("Removed stale PID file")
|
||||
|
||||
except (ValueError, IOError):
|
||||
# Invalid PID file, remove it
|
||||
pid_file.unlink()
|
||||
logger.info("Removed invalid PID file")
|
||||
|
||||
# Write current PID
|
||||
with open(pid_file, 'w') as f:
|
||||
f.write(str(os.getpid()))
|
||||
|
||||
logger.info(f"Created PID file: {pid_file}")
|
||||
return pid_file
|
||||
|
||||
|
||||
async def _run_server(config: dict) -> None:
|
||||
"""Run the server in foreground mode."""
|
||||
|
||||
logger.info("Starting server in foreground mode...")
|
||||
|
||||
# Create uvicorn server
|
||||
server = uvicorn.Server(uvicorn.Config(**config))
|
||||
|
||||
# Run server
|
||||
await server.serve()
|
||||
|
||||
|
||||
async def _run_as_daemon(config: dict, pid_file: Path) -> None:
|
||||
"""Run the server as a daemon."""
|
||||
|
||||
logger.info("Starting server in daemon mode...")
|
||||
|
||||
# Fork process
|
||||
try:
|
||||
pid = os.fork()
|
||||
if pid > 0:
|
||||
# Parent process
|
||||
logger.info(f"Server started as daemon with PID {pid}")
|
||||
sys.exit(0)
|
||||
except OSError as e:
|
||||
logger.error(f"Fork failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Child process continues
|
||||
|
||||
# Decouple from parent environment
|
||||
os.chdir("/")
|
||||
os.setsid()
|
||||
os.umask(0)
|
||||
|
||||
# Second fork
|
||||
try:
|
||||
pid = os.fork()
|
||||
if pid > 0:
|
||||
# Exit second parent
|
||||
sys.exit(0)
|
||||
except OSError as e:
|
||||
logger.error(f"Second fork failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Update PID file with daemon PID
|
||||
with open(pid_file, 'w') as f:
|
||||
f.write(str(os.getpid()))
|
||||
|
||||
# Redirect standard file descriptors
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
|
||||
# Redirect stdin, stdout, stderr to /dev/null
|
||||
with open('/dev/null', 'r') as f:
|
||||
os.dup2(f.fileno(), sys.stdin.fileno())
|
||||
|
||||
with open('/dev/null', 'w') as f:
|
||||
os.dup2(f.fileno(), sys.stdout.fileno())
|
||||
os.dup2(f.fileno(), sys.stderr.fileno())
|
||||
|
||||
# Create uvicorn server
|
||||
server = uvicorn.Server(uvicorn.Config(**config))
|
||||
|
||||
# Run server
|
||||
await server.serve()
|
||||
|
||||
|
||||
def get_server_status(settings: Settings) -> dict:
|
||||
"""Get current server status."""
|
||||
|
||||
pid_file = Path(settings.log_directory) / "wifi-densepose-api.pid"
|
||||
|
||||
status = {
|
||||
"running": False,
|
||||
"pid": None,
|
||||
"pid_file": str(pid_file),
|
||||
"pid_file_exists": pid_file.exists(),
|
||||
}
|
||||
|
||||
if pid_file.exists():
|
||||
try:
|
||||
with open(pid_file, 'r') as f:
|
||||
pid = int(f.read().strip())
|
||||
|
||||
status["pid"] = pid
|
||||
|
||||
# Check if process is running
|
||||
try:
|
||||
os.kill(pid, 0) # Signal 0 just checks if process exists
|
||||
status["running"] = True
|
||||
except OSError:
|
||||
# Process doesn't exist
|
||||
status["running"] = False
|
||||
|
||||
except (ValueError, IOError):
|
||||
# Invalid PID file
|
||||
status["running"] = False
|
||||
|
||||
return status
|
||||
501
v1/src/commands/status.py
Normal file
501
v1/src/commands/status.py
Normal file
@@ -0,0 +1,501 @@
|
||||
"""
|
||||
Status command implementation for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import psutil
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def status_command(
|
||||
settings: Settings,
|
||||
output_format: str = "text",
|
||||
detailed: bool = False
|
||||
) -> None:
|
||||
"""Show the status of the WiFi-DensePose API server."""
|
||||
|
||||
logger.debug("Gathering server status information...")
|
||||
|
||||
try:
|
||||
# Collect status information
|
||||
status_data = await _collect_status_data(settings, detailed)
|
||||
|
||||
# Output status
|
||||
if output_format == "json":
|
||||
print(json.dumps(status_data, indent=2, default=str))
|
||||
else:
|
||||
_print_text_status(status_data, detailed)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get status: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def _collect_status_data(settings: Settings, detailed: bool) -> Dict[str, Any]:
|
||||
"""Collect comprehensive status data."""
|
||||
|
||||
status_data = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"server": await _get_server_status(settings),
|
||||
"system": _get_system_status(),
|
||||
"configuration": _get_configuration_status(settings),
|
||||
}
|
||||
|
||||
if detailed:
|
||||
status_data.update({
|
||||
"database": await _get_database_status(settings),
|
||||
"background_tasks": await _get_background_tasks_status(settings),
|
||||
"resources": _get_resource_usage(),
|
||||
"health": await _get_health_status(settings),
|
||||
})
|
||||
|
||||
return status_data
|
||||
|
||||
|
||||
async def _get_server_status(settings: Settings) -> Dict[str, Any]:
|
||||
"""Get server process status."""
|
||||
|
||||
from src.commands.stop import get_server_status
|
||||
|
||||
status = get_server_status(settings)
|
||||
|
||||
server_info = {
|
||||
"running": status["running"],
|
||||
"pid": status["pid"],
|
||||
"pid_file": status["pid_file"],
|
||||
"pid_file_exists": status["pid_file_exists"],
|
||||
}
|
||||
|
||||
if status["running"] and status["pid"]:
|
||||
try:
|
||||
# Get process information
|
||||
process = psutil.Process(status["pid"])
|
||||
|
||||
server_info.update({
|
||||
"start_time": datetime.fromtimestamp(process.create_time()).isoformat(),
|
||||
"uptime_seconds": time.time() - process.create_time(),
|
||||
"memory_usage_mb": process.memory_info().rss / (1024 * 1024),
|
||||
"cpu_percent": process.cpu_percent(),
|
||||
"status": process.status(),
|
||||
"num_threads": process.num_threads(),
|
||||
"connections": len(process.connections()) if hasattr(process, 'connections') else None,
|
||||
})
|
||||
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied) as e:
|
||||
server_info["error"] = f"Cannot access process info: {e}"
|
||||
|
||||
return server_info
|
||||
|
||||
|
||||
def _get_system_status() -> Dict[str, Any]:
|
||||
"""Get system status information."""
|
||||
|
||||
uname_info = psutil.os.uname()
|
||||
return {
|
||||
"hostname": uname_info.nodename,
|
||||
"platform": uname_info.sysname,
|
||||
"architecture": uname_info.machine,
|
||||
"python_version": f"{psutil.sys.version_info.major}.{psutil.sys.version_info.minor}.{psutil.sys.version_info.micro}",
|
||||
"boot_time": datetime.fromtimestamp(psutil.boot_time()).isoformat(),
|
||||
"uptime_seconds": time.time() - psutil.boot_time(),
|
||||
}
|
||||
|
||||
|
||||
def _get_configuration_status(settings: Settings) -> Dict[str, Any]:
|
||||
"""Get configuration status."""
|
||||
|
||||
return {
|
||||
"environment": settings.environment,
|
||||
"debug": settings.debug,
|
||||
"version": settings.version,
|
||||
"host": settings.host,
|
||||
"port": settings.port,
|
||||
"database_configured": bool(settings.database_url or (settings.db_host and settings.db_name)),
|
||||
"redis_enabled": settings.redis_enabled,
|
||||
"monitoring_enabled": settings.monitoring_interval_seconds > 0,
|
||||
"cleanup_enabled": settings.cleanup_interval_seconds > 0,
|
||||
"backup_enabled": settings.backup_interval_seconds > 0,
|
||||
}
|
||||
|
||||
|
||||
async def _get_database_status(settings: Settings) -> Dict[str, Any]:
|
||||
"""Get database status."""
|
||||
|
||||
db_status = {
|
||||
"connected": False,
|
||||
"connection_pool": None,
|
||||
"tables": {},
|
||||
"error": None,
|
||||
}
|
||||
|
||||
try:
|
||||
from src.database.connection import get_database_manager
|
||||
|
||||
db_manager = get_database_manager(settings)
|
||||
|
||||
# Test connection
|
||||
await db_manager.test_connection()
|
||||
db_status["connected"] = True
|
||||
|
||||
# Get connection stats
|
||||
connection_stats = await db_manager.get_connection_stats()
|
||||
db_status["connection_pool"] = connection_stats
|
||||
|
||||
# Get table counts
|
||||
async with db_manager.get_async_session() as session:
|
||||
from sqlalchemy import text, func
|
||||
from src.database.models import Device, Session, CSIData, PoseDetection, SystemMetric, AuditLog
|
||||
|
||||
tables = {
|
||||
"devices": Device,
|
||||
"sessions": Session,
|
||||
"csi_data": CSIData,
|
||||
"pose_detections": PoseDetection,
|
||||
"system_metrics": SystemMetric,
|
||||
"audit_logs": AuditLog,
|
||||
}
|
||||
|
||||
for table_name, model in tables.items():
|
||||
try:
|
||||
result = await session.execute(
|
||||
text(f"SELECT COUNT(*) FROM {table_name}")
|
||||
)
|
||||
count = result.scalar()
|
||||
db_status["tables"][table_name] = {"count": count}
|
||||
except Exception as e:
|
||||
db_status["tables"][table_name] = {"error": str(e)}
|
||||
|
||||
except Exception as e:
|
||||
db_status["error"] = str(e)
|
||||
|
||||
return db_status
|
||||
|
||||
|
||||
async def _get_background_tasks_status(settings: Settings) -> Dict[str, Any]:
|
||||
"""Get background tasks status."""
|
||||
|
||||
tasks_status = {}
|
||||
|
||||
try:
|
||||
# Cleanup tasks
|
||||
from src.tasks.cleanup import get_cleanup_manager
|
||||
cleanup_manager = get_cleanup_manager(settings)
|
||||
tasks_status["cleanup"] = cleanup_manager.get_stats()
|
||||
|
||||
except Exception as e:
|
||||
tasks_status["cleanup"] = {"error": str(e)}
|
||||
|
||||
try:
|
||||
# Monitoring tasks
|
||||
from src.tasks.monitoring import get_monitoring_manager
|
||||
monitoring_manager = get_monitoring_manager(settings)
|
||||
tasks_status["monitoring"] = monitoring_manager.get_stats()
|
||||
|
||||
except Exception as e:
|
||||
tasks_status["monitoring"] = {"error": str(e)}
|
||||
|
||||
try:
|
||||
# Backup tasks
|
||||
from src.tasks.backup import get_backup_manager
|
||||
backup_manager = get_backup_manager(settings)
|
||||
tasks_status["backup"] = backup_manager.get_stats()
|
||||
|
||||
except Exception as e:
|
||||
tasks_status["backup"] = {"error": str(e)}
|
||||
|
||||
return tasks_status
|
||||
|
||||
|
||||
def _get_resource_usage() -> Dict[str, Any]:
|
||||
"""Get system resource usage."""
|
||||
|
||||
# CPU usage
|
||||
cpu_percent = psutil.cpu_percent(interval=1)
|
||||
cpu_count = psutil.cpu_count()
|
||||
|
||||
# Memory usage
|
||||
memory = psutil.virtual_memory()
|
||||
swap = psutil.swap_memory()
|
||||
|
||||
# Disk usage
|
||||
disk = psutil.disk_usage('/')
|
||||
|
||||
# Network I/O
|
||||
network = psutil.net_io_counters()
|
||||
|
||||
return {
|
||||
"cpu": {
|
||||
"usage_percent": cpu_percent,
|
||||
"count": cpu_count,
|
||||
},
|
||||
"memory": {
|
||||
"total_mb": memory.total / (1024 * 1024),
|
||||
"used_mb": memory.used / (1024 * 1024),
|
||||
"available_mb": memory.available / (1024 * 1024),
|
||||
"usage_percent": memory.percent,
|
||||
},
|
||||
"swap": {
|
||||
"total_mb": swap.total / (1024 * 1024),
|
||||
"used_mb": swap.used / (1024 * 1024),
|
||||
"usage_percent": swap.percent,
|
||||
},
|
||||
"disk": {
|
||||
"total_gb": disk.total / (1024 * 1024 * 1024),
|
||||
"used_gb": disk.used / (1024 * 1024 * 1024),
|
||||
"free_gb": disk.free / (1024 * 1024 * 1024),
|
||||
"usage_percent": (disk.used / disk.total) * 100,
|
||||
},
|
||||
"network": {
|
||||
"bytes_sent": network.bytes_sent,
|
||||
"bytes_recv": network.bytes_recv,
|
||||
"packets_sent": network.packets_sent,
|
||||
"packets_recv": network.packets_recv,
|
||||
} if network else None,
|
||||
}
|
||||
|
||||
|
||||
async def _get_health_status(settings: Settings) -> Dict[str, Any]:
|
||||
"""Get overall health status."""
|
||||
|
||||
health = {
|
||||
"status": "healthy",
|
||||
"checks": {},
|
||||
"issues": [],
|
||||
}
|
||||
|
||||
# Check database health
|
||||
try:
|
||||
from src.database.connection import get_database_manager
|
||||
|
||||
db_manager = get_database_manager(settings)
|
||||
await db_manager.test_connection()
|
||||
health["checks"]["database"] = "healthy"
|
||||
|
||||
except Exception as e:
|
||||
health["checks"]["database"] = "unhealthy"
|
||||
health["issues"].append(f"Database connection failed: {e}")
|
||||
health["status"] = "unhealthy"
|
||||
|
||||
# Check disk space
|
||||
disk = psutil.disk_usage('/')
|
||||
disk_usage_percent = (disk.used / disk.total) * 100
|
||||
|
||||
if disk_usage_percent > 90:
|
||||
health["checks"]["disk_space"] = "critical"
|
||||
health["issues"].append(f"Disk usage critical: {disk_usage_percent:.1f}%")
|
||||
health["status"] = "critical"
|
||||
elif disk_usage_percent > 80:
|
||||
health["checks"]["disk_space"] = "warning"
|
||||
health["issues"].append(f"Disk usage high: {disk_usage_percent:.1f}%")
|
||||
if health["status"] == "healthy":
|
||||
health["status"] = "warning"
|
||||
else:
|
||||
health["checks"]["disk_space"] = "healthy"
|
||||
|
||||
# Check memory usage
|
||||
memory = psutil.virtual_memory()
|
||||
|
||||
if memory.percent > 90:
|
||||
health["checks"]["memory"] = "critical"
|
||||
health["issues"].append(f"Memory usage critical: {memory.percent:.1f}%")
|
||||
health["status"] = "critical"
|
||||
elif memory.percent > 80:
|
||||
health["checks"]["memory"] = "warning"
|
||||
health["issues"].append(f"Memory usage high: {memory.percent:.1f}%")
|
||||
if health["status"] == "healthy":
|
||||
health["status"] = "warning"
|
||||
else:
|
||||
health["checks"]["memory"] = "healthy"
|
||||
|
||||
# Check log directory
|
||||
log_dir = Path(settings.log_directory)
|
||||
if log_dir.exists() and log_dir.is_dir():
|
||||
health["checks"]["log_directory"] = "healthy"
|
||||
else:
|
||||
health["checks"]["log_directory"] = "unhealthy"
|
||||
health["issues"].append(f"Log directory not accessible: {log_dir}")
|
||||
health["status"] = "unhealthy"
|
||||
|
||||
# Check backup directory
|
||||
backup_dir = Path(settings.backup_directory)
|
||||
if backup_dir.exists() and backup_dir.is_dir():
|
||||
health["checks"]["backup_directory"] = "healthy"
|
||||
else:
|
||||
health["checks"]["backup_directory"] = "unhealthy"
|
||||
health["issues"].append(f"Backup directory not accessible: {backup_dir}")
|
||||
health["status"] = "unhealthy"
|
||||
|
||||
return health
|
||||
|
||||
|
||||
def _print_text_status(status_data: Dict[str, Any], detailed: bool) -> None:
|
||||
"""Print status in human-readable text format."""
|
||||
|
||||
print("=" * 60)
|
||||
print("WiFi-DensePose API Server Status")
|
||||
print("=" * 60)
|
||||
print(f"Timestamp: {status_data['timestamp']}")
|
||||
print()
|
||||
|
||||
# Server status
|
||||
server = status_data["server"]
|
||||
print("🖥️ Server Status:")
|
||||
if server["running"]:
|
||||
print(f" ✅ Running (PID: {server['pid']})")
|
||||
if "start_time" in server:
|
||||
uptime = timedelta(seconds=int(server["uptime_seconds"]))
|
||||
print(f" ⏱️ Uptime: {uptime}")
|
||||
print(f" 💾 Memory: {server['memory_usage_mb']:.1f} MB")
|
||||
print(f" 🔧 CPU: {server['cpu_percent']:.1f}%")
|
||||
print(f" 🧵 Threads: {server['num_threads']}")
|
||||
else:
|
||||
print(" ❌ Not running")
|
||||
if server["pid_file_exists"]:
|
||||
print(" ⚠️ Stale PID file exists")
|
||||
print()
|
||||
|
||||
# System status
|
||||
system = status_data["system"]
|
||||
print("🖥️ System:")
|
||||
print(f" Hostname: {system['hostname']}")
|
||||
print(f" Platform: {system['platform']} ({system['architecture']})")
|
||||
print(f" Python: {system['python_version']}")
|
||||
uptime = timedelta(seconds=int(system["uptime_seconds"]))
|
||||
print(f" Uptime: {uptime}")
|
||||
print()
|
||||
|
||||
# Configuration
|
||||
config = status_data["configuration"]
|
||||
print("⚙️ Configuration:")
|
||||
print(f" Environment: {config['environment']}")
|
||||
print(f" Debug: {config['debug']}")
|
||||
print(f" API Version: {config['version']}")
|
||||
print(f" Listen: {config['host']}:{config['port']}")
|
||||
print(f" Database: {'✅' if config['database_configured'] else '❌'}")
|
||||
print(f" Redis: {'✅' if config['redis_enabled'] else '❌'}")
|
||||
print(f" Monitoring: {'✅' if config['monitoring_enabled'] else '❌'}")
|
||||
print(f" Cleanup: {'✅' if config['cleanup_enabled'] else '❌'}")
|
||||
print(f" Backup: {'✅' if config['backup_enabled'] else '❌'}")
|
||||
print()
|
||||
|
||||
if detailed:
|
||||
# Database status
|
||||
if "database" in status_data:
|
||||
db = status_data["database"]
|
||||
print("🗄️ Database:")
|
||||
if db["connected"]:
|
||||
print(" ✅ Connected")
|
||||
if "tables" in db:
|
||||
print(" 📊 Table counts:")
|
||||
for table, info in db["tables"].items():
|
||||
if "count" in info:
|
||||
print(f" {table}: {info['count']:,}")
|
||||
else:
|
||||
print(f" {table}: Error - {info.get('error', 'Unknown')}")
|
||||
else:
|
||||
print(f" ❌ Not connected: {db.get('error', 'Unknown error')}")
|
||||
print()
|
||||
|
||||
# Background tasks
|
||||
if "background_tasks" in status_data:
|
||||
tasks = status_data["background_tasks"]
|
||||
print("🔄 Background Tasks:")
|
||||
for task_name, task_info in tasks.items():
|
||||
if "error" in task_info:
|
||||
print(f" ❌ {task_name}: {task_info['error']}")
|
||||
else:
|
||||
manager_info = task_info.get("manager", {})
|
||||
print(f" 📋 {task_name}:")
|
||||
print(f" Running: {manager_info.get('running', 'Unknown')}")
|
||||
print(f" Last run: {manager_info.get('last_run', 'Never')}")
|
||||
print(f" Run count: {manager_info.get('run_count', 0)}")
|
||||
print()
|
||||
|
||||
# Resource usage
|
||||
if "resources" in status_data:
|
||||
resources = status_data["resources"]
|
||||
print("📊 Resource Usage:")
|
||||
|
||||
cpu = resources["cpu"]
|
||||
print(f" 🔧 CPU: {cpu['usage_percent']:.1f}% ({cpu['count']} cores)")
|
||||
|
||||
memory = resources["memory"]
|
||||
print(f" 💾 Memory: {memory['usage_percent']:.1f}% "
|
||||
f"({memory['used_mb']:.0f}/{memory['total_mb']:.0f} MB)")
|
||||
|
||||
disk = resources["disk"]
|
||||
print(f" 💿 Disk: {disk['usage_percent']:.1f}% "
|
||||
f"({disk['used_gb']:.1f}/{disk['total_gb']:.1f} GB)")
|
||||
print()
|
||||
|
||||
# Health status
|
||||
if "health" in status_data:
|
||||
health = status_data["health"]
|
||||
print("🏥 Health Status:")
|
||||
|
||||
status_emoji = {
|
||||
"healthy": "✅",
|
||||
"warning": "⚠️",
|
||||
"critical": "❌",
|
||||
"unhealthy": "❌"
|
||||
}
|
||||
|
||||
print(f" Overall: {status_emoji.get(health['status'], '❓')} {health['status'].upper()}")
|
||||
|
||||
if health["issues"]:
|
||||
print(" Issues:")
|
||||
for issue in health["issues"]:
|
||||
print(f" • {issue}")
|
||||
|
||||
print(" Checks:")
|
||||
for check, status in health["checks"].items():
|
||||
emoji = status_emoji.get(status, "❓")
|
||||
print(f" {emoji} {check}: {status}")
|
||||
print()
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def get_quick_status(settings: Settings) -> str:
|
||||
"""Get a quick one-line status."""
|
||||
|
||||
from src.commands.stop import get_server_status
|
||||
|
||||
status = get_server_status(settings)
|
||||
|
||||
if status["running"]:
|
||||
return f"✅ Running (PID: {status['pid']})"
|
||||
elif status["pid_file_exists"]:
|
||||
return "⚠️ Not running (stale PID file)"
|
||||
else:
|
||||
return "❌ Not running"
|
||||
|
||||
|
||||
async def check_health(settings: Settings) -> bool:
|
||||
"""Quick health check - returns True if healthy."""
|
||||
|
||||
try:
|
||||
status_data = await _collect_status_data(settings, detailed=True)
|
||||
|
||||
# Check if server is running
|
||||
if not status_data["server"]["running"]:
|
||||
return False
|
||||
|
||||
# Check health status
|
||||
if "health" in status_data:
|
||||
health_status = status_data["health"]["status"]
|
||||
return health_status in ["healthy", "warning"]
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
294
v1/src/commands/stop.py
Normal file
294
v1/src/commands/stop.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""
|
||||
Stop command implementation for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def stop_command(
|
||||
settings: Settings,
|
||||
force: bool = False,
|
||||
timeout: int = 30
|
||||
) -> None:
|
||||
"""Stop the WiFi-DensePose API server."""
|
||||
|
||||
logger.info("Stopping WiFi-DensePose API server...")
|
||||
|
||||
# Get server status
|
||||
status = get_server_status(settings)
|
||||
|
||||
if not status["running"]:
|
||||
if status["pid_file_exists"]:
|
||||
logger.info("Server is not running, but PID file exists. Cleaning up...")
|
||||
_cleanup_pid_file(settings)
|
||||
else:
|
||||
logger.info("Server is not running")
|
||||
return
|
||||
|
||||
pid = status["pid"]
|
||||
logger.info(f"Found running server with PID {pid}")
|
||||
|
||||
try:
|
||||
if force:
|
||||
await _force_stop_server(pid, settings)
|
||||
else:
|
||||
await _graceful_stop_server(pid, timeout, settings)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop server: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def _graceful_stop_server(pid: int, timeout: int, settings: Settings) -> None:
|
||||
"""Stop server gracefully with timeout."""
|
||||
|
||||
logger.info(f"Attempting graceful shutdown (timeout: {timeout}s)...")
|
||||
|
||||
try:
|
||||
# Send SIGTERM for graceful shutdown
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
logger.info("Sent SIGTERM signal")
|
||||
|
||||
# Wait for process to terminate
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
# Check if process is still running
|
||||
os.kill(pid, 0)
|
||||
await asyncio.sleep(1)
|
||||
except OSError:
|
||||
# Process has terminated
|
||||
logger.info("Server stopped gracefully")
|
||||
_cleanup_pid_file(settings)
|
||||
return
|
||||
|
||||
# Timeout reached, force kill
|
||||
logger.warning(f"Graceful shutdown timeout ({timeout}s) reached, forcing stop...")
|
||||
await _force_stop_server(pid, settings)
|
||||
|
||||
except OSError as e:
|
||||
if e.errno == 3: # No such process
|
||||
logger.info("Process already terminated")
|
||||
_cleanup_pid_file(settings)
|
||||
else:
|
||||
logger.error(f"Failed to send signal to process {pid}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def _force_stop_server(pid: int, settings: Settings) -> None:
|
||||
"""Force stop server immediately."""
|
||||
|
||||
logger.info("Force stopping server...")
|
||||
|
||||
try:
|
||||
# Send SIGKILL for immediate termination
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
logger.info("Sent SIGKILL signal")
|
||||
|
||||
# Wait a moment for process to die
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Verify process is dead
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
logger.error(f"Process {pid} still running after SIGKILL")
|
||||
except OSError:
|
||||
logger.info("Server force stopped")
|
||||
|
||||
except OSError as e:
|
||||
if e.errno == 3: # No such process
|
||||
logger.info("Process already terminated")
|
||||
else:
|
||||
logger.error(f"Failed to force kill process {pid}: {e}")
|
||||
raise
|
||||
|
||||
finally:
|
||||
_cleanup_pid_file(settings)
|
||||
|
||||
|
||||
def _cleanup_pid_file(settings: Settings) -> None:
|
||||
"""Clean up PID file."""
|
||||
|
||||
pid_file = Path(settings.log_directory) / "wifi-densepose-api.pid"
|
||||
|
||||
if pid_file.exists():
|
||||
try:
|
||||
pid_file.unlink()
|
||||
logger.info("Cleaned up PID file")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to remove PID file: {e}")
|
||||
|
||||
|
||||
def get_server_status(settings: Settings) -> dict:
|
||||
"""Get current server status."""
|
||||
|
||||
pid_file = Path(settings.log_directory) / "wifi-densepose-api.pid"
|
||||
|
||||
status = {
|
||||
"running": False,
|
||||
"pid": None,
|
||||
"pid_file": str(pid_file),
|
||||
"pid_file_exists": pid_file.exists(),
|
||||
}
|
||||
|
||||
if pid_file.exists():
|
||||
try:
|
||||
with open(pid_file, 'r') as f:
|
||||
pid = int(f.read().strip())
|
||||
|
||||
status["pid"] = pid
|
||||
|
||||
# Check if process is running
|
||||
try:
|
||||
os.kill(pid, 0) # Signal 0 just checks if process exists
|
||||
status["running"] = True
|
||||
except OSError:
|
||||
# Process doesn't exist
|
||||
status["running"] = False
|
||||
|
||||
except (ValueError, IOError):
|
||||
# Invalid PID file
|
||||
status["running"] = False
|
||||
|
||||
return status
|
||||
|
||||
|
||||
async def stop_all_background_tasks(settings: Settings) -> None:
|
||||
"""Stop all background tasks if they're running."""
|
||||
|
||||
logger.info("Stopping background tasks...")
|
||||
|
||||
try:
|
||||
# This would typically involve connecting to a task queue or
|
||||
# sending signals to background processes
|
||||
# For now, we'll just log the action
|
||||
|
||||
logger.info("Background tasks stop signal sent")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop background tasks: {e}")
|
||||
|
||||
|
||||
async def cleanup_resources(settings: Settings) -> None:
|
||||
"""Clean up system resources."""
|
||||
|
||||
logger.info("Cleaning up resources...")
|
||||
|
||||
try:
|
||||
# Close database connections
|
||||
from src.database.connection import get_database_manager
|
||||
|
||||
db_manager = get_database_manager(settings)
|
||||
await db_manager.close_all_connections()
|
||||
logger.info("Database connections closed")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to close database connections: {e}")
|
||||
|
||||
try:
|
||||
# Clean up temporary files
|
||||
temp_files = [
|
||||
Path(settings.log_directory) / "temp",
|
||||
Path(settings.backup_directory) / "temp",
|
||||
]
|
||||
|
||||
for temp_path in temp_files:
|
||||
if temp_path.exists() and temp_path.is_dir():
|
||||
import shutil
|
||||
shutil.rmtree(temp_path)
|
||||
logger.info(f"Cleaned up temporary directory: {temp_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temporary files: {e}")
|
||||
|
||||
logger.info("Resource cleanup completed")
|
||||
|
||||
|
||||
def is_server_running(settings: Settings) -> bool:
|
||||
"""Check if server is currently running."""
|
||||
|
||||
status = get_server_status(settings)
|
||||
return status["running"]
|
||||
|
||||
|
||||
def get_server_pid(settings: Settings) -> Optional[int]:
|
||||
"""Get server PID if running."""
|
||||
|
||||
status = get_server_status(settings)
|
||||
return status["pid"] if status["running"] else None
|
||||
|
||||
|
||||
async def wait_for_server_stop(settings: Settings, timeout: int = 30) -> bool:
|
||||
"""Wait for server to stop with timeout."""
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
if not is_server_running(settings):
|
||||
return True
|
||||
await asyncio.sleep(1)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def send_reload_signal(settings: Settings) -> bool:
|
||||
"""Send reload signal to running server."""
|
||||
|
||||
status = get_server_status(settings)
|
||||
|
||||
if not status["running"]:
|
||||
logger.error("Server is not running")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Send SIGHUP for reload
|
||||
os.kill(status["pid"], signal.SIGHUP)
|
||||
logger.info("Sent reload signal to server")
|
||||
return True
|
||||
|
||||
except OSError as e:
|
||||
logger.error(f"Failed to send reload signal: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def restart_server(settings: Settings, timeout: int = 30) -> None:
|
||||
"""Restart the server (stop then start)."""
|
||||
|
||||
logger.info("Restarting server...")
|
||||
|
||||
# Stop server if running
|
||||
if is_server_running(settings):
|
||||
await stop_command(settings, timeout=timeout)
|
||||
|
||||
# Wait for server to stop
|
||||
if not await wait_for_server_stop(settings, timeout):
|
||||
logger.error("Server did not stop within timeout, forcing restart")
|
||||
await stop_command(settings, force=True)
|
||||
|
||||
# Start server
|
||||
from src.commands.start import start_command
|
||||
await start_command(settings)
|
||||
|
||||
|
||||
def get_stop_status_summary(settings: Settings) -> dict:
|
||||
"""Get a summary of stop operation status."""
|
||||
|
||||
status = get_server_status(settings)
|
||||
|
||||
return {
|
||||
"server_running": status["running"],
|
||||
"pid": status["pid"],
|
||||
"pid_file_exists": status["pid_file_exists"],
|
||||
"can_stop": status["running"],
|
||||
"cleanup_needed": status["pid_file_exists"] and not status["running"],
|
||||
}
|
||||
310
v1/src/config.py
Normal file
310
v1/src/config.py
Normal file
@@ -0,0 +1,310 @@
|
||||
"""
|
||||
Centralized configuration management for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List
|
||||
from functools import lru_cache
|
||||
|
||||
from src.config.settings import Settings, get_settings
|
||||
from src.config.domains import DomainConfig, get_domain_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""Centralized configuration manager."""
|
||||
|
||||
def __init__(self):
|
||||
self._settings: Optional[Settings] = None
|
||||
self._domain_config: Optional[DomainConfig] = None
|
||||
self._environment_overrides: Dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def settings(self) -> Settings:
|
||||
"""Get application settings."""
|
||||
if self._settings is None:
|
||||
self._settings = get_settings()
|
||||
return self._settings
|
||||
|
||||
@property
|
||||
def domain_config(self) -> DomainConfig:
|
||||
"""Get domain configuration."""
|
||||
if self._domain_config is None:
|
||||
self._domain_config = get_domain_config()
|
||||
return self._domain_config
|
||||
|
||||
def reload_settings(self) -> Settings:
|
||||
"""Reload settings from environment."""
|
||||
self._settings = None
|
||||
return self.settings
|
||||
|
||||
def reload_domain_config(self) -> DomainConfig:
|
||||
"""Reload domain configuration."""
|
||||
self._domain_config = None
|
||||
return self.domain_config
|
||||
|
||||
def set_environment_override(self, key: str, value: Any):
|
||||
"""Set environment variable override."""
|
||||
self._environment_overrides[key] = value
|
||||
os.environ[key] = str(value)
|
||||
|
||||
def get_environment_override(self, key: str, default: Any = None) -> Any:
|
||||
"""Get environment variable override."""
|
||||
return self._environment_overrides.get(key, os.environ.get(key, default))
|
||||
|
||||
def clear_environment_overrides(self):
|
||||
"""Clear all environment overrides."""
|
||||
for key in self._environment_overrides:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
self._environment_overrides.clear()
|
||||
|
||||
def get_database_config(self) -> Dict[str, Any]:
|
||||
"""Get database configuration."""
|
||||
settings = self.settings
|
||||
|
||||
config = {
|
||||
"url": settings.get_database_url(),
|
||||
"pool_size": settings.database_pool_size,
|
||||
"max_overflow": settings.database_max_overflow,
|
||||
"echo": settings.is_development and settings.debug,
|
||||
"pool_pre_ping": True,
|
||||
"pool_recycle": 3600, # 1 hour
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
def get_redis_config(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get Redis configuration."""
|
||||
settings = self.settings
|
||||
redis_url = settings.get_redis_url()
|
||||
|
||||
if not redis_url:
|
||||
return None
|
||||
|
||||
config = {
|
||||
"url": redis_url,
|
||||
"password": settings.redis_password,
|
||||
"db": settings.redis_db,
|
||||
"decode_responses": True,
|
||||
"socket_connect_timeout": 5,
|
||||
"socket_timeout": 5,
|
||||
"retry_on_timeout": True,
|
||||
"health_check_interval": 30,
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
def get_logging_config(self) -> Dict[str, Any]:
|
||||
"""Get logging configuration."""
|
||||
return self.settings.get_logging_config()
|
||||
|
||||
def get_cors_config(self) -> Dict[str, Any]:
|
||||
"""Get CORS configuration."""
|
||||
return self.settings.get_cors_config()
|
||||
|
||||
def get_security_config(self) -> Dict[str, Any]:
|
||||
"""Get security configuration."""
|
||||
settings = self.settings
|
||||
|
||||
config = {
|
||||
"secret_key": settings.secret_key,
|
||||
"jwt_algorithm": settings.jwt_algorithm,
|
||||
"jwt_expire_hours": settings.jwt_expire_hours,
|
||||
"allowed_hosts": settings.allowed_hosts,
|
||||
"enable_authentication": settings.enable_authentication,
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
def get_hardware_config(self) -> Dict[str, Any]:
|
||||
"""Get hardware configuration."""
|
||||
settings = self.settings
|
||||
domain_config = self.domain_config
|
||||
|
||||
config = {
|
||||
"wifi_interface": settings.wifi_interface,
|
||||
"csi_buffer_size": settings.csi_buffer_size,
|
||||
"polling_interval": settings.hardware_polling_interval,
|
||||
"mock_hardware": settings.mock_hardware,
|
||||
"routers": [router.dict() for router in domain_config.routers],
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
def get_pose_config(self) -> Dict[str, Any]:
|
||||
"""Get pose estimation configuration."""
|
||||
settings = self.settings
|
||||
domain_config = self.domain_config
|
||||
|
||||
config = {
|
||||
"model_path": settings.pose_model_path,
|
||||
"confidence_threshold": settings.pose_confidence_threshold,
|
||||
"batch_size": settings.pose_processing_batch_size,
|
||||
"max_persons": settings.pose_max_persons,
|
||||
"mock_pose_data": settings.mock_pose_data,
|
||||
"models": [model.dict() for model in domain_config.pose_models],
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
def get_streaming_config(self) -> Dict[str, Any]:
|
||||
"""Get streaming configuration."""
|
||||
settings = self.settings
|
||||
domain_config = self.domain_config
|
||||
|
||||
config = {
|
||||
"fps": settings.stream_fps,
|
||||
"buffer_size": settings.stream_buffer_size,
|
||||
"websocket_ping_interval": settings.websocket_ping_interval,
|
||||
"websocket_timeout": settings.websocket_timeout,
|
||||
"enable_websockets": settings.enable_websockets,
|
||||
"enable_real_time_processing": settings.enable_real_time_processing,
|
||||
"max_connections": domain_config.streaming.max_connections,
|
||||
"compression": domain_config.streaming.compression,
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
def get_storage_config(self) -> Dict[str, Any]:
|
||||
"""Get storage configuration."""
|
||||
settings = self.settings
|
||||
|
||||
config = {
|
||||
"data_path": Path(settings.data_storage_path),
|
||||
"model_path": Path(settings.model_storage_path),
|
||||
"temp_path": Path(settings.temp_storage_path),
|
||||
"max_size_gb": settings.max_storage_size_gb,
|
||||
"enable_historical_data": settings.enable_historical_data,
|
||||
}
|
||||
|
||||
# Ensure directories exist
|
||||
for path in [config["data_path"], config["model_path"], config["temp_path"]]:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return config
|
||||
|
||||
def get_monitoring_config(self) -> Dict[str, Any]:
|
||||
"""Get monitoring configuration."""
|
||||
settings = self.settings
|
||||
|
||||
config = {
|
||||
"metrics_enabled": settings.metrics_enabled,
|
||||
"health_check_interval": settings.health_check_interval,
|
||||
"performance_monitoring": settings.performance_monitoring,
|
||||
"log_level": settings.log_level,
|
||||
"log_file": settings.log_file,
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
def get_rate_limiting_config(self) -> Dict[str, Any]:
|
||||
"""Get rate limiting configuration."""
|
||||
settings = self.settings
|
||||
|
||||
config = {
|
||||
"enabled": settings.enable_rate_limiting,
|
||||
"requests": settings.rate_limit_requests,
|
||||
"authenticated_requests": settings.rate_limit_authenticated_requests,
|
||||
"window": settings.rate_limit_window,
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
def validate_configuration(self) -> List[str]:
|
||||
"""Validate complete configuration and return issues."""
|
||||
issues = []
|
||||
|
||||
try:
|
||||
# Validate settings
|
||||
from src.config.settings import validate_settings
|
||||
settings_issues = validate_settings(self.settings)
|
||||
issues.extend(settings_issues)
|
||||
|
||||
# Validate database configuration
|
||||
try:
|
||||
db_config = self.get_database_config()
|
||||
if not db_config["url"]:
|
||||
issues.append("Database URL is not configured")
|
||||
except Exception as e:
|
||||
issues.append(f"Database configuration error: {e}")
|
||||
|
||||
# Validate storage paths
|
||||
try:
|
||||
storage_config = self.get_storage_config()
|
||||
for name, path in storage_config.items():
|
||||
if name.endswith("_path") and not path.exists():
|
||||
issues.append(f"Storage path does not exist: {path}")
|
||||
except Exception as e:
|
||||
issues.append(f"Storage configuration error: {e}")
|
||||
|
||||
# Validate hardware configuration
|
||||
try:
|
||||
hw_config = self.get_hardware_config()
|
||||
if not hw_config["routers"]:
|
||||
issues.append("No routers configured")
|
||||
except Exception as e:
|
||||
issues.append(f"Hardware configuration error: {e}")
|
||||
|
||||
# Validate pose configuration
|
||||
try:
|
||||
pose_config = self.get_pose_config()
|
||||
if not pose_config["models"]:
|
||||
issues.append("No pose models configured")
|
||||
except Exception as e:
|
||||
issues.append(f"Pose configuration error: {e}")
|
||||
|
||||
except Exception as e:
|
||||
issues.append(f"Configuration validation error: {e}")
|
||||
|
||||
return issues
|
||||
|
||||
def get_full_config(self) -> Dict[str, Any]:
|
||||
"""Get complete configuration dictionary."""
|
||||
return {
|
||||
"settings": self.settings.dict(),
|
||||
"domain_config": self.domain_config.to_dict(),
|
||||
"database": self.get_database_config(),
|
||||
"redis": self.get_redis_config(),
|
||||
"security": self.get_security_config(),
|
||||
"hardware": self.get_hardware_config(),
|
||||
"pose": self.get_pose_config(),
|
||||
"streaming": self.get_streaming_config(),
|
||||
"storage": self.get_storage_config(),
|
||||
"monitoring": self.get_monitoring_config(),
|
||||
"rate_limiting": self.get_rate_limiting_config(),
|
||||
}
|
||||
|
||||
|
||||
# Global configuration manager instance
|
||||
@lru_cache()
|
||||
def get_config_manager() -> ConfigManager:
|
||||
"""Get cached configuration manager instance."""
|
||||
return ConfigManager()
|
||||
|
||||
|
||||
# Convenience functions
|
||||
def get_app_settings() -> Settings:
|
||||
"""Get application settings."""
|
||||
return get_config_manager().settings
|
||||
|
||||
|
||||
def get_app_domain_config() -> DomainConfig:
|
||||
"""Get domain configuration."""
|
||||
return get_config_manager().domain_config
|
||||
|
||||
|
||||
def validate_app_configuration() -> List[str]:
|
||||
"""Validate application configuration."""
|
||||
return get_config_manager().validate_configuration()
|
||||
|
||||
|
||||
def reload_configuration():
|
||||
"""Reload all configuration."""
|
||||
config_manager = get_config_manager()
|
||||
config_manager.reload_settings()
|
||||
config_manager.reload_domain_config()
|
||||
logger.info("Configuration reloaded")
|
||||
8
v1/src/config/__init__.py
Normal file
8
v1/src/config/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Configuration management package
|
||||
"""
|
||||
|
||||
from .settings import get_settings, Settings
|
||||
from .domains import DomainConfig, get_domain_config
|
||||
|
||||
__all__ = ["get_settings", "Settings", "DomainConfig", "get_domain_config"]
|
||||
481
v1/src/config/domains.py
Normal file
481
v1/src/config/domains.py
Normal file
@@ -0,0 +1,481 @@
|
||||
"""
|
||||
Domain-specific configuration for WiFi-DensePose
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
|
||||
class ZoneType(str, Enum):
|
||||
"""Zone types for pose detection."""
|
||||
ROOM = "room"
|
||||
HALLWAY = "hallway"
|
||||
ENTRANCE = "entrance"
|
||||
OUTDOOR = "outdoor"
|
||||
OFFICE = "office"
|
||||
MEETING_ROOM = "meeting_room"
|
||||
KITCHEN = "kitchen"
|
||||
BATHROOM = "bathroom"
|
||||
BEDROOM = "bedroom"
|
||||
LIVING_ROOM = "living_room"
|
||||
|
||||
|
||||
class ActivityType(str, Enum):
|
||||
"""Activity types for pose classification."""
|
||||
STANDING = "standing"
|
||||
SITTING = "sitting"
|
||||
WALKING = "walking"
|
||||
LYING = "lying"
|
||||
RUNNING = "running"
|
||||
JUMPING = "jumping"
|
||||
FALLING = "falling"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class HardwareType(str, Enum):
|
||||
"""Hardware types for WiFi devices."""
|
||||
ROUTER = "router"
|
||||
ACCESS_POINT = "access_point"
|
||||
REPEATER = "repeater"
|
||||
MESH_NODE = "mesh_node"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ZoneConfig:
|
||||
"""Configuration for a detection zone."""
|
||||
|
||||
zone_id: str
|
||||
name: str
|
||||
zone_type: ZoneType
|
||||
description: Optional[str] = None
|
||||
|
||||
# Physical boundaries (in meters)
|
||||
x_min: float = 0.0
|
||||
x_max: float = 10.0
|
||||
y_min: float = 0.0
|
||||
y_max: float = 10.0
|
||||
z_min: float = 0.0
|
||||
z_max: float = 3.0
|
||||
|
||||
# Detection settings
|
||||
enabled: bool = True
|
||||
confidence_threshold: float = 0.5
|
||||
max_persons: int = 5
|
||||
activity_detection: bool = True
|
||||
|
||||
# Hardware assignments
|
||||
primary_router: Optional[str] = None
|
||||
secondary_routers: List[str] = field(default_factory=list)
|
||||
|
||||
# Processing settings
|
||||
processing_interval: float = 0.1 # seconds
|
||||
data_retention_hours: int = 24
|
||||
|
||||
# Alert settings
|
||||
enable_alerts: bool = False
|
||||
alert_threshold: float = 0.8
|
||||
alert_activities: List[ActivityType] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RouterConfig:
|
||||
"""Configuration for a WiFi router/device."""
|
||||
|
||||
router_id: str
|
||||
name: str
|
||||
hardware_type: HardwareType
|
||||
|
||||
# Network settings
|
||||
ip_address: str
|
||||
mac_address: str
|
||||
interface: str = "wlan0"
|
||||
channel: int = 6
|
||||
frequency: float = 2.4 # GHz
|
||||
|
||||
# CSI settings
|
||||
csi_enabled: bool = True
|
||||
csi_rate: int = 100 # Hz
|
||||
csi_subcarriers: int = 56
|
||||
antenna_count: int = 3
|
||||
|
||||
# Position (in meters)
|
||||
x_position: float = 0.0
|
||||
y_position: float = 0.0
|
||||
z_position: float = 2.5 # typical ceiling mount
|
||||
|
||||
# Calibration
|
||||
calibrated: bool = False
|
||||
calibration_data: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Status
|
||||
enabled: bool = True
|
||||
last_seen: Optional[str] = None
|
||||
|
||||
# Performance settings
|
||||
max_connections: int = 50
|
||||
power_level: int = 20 # dBm
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"router_id": self.router_id,
|
||||
"name": self.name,
|
||||
"hardware_type": self.hardware_type.value,
|
||||
"ip_address": self.ip_address,
|
||||
"mac_address": self.mac_address,
|
||||
"interface": self.interface,
|
||||
"channel": self.channel,
|
||||
"frequency": self.frequency,
|
||||
"csi_enabled": self.csi_enabled,
|
||||
"csi_rate": self.csi_rate,
|
||||
"csi_subcarriers": self.csi_subcarriers,
|
||||
"antenna_count": self.antenna_count,
|
||||
"position": {
|
||||
"x": self.x_position,
|
||||
"y": self.y_position,
|
||||
"z": self.z_position
|
||||
},
|
||||
"calibrated": self.calibrated,
|
||||
"calibration_data": self.calibration_data,
|
||||
"enabled": self.enabled,
|
||||
"last_seen": self.last_seen,
|
||||
"max_connections": self.max_connections,
|
||||
"power_level": self.power_level
|
||||
}
|
||||
|
||||
|
||||
class PoseModelConfig(BaseModel):
|
||||
"""Configuration for pose estimation models."""
|
||||
|
||||
model_name: str = Field(..., description="Model name")
|
||||
model_path: str = Field(..., description="Path to model file")
|
||||
model_type: str = Field(default="densepose", description="Model type")
|
||||
|
||||
# Input settings
|
||||
input_width: int = Field(default=256, description="Input image width")
|
||||
input_height: int = Field(default=256, description="Input image height")
|
||||
input_channels: int = Field(default=3, description="Input channels")
|
||||
|
||||
# Processing settings
|
||||
batch_size: int = Field(default=1, description="Batch size for inference")
|
||||
confidence_threshold: float = Field(default=0.5, description="Confidence threshold")
|
||||
nms_threshold: float = Field(default=0.4, description="NMS threshold")
|
||||
|
||||
# Output settings
|
||||
max_detections: int = Field(default=10, description="Maximum detections per frame")
|
||||
keypoint_count: int = Field(default=17, description="Number of keypoints")
|
||||
|
||||
# Performance settings
|
||||
use_gpu: bool = Field(default=True, description="Use GPU acceleration")
|
||||
gpu_memory_fraction: float = Field(default=0.5, description="GPU memory fraction")
|
||||
num_threads: int = Field(default=4, description="Number of CPU threads")
|
||||
|
||||
@validator("confidence_threshold", "nms_threshold", "gpu_memory_fraction")
|
||||
def validate_thresholds(cls, v):
|
||||
"""Validate threshold values."""
|
||||
if not 0.0 <= v <= 1.0:
|
||||
raise ValueError("Threshold must be between 0.0 and 1.0")
|
||||
return v
|
||||
|
||||
|
||||
class StreamingConfig(BaseModel):
|
||||
"""Configuration for real-time streaming."""
|
||||
|
||||
# Stream settings
|
||||
fps: int = Field(default=30, description="Frames per second")
|
||||
resolution: str = Field(default="720p", description="Stream resolution")
|
||||
quality: str = Field(default="medium", description="Stream quality")
|
||||
|
||||
# Buffer settings
|
||||
buffer_size: int = Field(default=100, description="Buffer size")
|
||||
max_latency_ms: int = Field(default=100, description="Maximum latency in milliseconds")
|
||||
|
||||
# Compression settings
|
||||
compression_enabled: bool = Field(default=True, description="Enable compression")
|
||||
compression_level: int = Field(default=5, description="Compression level (1-9)")
|
||||
|
||||
# WebSocket settings
|
||||
ping_interval: int = Field(default=60, description="Ping interval in seconds")
|
||||
timeout: int = Field(default=300, description="Connection timeout in seconds")
|
||||
max_connections: int = Field(default=100, description="Maximum concurrent connections")
|
||||
|
||||
# Data filtering
|
||||
min_confidence: float = Field(default=0.5, description="Minimum confidence for streaming")
|
||||
include_metadata: bool = Field(default=True, description="Include metadata in stream")
|
||||
|
||||
@validator("fps")
|
||||
def validate_fps(cls, v):
|
||||
"""Validate FPS value."""
|
||||
if not 1 <= v <= 60:
|
||||
raise ValueError("FPS must be between 1 and 60")
|
||||
return v
|
||||
|
||||
@validator("compression_level")
|
||||
def validate_compression_level(cls, v):
|
||||
"""Validate compression level."""
|
||||
if not 1 <= v <= 9:
|
||||
raise ValueError("Compression level must be between 1 and 9")
|
||||
return v
|
||||
|
||||
|
||||
class AlertConfig(BaseModel):
|
||||
"""Configuration for alerts and notifications."""
|
||||
|
||||
# Alert types
|
||||
enable_pose_alerts: bool = Field(default=False, description="Enable pose-based alerts")
|
||||
enable_activity_alerts: bool = Field(default=False, description="Enable activity-based alerts")
|
||||
enable_zone_alerts: bool = Field(default=False, description="Enable zone-based alerts")
|
||||
enable_system_alerts: bool = Field(default=True, description="Enable system alerts")
|
||||
|
||||
# Thresholds
|
||||
confidence_threshold: float = Field(default=0.8, description="Alert confidence threshold")
|
||||
duration_threshold: int = Field(default=5, description="Alert duration threshold in seconds")
|
||||
|
||||
# Activities that trigger alerts
|
||||
alert_activities: List[ActivityType] = Field(
|
||||
default=[ActivityType.FALLING],
|
||||
description="Activities that trigger alerts"
|
||||
)
|
||||
|
||||
# Notification settings
|
||||
email_enabled: bool = Field(default=False, description="Enable email notifications")
|
||||
webhook_enabled: bool = Field(default=False, description="Enable webhook notifications")
|
||||
sms_enabled: bool = Field(default=False, description="Enable SMS notifications")
|
||||
|
||||
# Rate limiting
|
||||
max_alerts_per_hour: int = Field(default=10, description="Maximum alerts per hour")
|
||||
cooldown_minutes: int = Field(default=5, description="Cooldown between similar alerts")
|
||||
|
||||
|
||||
class DomainConfig:
|
||||
"""Main domain configuration container."""
|
||||
|
||||
def __init__(self):
|
||||
self.zones: Dict[str, ZoneConfig] = {}
|
||||
self.routers: Dict[str, RouterConfig] = {}
|
||||
self.pose_models: Dict[str, PoseModelConfig] = {}
|
||||
self.streaming = StreamingConfig()
|
||||
self.alerts = AlertConfig()
|
||||
|
||||
# Load default configurations
|
||||
self._load_defaults()
|
||||
|
||||
def _load_defaults(self):
|
||||
"""Load default configurations."""
|
||||
# Default pose model
|
||||
self.pose_models["default"] = PoseModelConfig(
|
||||
model_name="densepose_rcnn_R_50_FPN_s1x",
|
||||
model_path="./models/densepose_rcnn_R_50_FPN_s1x.pkl",
|
||||
model_type="densepose"
|
||||
)
|
||||
|
||||
# Example zone
|
||||
self.zones["living_room"] = ZoneConfig(
|
||||
zone_id="living_room",
|
||||
name="Living Room",
|
||||
zone_type=ZoneType.LIVING_ROOM,
|
||||
description="Main living area",
|
||||
x_max=5.0,
|
||||
y_max=4.0,
|
||||
z_max=3.0
|
||||
)
|
||||
|
||||
# Example router
|
||||
self.routers["main_router"] = RouterConfig(
|
||||
router_id="main_router",
|
||||
name="Main Router",
|
||||
hardware_type=HardwareType.ROUTER,
|
||||
ip_address="192.168.1.1",
|
||||
mac_address="00:11:22:33:44:55",
|
||||
x_position=2.5,
|
||||
y_position=2.0,
|
||||
z_position=2.5
|
||||
)
|
||||
|
||||
def add_zone(self, zone: ZoneConfig):
|
||||
"""Add a zone configuration."""
|
||||
self.zones[zone.zone_id] = zone
|
||||
|
||||
def add_router(self, router: RouterConfig):
|
||||
"""Add a router configuration."""
|
||||
self.routers[router.router_id] = router
|
||||
|
||||
def add_pose_model(self, model: PoseModelConfig):
|
||||
"""Add a pose model configuration."""
|
||||
self.pose_models[model.model_name] = model
|
||||
|
||||
def get_zone(self, zone_id: str) -> Optional[ZoneConfig]:
|
||||
"""Get zone configuration by ID."""
|
||||
return self.zones.get(zone_id)
|
||||
|
||||
def get_router(self, router_id: str) -> Optional[RouterConfig]:
|
||||
"""Get router configuration by ID."""
|
||||
return self.routers.get(router_id)
|
||||
|
||||
def get_pose_model(self, model_name: str) -> Optional[PoseModelConfig]:
|
||||
"""Get pose model configuration by name."""
|
||||
return self.pose_models.get(model_name)
|
||||
|
||||
def get_zones_for_router(self, router_id: str) -> List[ZoneConfig]:
|
||||
"""Get zones that use a specific router."""
|
||||
zones = []
|
||||
for zone in self.zones.values():
|
||||
if (zone.primary_router == router_id or
|
||||
router_id in zone.secondary_routers):
|
||||
zones.append(zone)
|
||||
return zones
|
||||
|
||||
def get_routers_for_zone(self, zone_id: str) -> List[RouterConfig]:
|
||||
"""Get routers assigned to a specific zone."""
|
||||
zone = self.get_zone(zone_id)
|
||||
if not zone:
|
||||
return []
|
||||
|
||||
routers = []
|
||||
|
||||
# Add primary router
|
||||
if zone.primary_router and zone.primary_router in self.routers:
|
||||
routers.append(self.routers[zone.primary_router])
|
||||
|
||||
# Add secondary routers
|
||||
for router_id in zone.secondary_routers:
|
||||
if router_id in self.routers:
|
||||
routers.append(self.routers[router_id])
|
||||
|
||||
return routers
|
||||
|
||||
def get_all_routers(self) -> List[RouterConfig]:
|
||||
"""Get all router configurations."""
|
||||
return list(self.routers.values())
|
||||
|
||||
def validate_configuration(self) -> List[str]:
|
||||
"""Validate the entire configuration."""
|
||||
issues = []
|
||||
|
||||
# Validate zones
|
||||
for zone_id, zone in self.zones.items():
|
||||
if zone.primary_router and zone.primary_router not in self.routers:
|
||||
issues.append(f"Zone {zone_id} references unknown primary router: {zone.primary_router}")
|
||||
|
||||
for router_id in zone.secondary_routers:
|
||||
if router_id not in self.routers:
|
||||
issues.append(f"Zone {zone_id} references unknown secondary router: {router_id}")
|
||||
|
||||
# Validate routers
|
||||
for router_id, router in self.routers.items():
|
||||
if not router.ip_address:
|
||||
issues.append(f"Router {router_id} missing IP address")
|
||||
|
||||
if not router.mac_address:
|
||||
issues.append(f"Router {router_id} missing MAC address")
|
||||
|
||||
# Validate pose models
|
||||
for model_name, model in self.pose_models.items():
|
||||
import os
|
||||
if not os.path.exists(model.model_path):
|
||||
issues.append(f"Pose model {model_name} file not found: {model.model_path}")
|
||||
|
||||
return issues
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert configuration to dictionary."""
|
||||
return {
|
||||
"zones": {
|
||||
zone_id: {
|
||||
"zone_id": zone.zone_id,
|
||||
"name": zone.name,
|
||||
"zone_type": zone.zone_type.value,
|
||||
"description": zone.description,
|
||||
"boundaries": {
|
||||
"x_min": zone.x_min,
|
||||
"x_max": zone.x_max,
|
||||
"y_min": zone.y_min,
|
||||
"y_max": zone.y_max,
|
||||
"z_min": zone.z_min,
|
||||
"z_max": zone.z_max
|
||||
},
|
||||
"settings": {
|
||||
"enabled": zone.enabled,
|
||||
"confidence_threshold": zone.confidence_threshold,
|
||||
"max_persons": zone.max_persons,
|
||||
"activity_detection": zone.activity_detection
|
||||
},
|
||||
"hardware": {
|
||||
"primary_router": zone.primary_router,
|
||||
"secondary_routers": zone.secondary_routers
|
||||
}
|
||||
}
|
||||
for zone_id, zone in self.zones.items()
|
||||
},
|
||||
"routers": {
|
||||
router_id: router.to_dict()
|
||||
for router_id, router in self.routers.items()
|
||||
},
|
||||
"pose_models": {
|
||||
model_name: model.dict()
|
||||
for model_name, model in self.pose_models.items()
|
||||
},
|
||||
"streaming": self.streaming.dict(),
|
||||
"alerts": self.alerts.dict()
|
||||
}
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_domain_config() -> DomainConfig:
|
||||
"""Get cached domain configuration instance."""
|
||||
return DomainConfig()
|
||||
|
||||
|
||||
def load_domain_config_from_file(file_path: str) -> DomainConfig:
|
||||
"""Load domain configuration from file."""
|
||||
import json
|
||||
|
||||
config = DomainConfig()
|
||||
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Load zones
|
||||
for zone_data in data.get("zones", []):
|
||||
zone = ZoneConfig(**zone_data)
|
||||
config.add_zone(zone)
|
||||
|
||||
# Load routers
|
||||
for router_data in data.get("routers", []):
|
||||
router = RouterConfig(**router_data)
|
||||
config.add_router(router)
|
||||
|
||||
# Load pose models
|
||||
for model_data in data.get("pose_models", []):
|
||||
model = PoseModelConfig(**model_data)
|
||||
config.add_pose_model(model)
|
||||
|
||||
# Load streaming config
|
||||
if "streaming" in data:
|
||||
config.streaming = StreamingConfig(**data["streaming"])
|
||||
|
||||
# Load alerts config
|
||||
if "alerts" in data:
|
||||
config.alerts = AlertConfig(**data["alerts"])
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load domain configuration: {e}")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def save_domain_config_to_file(config: DomainConfig, file_path: str):
|
||||
"""Save domain configuration to file."""
|
||||
import json
|
||||
|
||||
try:
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(config.to_dict(), f, indent=2)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to save domain configuration: {e}")
|
||||
435
v1/src/config/settings.py
Normal file
435
v1/src/config/settings.py
Normal file
@@ -0,0 +1,435 @@
|
||||
"""
|
||||
Pydantic settings for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional, Dict, Any
|
||||
from functools import lru_cache
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings with environment variable support."""
|
||||
|
||||
# Application settings
|
||||
app_name: str = Field(default="WiFi-DensePose API", description="Application name")
|
||||
version: str = Field(default="1.0.0", description="Application version")
|
||||
environment: str = Field(default="development", description="Environment (development, staging, production)")
|
||||
debug: bool = Field(default=False, description="Debug mode")
|
||||
|
||||
# Server settings
|
||||
host: str = Field(default="0.0.0.0", description="Server host")
|
||||
port: int = Field(default=8000, description="Server port")
|
||||
reload: bool = Field(default=False, description="Auto-reload on code changes")
|
||||
workers: int = Field(default=1, description="Number of worker processes")
|
||||
|
||||
# Security settings
|
||||
secret_key: str = Field(..., description="Secret key for JWT tokens")
|
||||
jwt_algorithm: str = Field(default="HS256", description="JWT algorithm")
|
||||
jwt_expire_hours: int = Field(default=24, description="JWT token expiration in hours")
|
||||
allowed_hosts: List[str] = Field(default=["*"], description="Allowed hosts")
|
||||
cors_origins: List[str] = Field(default=["*"], description="CORS allowed origins")
|
||||
|
||||
# Rate limiting settings
|
||||
rate_limit_requests: int = Field(default=100, description="Rate limit requests per window")
|
||||
rate_limit_authenticated_requests: int = Field(default=1000, description="Rate limit for authenticated users")
|
||||
rate_limit_window: int = Field(default=3600, description="Rate limit window in seconds")
|
||||
|
||||
# Database settings
|
||||
database_url: Optional[str] = Field(default=None, description="Database connection URL")
|
||||
database_pool_size: int = Field(default=10, description="Database connection pool size")
|
||||
database_max_overflow: int = Field(default=20, description="Database max overflow connections")
|
||||
|
||||
# Database connection pool settings (alternative naming for compatibility)
|
||||
db_pool_size: int = Field(default=10, description="Database connection pool size")
|
||||
db_max_overflow: int = Field(default=20, description="Database max overflow connections")
|
||||
db_pool_timeout: int = Field(default=30, description="Database pool timeout in seconds")
|
||||
db_pool_recycle: int = Field(default=3600, description="Database pool recycle time in seconds")
|
||||
|
||||
# Database connection settings
|
||||
db_host: Optional[str] = Field(default=None, description="Database host")
|
||||
db_port: int = Field(default=5432, description="Database port")
|
||||
db_name: Optional[str] = Field(default=None, description="Database name")
|
||||
db_user: Optional[str] = Field(default=None, description="Database user")
|
||||
db_password: Optional[str] = Field(default=None, description="Database password")
|
||||
db_echo: bool = Field(default=False, description="Enable database query logging")
|
||||
|
||||
# Redis settings (for caching and rate limiting)
|
||||
redis_url: Optional[str] = Field(default=None, description="Redis connection URL")
|
||||
redis_password: Optional[str] = Field(default=None, description="Redis password")
|
||||
redis_db: int = Field(default=0, description="Redis database number")
|
||||
redis_enabled: bool = Field(default=True, description="Enable Redis")
|
||||
redis_host: str = Field(default="localhost", description="Redis host")
|
||||
redis_port: int = Field(default=6379, description="Redis port")
|
||||
redis_required: bool = Field(default=False, description="Require Redis connection (fail if unavailable)")
|
||||
redis_max_connections: int = Field(default=10, description="Maximum Redis connections")
|
||||
redis_socket_timeout: int = Field(default=5, description="Redis socket timeout in seconds")
|
||||
redis_connect_timeout: int = Field(default=5, description="Redis connection timeout in seconds")
|
||||
|
||||
# Failsafe settings
|
||||
enable_database_failsafe: bool = Field(default=True, description="Enable automatic SQLite failsafe when PostgreSQL unavailable")
|
||||
enable_redis_failsafe: bool = Field(default=True, description="Enable automatic Redis failsafe (disable when unavailable)")
|
||||
sqlite_fallback_path: str = Field(default="./data/wifi_densepose_fallback.db", description="SQLite fallback database path")
|
||||
|
||||
# Hardware settings
|
||||
wifi_interface: str = Field(default="wlan0", description="WiFi interface name")
|
||||
csi_buffer_size: int = Field(default=1000, description="CSI data buffer size")
|
||||
hardware_polling_interval: float = Field(default=0.1, description="Hardware polling interval in seconds")
|
||||
|
||||
# CSI Processing settings
|
||||
csi_sampling_rate: int = Field(default=1000, description="CSI sampling rate")
|
||||
csi_window_size: int = Field(default=512, description="CSI window size")
|
||||
csi_overlap: float = Field(default=0.5, description="CSI window overlap")
|
||||
csi_noise_threshold: float = Field(default=0.1, description="CSI noise threshold")
|
||||
csi_human_detection_threshold: float = Field(default=0.8, description="CSI human detection threshold")
|
||||
csi_smoothing_factor: float = Field(default=0.9, description="CSI smoothing factor")
|
||||
csi_max_history_size: int = Field(default=500, description="CSI max history size")
|
||||
|
||||
# Pose estimation settings
|
||||
pose_model_path: Optional[str] = Field(default=None, description="Path to pose estimation model")
|
||||
pose_confidence_threshold: float = Field(default=0.5, description="Minimum confidence threshold")
|
||||
pose_processing_batch_size: int = Field(default=32, description="Batch size for pose processing")
|
||||
pose_max_persons: int = Field(default=10, description="Maximum persons to detect per frame")
|
||||
|
||||
# Streaming settings
|
||||
stream_fps: int = Field(default=30, description="Streaming frames per second")
|
||||
stream_buffer_size: int = Field(default=100, description="Stream buffer size")
|
||||
websocket_ping_interval: int = Field(default=60, description="WebSocket ping interval in seconds")
|
||||
websocket_timeout: int = Field(default=300, description="WebSocket timeout in seconds")
|
||||
|
||||
# Logging settings
|
||||
log_level: str = Field(default="INFO", description="Logging level")
|
||||
log_format: str = Field(
|
||||
default="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
description="Log format"
|
||||
)
|
||||
log_file: Optional[str] = Field(default=None, description="Log file path")
|
||||
log_directory: str = Field(default="./logs", description="Log directory path")
|
||||
log_max_size: int = Field(default=10485760, description="Max log file size in bytes (10MB)")
|
||||
log_backup_count: int = Field(default=5, description="Number of log backup files")
|
||||
|
||||
# Monitoring settings
|
||||
metrics_enabled: bool = Field(default=True, description="Enable metrics collection")
|
||||
health_check_interval: int = Field(default=30, description="Health check interval in seconds")
|
||||
performance_monitoring: bool = Field(default=True, description="Enable performance monitoring")
|
||||
monitoring_interval_seconds: int = Field(default=60, description="Monitoring task interval in seconds")
|
||||
cleanup_interval_seconds: int = Field(default=3600, description="Cleanup task interval in seconds")
|
||||
backup_interval_seconds: int = Field(default=86400, description="Backup task interval in seconds")
|
||||
|
||||
# Storage settings
|
||||
data_storage_path: str = Field(default="./data", description="Data storage directory")
|
||||
model_storage_path: str = Field(default="./models", description="Model storage directory")
|
||||
temp_storage_path: str = Field(default="./temp", description="Temporary storage directory")
|
||||
backup_directory: str = Field(default="./backups", description="Backup storage directory")
|
||||
max_storage_size_gb: int = Field(default=100, description="Maximum storage size in GB")
|
||||
|
||||
# API settings
|
||||
api_prefix: str = Field(default="/api/v1", description="API prefix")
|
||||
docs_url: str = Field(default="/docs", description="API documentation URL")
|
||||
redoc_url: str = Field(default="/redoc", description="ReDoc documentation URL")
|
||||
openapi_url: str = Field(default="/openapi.json", description="OpenAPI schema URL")
|
||||
|
||||
# Feature flags
|
||||
enable_authentication: bool = Field(default=True, description="Enable authentication")
|
||||
enable_rate_limiting: bool = Field(default=True, description="Enable rate limiting")
|
||||
enable_websockets: bool = Field(default=True, description="Enable WebSocket support")
|
||||
enable_historical_data: bool = Field(default=True, description="Enable historical data storage")
|
||||
enable_real_time_processing: bool = Field(default=True, description="Enable real-time processing")
|
||||
cors_enabled: bool = Field(default=True, description="Enable CORS middleware")
|
||||
cors_allow_credentials: bool = Field(default=True, description="Allow credentials in CORS")
|
||||
|
||||
# Development settings
|
||||
mock_hardware: bool = Field(default=False, description="Use mock hardware for development")
|
||||
mock_pose_data: bool = Field(default=False, description="Use mock pose data for development")
|
||||
enable_test_endpoints: bool = Field(default=False, description="Enable test endpoints")
|
||||
|
||||
# Cleanup settings
|
||||
csi_data_retention_days: int = Field(default=30, description="CSI data retention in days")
|
||||
pose_detection_retention_days: int = Field(default=30, description="Pose detection retention in days")
|
||||
metrics_retention_days: int = Field(default=7, description="Metrics retention in days")
|
||||
audit_log_retention_days: int = Field(default=90, description="Audit log retention in days")
|
||||
orphaned_session_threshold_days: int = Field(default=7, description="Orphaned session threshold in days")
|
||||
cleanup_batch_size: int = Field(default=1000, description="Cleanup batch size")
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False
|
||||
)
|
||||
|
||||
@field_validator("environment")
|
||||
@classmethod
|
||||
def validate_environment(cls, v):
|
||||
"""Validate environment setting."""
|
||||
allowed_environments = ["development", "staging", "production"]
|
||||
if v not in allowed_environments:
|
||||
raise ValueError(f"Environment must be one of: {allowed_environments}")
|
||||
return v
|
||||
|
||||
@field_validator("log_level")
|
||||
@classmethod
|
||||
def validate_log_level(cls, v):
|
||||
"""Validate log level setting."""
|
||||
allowed_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
||||
if v.upper() not in allowed_levels:
|
||||
raise ValueError(f"Log level must be one of: {allowed_levels}")
|
||||
return v.upper()
|
||||
|
||||
@field_validator("pose_confidence_threshold")
|
||||
@classmethod
|
||||
def validate_confidence_threshold(cls, v):
|
||||
"""Validate confidence threshold."""
|
||||
if not 0.0 <= v <= 1.0:
|
||||
raise ValueError("Confidence threshold must be between 0.0 and 1.0")
|
||||
return v
|
||||
|
||||
@field_validator("stream_fps")
|
||||
@classmethod
|
||||
def validate_stream_fps(cls, v):
|
||||
"""Validate streaming FPS."""
|
||||
if not 1 <= v <= 60:
|
||||
raise ValueError("Stream FPS must be between 1 and 60")
|
||||
return v
|
||||
|
||||
@field_validator("port")
|
||||
@classmethod
|
||||
def validate_port(cls, v):
|
||||
"""Validate port number."""
|
||||
if not 1 <= v <= 65535:
|
||||
raise ValueError("Port must be between 1 and 65535")
|
||||
return v
|
||||
|
||||
@field_validator("workers")
|
||||
@classmethod
|
||||
def validate_workers(cls, v):
|
||||
"""Validate worker count."""
|
||||
if v < 1:
|
||||
raise ValueError("Workers must be at least 1")
|
||||
return v
|
||||
|
||||
@field_validator("db_port")
|
||||
@classmethod
|
||||
def validate_db_port(cls, v):
|
||||
"""Validate database port."""
|
||||
if not 1 <= v <= 65535:
|
||||
raise ValueError("Database port must be between 1 and 65535")
|
||||
return v
|
||||
|
||||
@field_validator("redis_port")
|
||||
@classmethod
|
||||
def validate_redis_port(cls, v):
|
||||
"""Validate Redis port."""
|
||||
if not 1 <= v <= 65535:
|
||||
raise ValueError("Redis port must be between 1 and 65535")
|
||||
return v
|
||||
|
||||
@field_validator("db_pool_size")
|
||||
@classmethod
|
||||
def validate_db_pool_size(cls, v):
|
||||
"""Validate database pool size."""
|
||||
if v < 1:
|
||||
raise ValueError("Database pool size must be at least 1")
|
||||
return v
|
||||
|
||||
@field_validator("monitoring_interval_seconds", "cleanup_interval_seconds", "backup_interval_seconds")
|
||||
@classmethod
|
||||
def validate_interval_seconds(cls, v):
|
||||
"""Validate interval settings."""
|
||||
if v < 0:
|
||||
raise ValueError("Interval seconds must be non-negative")
|
||||
return v
|
||||
@property
|
||||
def is_development(self) -> bool:
|
||||
"""Check if running in development environment."""
|
||||
return self.environment == "development"
|
||||
|
||||
@property
|
||||
def is_production(self) -> bool:
|
||||
"""Check if running in production environment."""
|
||||
return self.environment == "production"
|
||||
|
||||
@property
|
||||
def is_testing(self) -> bool:
|
||||
"""Check if running in testing environment."""
|
||||
return self.environment == "testing"
|
||||
|
||||
def get_database_url(self) -> str:
|
||||
"""Get database URL with fallback."""
|
||||
if self.database_url:
|
||||
return self.database_url
|
||||
|
||||
# Build URL from individual components if available
|
||||
if self.db_host and self.db_name and self.db_user:
|
||||
password_part = f":{self.db_password}" if self.db_password else ""
|
||||
return f"postgresql://{self.db_user}{password_part}@{self.db_host}:{self.db_port}/{self.db_name}"
|
||||
|
||||
# Default SQLite database for development
|
||||
if self.is_development:
|
||||
return f"sqlite:///{self.data_storage_path}/wifi_densepose.db"
|
||||
|
||||
# SQLite failsafe for production if enabled
|
||||
if self.enable_database_failsafe:
|
||||
return f"sqlite:///{self.sqlite_fallback_path}"
|
||||
|
||||
raise ValueError("Database URL must be configured for non-development environments")
|
||||
|
||||
def get_sqlite_fallback_url(self) -> str:
|
||||
"""Get SQLite fallback database URL."""
|
||||
return f"sqlite:///{self.sqlite_fallback_path}"
|
||||
|
||||
def get_redis_url(self) -> Optional[str]:
|
||||
"""Get Redis URL with fallback."""
|
||||
if not self.redis_enabled:
|
||||
return None
|
||||
|
||||
if self.redis_url:
|
||||
return self.redis_url
|
||||
|
||||
# Build URL from individual components
|
||||
password_part = f":{self.redis_password}@" if self.redis_password else ""
|
||||
return f"redis://{password_part}{self.redis_host}:{self.redis_port}/{self.redis_db}"
|
||||
|
||||
def get_cors_config(self) -> Dict[str, Any]:
|
||||
"""Get CORS configuration."""
|
||||
if self.is_development:
|
||||
return {
|
||||
"allow_origins": ["*"],
|
||||
"allow_credentials": True,
|
||||
"allow_methods": ["*"],
|
||||
"allow_headers": ["*"],
|
||||
}
|
||||
|
||||
return {
|
||||
"allow_origins": self.cors_origins,
|
||||
"allow_credentials": True,
|
||||
"allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
"allow_headers": ["Authorization", "Content-Type"],
|
||||
}
|
||||
|
||||
def get_logging_config(self) -> Dict[str, Any]:
|
||||
"""Get logging configuration."""
|
||||
config = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"default": {
|
||||
"format": self.log_format,
|
||||
},
|
||||
"detailed": {
|
||||
"format": "%(asctime)s - %(name)s - %(levelname)s - %(module)s:%(lineno)d - %(message)s",
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": self.log_level,
|
||||
"formatter": "default",
|
||||
"stream": "ext://sys.stdout",
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"": {
|
||||
"level": self.log_level,
|
||||
"handlers": ["console"],
|
||||
},
|
||||
"uvicorn": {
|
||||
"level": "INFO",
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
},
|
||||
"fastapi": {
|
||||
"level": "INFO",
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Add file handler if log file is specified
|
||||
if self.log_file:
|
||||
config["handlers"]["file"] = {
|
||||
"class": "logging.handlers.RotatingFileHandler",
|
||||
"level": self.log_level,
|
||||
"formatter": "detailed",
|
||||
"filename": self.log_file,
|
||||
"maxBytes": self.log_max_size,
|
||||
"backupCount": self.log_backup_count,
|
||||
}
|
||||
|
||||
# Add file handler to all loggers
|
||||
for logger_config in config["loggers"].values():
|
||||
logger_config["handlers"].append("file")
|
||||
|
||||
return config
|
||||
|
||||
def create_directories(self):
|
||||
"""Create necessary directories."""
|
||||
directories = [
|
||||
self.data_storage_path,
|
||||
self.model_storage_path,
|
||||
self.temp_storage_path,
|
||||
self.log_directory,
|
||||
self.backup_directory,
|
||||
]
|
||||
|
||||
for directory in directories:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
"""Get cached settings instance."""
|
||||
settings = Settings()
|
||||
settings.create_directories()
|
||||
return settings
|
||||
|
||||
|
||||
def get_test_settings() -> Settings:
|
||||
"""Get settings for testing."""
|
||||
return Settings(
|
||||
environment="testing",
|
||||
debug=True,
|
||||
secret_key="test-secret-key",
|
||||
database_url="sqlite:///:memory:",
|
||||
mock_hardware=True,
|
||||
mock_pose_data=True,
|
||||
enable_test_endpoints=True,
|
||||
log_level="DEBUG"
|
||||
)
|
||||
|
||||
|
||||
def load_settings_from_file(file_path: str) -> Settings:
|
||||
"""Load settings from a specific file."""
|
||||
return Settings(_env_file=file_path)
|
||||
|
||||
|
||||
def validate_settings(settings: Settings) -> List[str]:
|
||||
"""Validate settings and return list of issues."""
|
||||
issues = []
|
||||
|
||||
# Check required settings for production
|
||||
if settings.is_production:
|
||||
if not settings.secret_key or settings.secret_key == "change-me":
|
||||
issues.append("Secret key must be set for production")
|
||||
|
||||
if not settings.database_url and not (settings.db_host and settings.db_name and settings.db_user):
|
||||
issues.append("Database URL or database connection parameters must be set for production")
|
||||
|
||||
if settings.debug:
|
||||
issues.append("Debug mode should be disabled in production")
|
||||
|
||||
if "*" in settings.allowed_hosts:
|
||||
issues.append("Allowed hosts should be restricted in production")
|
||||
|
||||
if "*" in settings.cors_origins:
|
||||
issues.append("CORS origins should be restricted in production")
|
||||
|
||||
# Check storage paths exist
|
||||
try:
|
||||
settings.create_directories()
|
||||
except Exception as e:
|
||||
issues.append(f"Cannot create storage directories: {e}")
|
||||
|
||||
return issues
|
||||
13
v1/src/core/__init__.py
Normal file
13
v1/src/core/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Core package for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
from .csi_processor import CSIProcessor
|
||||
from .phase_sanitizer import PhaseSanitizer
|
||||
from .router_interface import RouterInterface
|
||||
|
||||
__all__ = [
|
||||
'CSIProcessor',
|
||||
'PhaseSanitizer',
|
||||
'RouterInterface'
|
||||
]
|
||||
425
v1/src/core/csi_processor.py
Normal file
425
v1/src/core/csi_processor.py
Normal file
@@ -0,0 +1,425 @@
|
||||
"""CSI data processor for WiFi-DensePose system using TDD approach."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Any, Optional, List
|
||||
from dataclasses import dataclass
|
||||
from collections import deque
|
||||
import scipy.signal
|
||||
import scipy.fft
|
||||
|
||||
try:
|
||||
from ..hardware.csi_extractor import CSIData
|
||||
except ImportError:
|
||||
# Handle import for testing
|
||||
from src.hardware.csi_extractor import CSIData
|
||||
|
||||
|
||||
class CSIProcessingError(Exception):
|
||||
"""Exception raised for CSI processing errors."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CSIFeatures:
|
||||
"""Data structure for extracted CSI features."""
|
||||
amplitude_mean: np.ndarray
|
||||
amplitude_variance: np.ndarray
|
||||
phase_difference: np.ndarray
|
||||
correlation_matrix: np.ndarray
|
||||
doppler_shift: np.ndarray
|
||||
power_spectral_density: np.ndarray
|
||||
timestamp: datetime
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class HumanDetectionResult:
|
||||
"""Data structure for human detection results."""
|
||||
human_detected: bool
|
||||
confidence: float
|
||||
motion_score: float
|
||||
timestamp: datetime
|
||||
features: CSIFeatures
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class CSIProcessor:
|
||||
"""Processes CSI data for human detection and pose estimation."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any], logger: Optional[logging.Logger] = None):
|
||||
"""Initialize CSI processor.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
logger: Optional logger instance
|
||||
|
||||
Raises:
|
||||
ValueError: If configuration is invalid
|
||||
"""
|
||||
self._validate_config(config)
|
||||
|
||||
self.config = config
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
|
||||
# Processing parameters
|
||||
self.sampling_rate = config['sampling_rate']
|
||||
self.window_size = config['window_size']
|
||||
self.overlap = config['overlap']
|
||||
self.noise_threshold = config['noise_threshold']
|
||||
self.human_detection_threshold = config.get('human_detection_threshold', 0.8)
|
||||
self.smoothing_factor = config.get('smoothing_factor', 0.9)
|
||||
self.max_history_size = config.get('max_history_size', 500)
|
||||
|
||||
# Feature extraction flags
|
||||
self.enable_preprocessing = config.get('enable_preprocessing', True)
|
||||
self.enable_feature_extraction = config.get('enable_feature_extraction', True)
|
||||
self.enable_human_detection = config.get('enable_human_detection', True)
|
||||
|
||||
# Processing state
|
||||
self.csi_history = deque(maxlen=self.max_history_size)
|
||||
self.previous_detection_confidence = 0.0
|
||||
|
||||
# Statistics tracking
|
||||
self._total_processed = 0
|
||||
self._processing_errors = 0
|
||||
self._human_detections = 0
|
||||
|
||||
def _validate_config(self, config: Dict[str, Any]) -> None:
|
||||
"""Validate configuration parameters.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If configuration is invalid
|
||||
"""
|
||||
required_fields = ['sampling_rate', 'window_size', 'overlap', 'noise_threshold']
|
||||
missing_fields = [field for field in required_fields if field not in config]
|
||||
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required configuration: {missing_fields}")
|
||||
|
||||
if config['sampling_rate'] <= 0:
|
||||
raise ValueError("sampling_rate must be positive")
|
||||
|
||||
if config['window_size'] <= 0:
|
||||
raise ValueError("window_size must be positive")
|
||||
|
||||
if not 0 <= config['overlap'] < 1:
|
||||
raise ValueError("overlap must be between 0 and 1")
|
||||
|
||||
def preprocess_csi_data(self, csi_data: CSIData) -> CSIData:
|
||||
"""Preprocess CSI data for feature extraction.
|
||||
|
||||
Args:
|
||||
csi_data: Raw CSI data
|
||||
|
||||
Returns:
|
||||
Preprocessed CSI data
|
||||
|
||||
Raises:
|
||||
CSIProcessingError: If preprocessing fails
|
||||
"""
|
||||
if not self.enable_preprocessing:
|
||||
return csi_data
|
||||
|
||||
try:
|
||||
# Remove noise from the signal
|
||||
cleaned_data = self._remove_noise(csi_data)
|
||||
|
||||
# Apply windowing function
|
||||
windowed_data = self._apply_windowing(cleaned_data)
|
||||
|
||||
# Normalize amplitude values
|
||||
normalized_data = self._normalize_amplitude(windowed_data)
|
||||
|
||||
return normalized_data
|
||||
|
||||
except Exception as e:
|
||||
raise CSIProcessingError(f"Failed to preprocess CSI data: {e}")
|
||||
|
||||
def extract_features(self, csi_data: CSIData) -> Optional[CSIFeatures]:
|
||||
"""Extract features from CSI data.
|
||||
|
||||
Args:
|
||||
csi_data: Preprocessed CSI data
|
||||
|
||||
Returns:
|
||||
Extracted features or None if disabled
|
||||
|
||||
Raises:
|
||||
CSIProcessingError: If feature extraction fails
|
||||
"""
|
||||
if not self.enable_feature_extraction:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Extract amplitude-based features
|
||||
amplitude_mean, amplitude_variance = self._extract_amplitude_features(csi_data)
|
||||
|
||||
# Extract phase-based features
|
||||
phase_difference = self._extract_phase_features(csi_data)
|
||||
|
||||
# Extract correlation features
|
||||
correlation_matrix = self._extract_correlation_features(csi_data)
|
||||
|
||||
# Extract Doppler and frequency features
|
||||
doppler_shift, power_spectral_density = self._extract_doppler_features(csi_data)
|
||||
|
||||
return CSIFeatures(
|
||||
amplitude_mean=amplitude_mean,
|
||||
amplitude_variance=amplitude_variance,
|
||||
phase_difference=phase_difference,
|
||||
correlation_matrix=correlation_matrix,
|
||||
doppler_shift=doppler_shift,
|
||||
power_spectral_density=power_spectral_density,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
metadata={'processing_params': self.config}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise CSIProcessingError(f"Failed to extract features: {e}")
|
||||
|
||||
def detect_human_presence(self, features: CSIFeatures) -> Optional[HumanDetectionResult]:
|
||||
"""Detect human presence from CSI features.
|
||||
|
||||
Args:
|
||||
features: Extracted CSI features
|
||||
|
||||
Returns:
|
||||
Detection result or None if disabled
|
||||
|
||||
Raises:
|
||||
CSIProcessingError: If detection fails
|
||||
"""
|
||||
if not self.enable_human_detection:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Analyze motion patterns
|
||||
motion_score = self._analyze_motion_patterns(features)
|
||||
|
||||
# Calculate detection confidence
|
||||
raw_confidence = self._calculate_detection_confidence(features, motion_score)
|
||||
|
||||
# Apply temporal smoothing
|
||||
smoothed_confidence = self._apply_temporal_smoothing(raw_confidence)
|
||||
|
||||
# Determine if human is detected
|
||||
human_detected = smoothed_confidence >= self.human_detection_threshold
|
||||
|
||||
if human_detected:
|
||||
self._human_detections += 1
|
||||
|
||||
return HumanDetectionResult(
|
||||
human_detected=human_detected,
|
||||
confidence=smoothed_confidence,
|
||||
motion_score=motion_score,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
features=features,
|
||||
metadata={'threshold': self.human_detection_threshold}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise CSIProcessingError(f"Failed to detect human presence: {e}")
|
||||
|
||||
async def process_csi_data(self, csi_data: CSIData) -> HumanDetectionResult:
|
||||
"""Process CSI data through the complete pipeline.
|
||||
|
||||
Args:
|
||||
csi_data: Raw CSI data
|
||||
|
||||
Returns:
|
||||
Human detection result
|
||||
|
||||
Raises:
|
||||
CSIProcessingError: If processing fails
|
||||
"""
|
||||
try:
|
||||
self._total_processed += 1
|
||||
|
||||
# Preprocess the data
|
||||
preprocessed_data = self.preprocess_csi_data(csi_data)
|
||||
|
||||
# Extract features
|
||||
features = self.extract_features(preprocessed_data)
|
||||
|
||||
# Detect human presence
|
||||
detection_result = self.detect_human_presence(features)
|
||||
|
||||
# Add to history
|
||||
self.add_to_history(csi_data)
|
||||
|
||||
return detection_result
|
||||
|
||||
except Exception as e:
|
||||
self._processing_errors += 1
|
||||
raise CSIProcessingError(f"Pipeline processing failed: {e}")
|
||||
|
||||
def add_to_history(self, csi_data: CSIData) -> None:
|
||||
"""Add CSI data to processing history.
|
||||
|
||||
Args:
|
||||
csi_data: CSI data to add to history
|
||||
"""
|
||||
self.csi_history.append(csi_data)
|
||||
|
||||
def clear_history(self) -> None:
|
||||
"""Clear the CSI data history."""
|
||||
self.csi_history.clear()
|
||||
|
||||
def get_recent_history(self, count: int) -> List[CSIData]:
|
||||
"""Get recent CSI data from history.
|
||||
|
||||
Args:
|
||||
count: Number of recent entries to return
|
||||
|
||||
Returns:
|
||||
List of recent CSI data entries
|
||||
"""
|
||||
if count >= len(self.csi_history):
|
||||
return list(self.csi_history)
|
||||
else:
|
||||
return list(self.csi_history)[-count:]
|
||||
|
||||
def get_processing_statistics(self) -> Dict[str, Any]:
|
||||
"""Get processing statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary containing processing statistics
|
||||
"""
|
||||
error_rate = self._processing_errors / self._total_processed if self._total_processed > 0 else 0
|
||||
detection_rate = self._human_detections / self._total_processed if self._total_processed > 0 else 0
|
||||
|
||||
return {
|
||||
'total_processed': self._total_processed,
|
||||
'processing_errors': self._processing_errors,
|
||||
'human_detections': self._human_detections,
|
||||
'error_rate': error_rate,
|
||||
'detection_rate': detection_rate,
|
||||
'history_size': len(self.csi_history)
|
||||
}
|
||||
|
||||
def reset_statistics(self) -> None:
|
||||
"""Reset processing statistics."""
|
||||
self._total_processed = 0
|
||||
self._processing_errors = 0
|
||||
self._human_detections = 0
|
||||
|
||||
# Private processing methods
|
||||
def _remove_noise(self, csi_data: CSIData) -> CSIData:
|
||||
"""Remove noise from CSI data."""
|
||||
# Apply noise filtering based on threshold
|
||||
amplitude_db = 20 * np.log10(np.abs(csi_data.amplitude) + 1e-12)
|
||||
noise_mask = amplitude_db > self.noise_threshold
|
||||
|
||||
filtered_amplitude = csi_data.amplitude.copy()
|
||||
filtered_amplitude[~noise_mask] = 0
|
||||
|
||||
return CSIData(
|
||||
timestamp=csi_data.timestamp,
|
||||
amplitude=filtered_amplitude,
|
||||
phase=csi_data.phase,
|
||||
frequency=csi_data.frequency,
|
||||
bandwidth=csi_data.bandwidth,
|
||||
num_subcarriers=csi_data.num_subcarriers,
|
||||
num_antennas=csi_data.num_antennas,
|
||||
snr=csi_data.snr,
|
||||
metadata={**csi_data.metadata, 'noise_filtered': True}
|
||||
)
|
||||
|
||||
def _apply_windowing(self, csi_data: CSIData) -> CSIData:
|
||||
"""Apply windowing function to CSI data."""
|
||||
# Apply Hamming window to reduce spectral leakage
|
||||
window = scipy.signal.windows.hamming(csi_data.num_subcarriers)
|
||||
windowed_amplitude = csi_data.amplitude * window[np.newaxis, :]
|
||||
|
||||
return CSIData(
|
||||
timestamp=csi_data.timestamp,
|
||||
amplitude=windowed_amplitude,
|
||||
phase=csi_data.phase,
|
||||
frequency=csi_data.frequency,
|
||||
bandwidth=csi_data.bandwidth,
|
||||
num_subcarriers=csi_data.num_subcarriers,
|
||||
num_antennas=csi_data.num_antennas,
|
||||
snr=csi_data.snr,
|
||||
metadata={**csi_data.metadata, 'windowed': True}
|
||||
)
|
||||
|
||||
def _normalize_amplitude(self, csi_data: CSIData) -> CSIData:
|
||||
"""Normalize amplitude values."""
|
||||
# Normalize to unit variance
|
||||
normalized_amplitude = csi_data.amplitude / (np.std(csi_data.amplitude) + 1e-12)
|
||||
|
||||
return CSIData(
|
||||
timestamp=csi_data.timestamp,
|
||||
amplitude=normalized_amplitude,
|
||||
phase=csi_data.phase,
|
||||
frequency=csi_data.frequency,
|
||||
bandwidth=csi_data.bandwidth,
|
||||
num_subcarriers=csi_data.num_subcarriers,
|
||||
num_antennas=csi_data.num_antennas,
|
||||
snr=csi_data.snr,
|
||||
metadata={**csi_data.metadata, 'normalized': True}
|
||||
)
|
||||
|
||||
def _extract_amplitude_features(self, csi_data: CSIData) -> tuple:
|
||||
"""Extract amplitude-based features."""
|
||||
amplitude_mean = np.mean(csi_data.amplitude, axis=0)
|
||||
amplitude_variance = np.var(csi_data.amplitude, axis=0)
|
||||
return amplitude_mean, amplitude_variance
|
||||
|
||||
def _extract_phase_features(self, csi_data: CSIData) -> np.ndarray:
|
||||
"""Extract phase-based features."""
|
||||
# Calculate phase differences between adjacent subcarriers
|
||||
phase_diff = np.diff(csi_data.phase, axis=1)
|
||||
return np.mean(phase_diff, axis=0)
|
||||
|
||||
def _extract_correlation_features(self, csi_data: CSIData) -> np.ndarray:
|
||||
"""Extract correlation features between antennas."""
|
||||
# Calculate correlation matrix between antennas
|
||||
correlation_matrix = np.corrcoef(csi_data.amplitude)
|
||||
return correlation_matrix
|
||||
|
||||
def _extract_doppler_features(self, csi_data: CSIData) -> tuple:
|
||||
"""Extract Doppler and frequency domain features."""
|
||||
# Simple Doppler estimation (would use history in real implementation)
|
||||
doppler_shift = np.random.rand(10) # Placeholder
|
||||
|
||||
# Power spectral density
|
||||
psd = np.abs(scipy.fft.fft(csi_data.amplitude.flatten(), n=128))**2
|
||||
|
||||
return doppler_shift, psd
|
||||
|
||||
def _analyze_motion_patterns(self, features: CSIFeatures) -> float:
|
||||
"""Analyze motion patterns from features."""
|
||||
# Analyze variance and correlation patterns to detect motion
|
||||
variance_score = np.mean(features.amplitude_variance)
|
||||
correlation_score = np.mean(np.abs(features.correlation_matrix - np.eye(features.correlation_matrix.shape[0])))
|
||||
|
||||
# Combine scores (simplified approach)
|
||||
motion_score = 0.6 * variance_score + 0.4 * correlation_score
|
||||
return np.clip(motion_score, 0.0, 1.0)
|
||||
|
||||
def _calculate_detection_confidence(self, features: CSIFeatures, motion_score: float) -> float:
|
||||
"""Calculate detection confidence based on features."""
|
||||
# Combine multiple feature indicators
|
||||
amplitude_indicator = np.mean(features.amplitude_mean) > 0.1
|
||||
phase_indicator = np.std(features.phase_difference) > 0.05
|
||||
motion_indicator = motion_score > 0.3
|
||||
|
||||
# Weight the indicators
|
||||
confidence = (0.4 * amplitude_indicator + 0.3 * phase_indicator + 0.3 * motion_indicator)
|
||||
return np.clip(confidence, 0.0, 1.0)
|
||||
|
||||
def _apply_temporal_smoothing(self, raw_confidence: float) -> float:
|
||||
"""Apply temporal smoothing to detection confidence."""
|
||||
# Exponential moving average
|
||||
smoothed_confidence = (self.smoothing_factor * self.previous_detection_confidence +
|
||||
(1 - self.smoothing_factor) * raw_confidence)
|
||||
|
||||
self.previous_detection_confidence = smoothed_confidence
|
||||
return smoothed_confidence
|
||||
347
v1/src/core/phase_sanitizer.py
Normal file
347
v1/src/core/phase_sanitizer.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""Phase sanitization module for WiFi-DensePose system using TDD approach."""
|
||||
|
||||
import numpy as np
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from datetime import datetime, timezone
|
||||
from scipy import signal
|
||||
|
||||
|
||||
class PhaseSanitizationError(Exception):
|
||||
"""Exception raised for phase sanitization errors."""
|
||||
pass
|
||||
|
||||
|
||||
class PhaseSanitizer:
|
||||
"""Sanitizes phase data from CSI signals for reliable processing."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any], logger: Optional[logging.Logger] = None):
|
||||
"""Initialize phase sanitizer.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
logger: Optional logger instance
|
||||
|
||||
Raises:
|
||||
ValueError: If configuration is invalid
|
||||
"""
|
||||
self._validate_config(config)
|
||||
|
||||
self.config = config
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
|
||||
# Processing parameters
|
||||
self.unwrapping_method = config['unwrapping_method']
|
||||
self.outlier_threshold = config['outlier_threshold']
|
||||
self.smoothing_window = config['smoothing_window']
|
||||
|
||||
# Optional parameters with defaults
|
||||
self.enable_outlier_removal = config.get('enable_outlier_removal', True)
|
||||
self.enable_smoothing = config.get('enable_smoothing', True)
|
||||
self.enable_noise_filtering = config.get('enable_noise_filtering', False)
|
||||
self.noise_threshold = config.get('noise_threshold', 0.05)
|
||||
self.phase_range = config.get('phase_range', (-np.pi, np.pi))
|
||||
|
||||
# Statistics tracking
|
||||
self._total_processed = 0
|
||||
self._outliers_removed = 0
|
||||
self._sanitization_errors = 0
|
||||
|
||||
def _validate_config(self, config: Dict[str, Any]) -> None:
|
||||
"""Validate configuration parameters.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If configuration is invalid
|
||||
"""
|
||||
required_fields = ['unwrapping_method', 'outlier_threshold', 'smoothing_window']
|
||||
missing_fields = [field for field in required_fields if field not in config]
|
||||
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required configuration: {missing_fields}")
|
||||
|
||||
# Validate unwrapping method
|
||||
valid_methods = ['numpy', 'scipy', 'custom']
|
||||
if config['unwrapping_method'] not in valid_methods:
|
||||
raise ValueError(f"Invalid unwrapping method: {config['unwrapping_method']}. Must be one of {valid_methods}")
|
||||
|
||||
# Validate thresholds
|
||||
if config['outlier_threshold'] <= 0:
|
||||
raise ValueError("outlier_threshold must be positive")
|
||||
|
||||
if config['smoothing_window'] <= 0:
|
||||
raise ValueError("smoothing_window must be positive")
|
||||
|
||||
def unwrap_phase(self, phase_data: np.ndarray) -> np.ndarray:
|
||||
"""Unwrap phase data to remove discontinuities.
|
||||
|
||||
Args:
|
||||
phase_data: Wrapped phase data (2D array)
|
||||
|
||||
Returns:
|
||||
Unwrapped phase data
|
||||
|
||||
Raises:
|
||||
PhaseSanitizationError: If unwrapping fails
|
||||
"""
|
||||
try:
|
||||
if self.unwrapping_method == 'numpy':
|
||||
return self._unwrap_numpy(phase_data)
|
||||
elif self.unwrapping_method == 'scipy':
|
||||
return self._unwrap_scipy(phase_data)
|
||||
elif self.unwrapping_method == 'custom':
|
||||
return self._unwrap_custom(phase_data)
|
||||
else:
|
||||
raise ValueError(f"Unknown unwrapping method: {self.unwrapping_method}")
|
||||
|
||||
except Exception as e:
|
||||
raise PhaseSanitizationError(f"Failed to unwrap phase: {e}")
|
||||
|
||||
def _unwrap_numpy(self, phase_data: np.ndarray) -> np.ndarray:
|
||||
"""Unwrap phase using numpy's unwrap function."""
|
||||
if phase_data.size == 0:
|
||||
raise ValueError("Cannot unwrap empty phase data")
|
||||
return np.unwrap(phase_data, axis=1)
|
||||
|
||||
def _unwrap_scipy(self, phase_data: np.ndarray) -> np.ndarray:
|
||||
"""Unwrap phase using scipy's unwrap function."""
|
||||
if phase_data.size == 0:
|
||||
raise ValueError("Cannot unwrap empty phase data")
|
||||
return np.unwrap(phase_data, axis=1)
|
||||
|
||||
def _unwrap_custom(self, phase_data: np.ndarray) -> np.ndarray:
|
||||
"""Unwrap phase using custom algorithm."""
|
||||
if phase_data.size == 0:
|
||||
raise ValueError("Cannot unwrap empty phase data")
|
||||
# Simple custom unwrapping algorithm
|
||||
unwrapped = phase_data.copy()
|
||||
for i in range(phase_data.shape[0]):
|
||||
unwrapped[i, :] = np.unwrap(phase_data[i, :])
|
||||
return unwrapped
|
||||
|
||||
def remove_outliers(self, phase_data: np.ndarray) -> np.ndarray:
|
||||
"""Remove outliers from phase data.
|
||||
|
||||
Args:
|
||||
phase_data: Phase data (2D array)
|
||||
|
||||
Returns:
|
||||
Phase data with outliers removed
|
||||
|
||||
Raises:
|
||||
PhaseSanitizationError: If outlier removal fails
|
||||
"""
|
||||
if not self.enable_outlier_removal:
|
||||
return phase_data
|
||||
|
||||
try:
|
||||
# Detect outliers
|
||||
outlier_mask = self._detect_outliers(phase_data)
|
||||
|
||||
# Interpolate outliers
|
||||
clean_data = self._interpolate_outliers(phase_data, outlier_mask)
|
||||
|
||||
return clean_data
|
||||
|
||||
except Exception as e:
|
||||
raise PhaseSanitizationError(f"Failed to remove outliers: {e}")
|
||||
|
||||
def _detect_outliers(self, phase_data: np.ndarray) -> np.ndarray:
|
||||
"""Detect outliers using statistical methods."""
|
||||
# Use Z-score method to detect outliers
|
||||
z_scores = np.abs((phase_data - np.mean(phase_data, axis=1, keepdims=True)) /
|
||||
(np.std(phase_data, axis=1, keepdims=True) + 1e-8))
|
||||
outlier_mask = z_scores > self.outlier_threshold
|
||||
|
||||
# Update statistics
|
||||
self._outliers_removed += np.sum(outlier_mask)
|
||||
|
||||
return outlier_mask
|
||||
|
||||
def _interpolate_outliers(self, phase_data: np.ndarray, outlier_mask: np.ndarray) -> np.ndarray:
|
||||
"""Interpolate outlier values."""
|
||||
clean_data = phase_data.copy()
|
||||
|
||||
for i in range(phase_data.shape[0]):
|
||||
outliers = outlier_mask[i, :]
|
||||
if np.any(outliers):
|
||||
# Linear interpolation for outliers
|
||||
valid_indices = np.where(~outliers)[0]
|
||||
outlier_indices = np.where(outliers)[0]
|
||||
|
||||
if len(valid_indices) > 1:
|
||||
clean_data[i, outlier_indices] = np.interp(
|
||||
outlier_indices, valid_indices, phase_data[i, valid_indices]
|
||||
)
|
||||
|
||||
return clean_data
|
||||
|
||||
def smooth_phase(self, phase_data: np.ndarray) -> np.ndarray:
|
||||
"""Smooth phase data to reduce noise.
|
||||
|
||||
Args:
|
||||
phase_data: Phase data (2D array)
|
||||
|
||||
Returns:
|
||||
Smoothed phase data
|
||||
|
||||
Raises:
|
||||
PhaseSanitizationError: If smoothing fails
|
||||
"""
|
||||
if not self.enable_smoothing:
|
||||
return phase_data
|
||||
|
||||
try:
|
||||
smoothed_data = self._apply_moving_average(phase_data, self.smoothing_window)
|
||||
return smoothed_data
|
||||
|
||||
except Exception as e:
|
||||
raise PhaseSanitizationError(f"Failed to smooth phase: {e}")
|
||||
|
||||
def _apply_moving_average(self, phase_data: np.ndarray, window_size: int) -> np.ndarray:
|
||||
"""Apply moving average smoothing."""
|
||||
smoothed_data = phase_data.copy()
|
||||
|
||||
# Ensure window size is odd
|
||||
if window_size % 2 == 0:
|
||||
window_size += 1
|
||||
|
||||
half_window = window_size // 2
|
||||
|
||||
for i in range(phase_data.shape[0]):
|
||||
for j in range(half_window, phase_data.shape[1] - half_window):
|
||||
start_idx = j - half_window
|
||||
end_idx = j + half_window + 1
|
||||
smoothed_data[i, j] = np.mean(phase_data[i, start_idx:end_idx])
|
||||
|
||||
return smoothed_data
|
||||
|
||||
def filter_noise(self, phase_data: np.ndarray) -> np.ndarray:
|
||||
"""Filter noise from phase data.
|
||||
|
||||
Args:
|
||||
phase_data: Phase data (2D array)
|
||||
|
||||
Returns:
|
||||
Filtered phase data
|
||||
|
||||
Raises:
|
||||
PhaseSanitizationError: If noise filtering fails
|
||||
"""
|
||||
if not self.enable_noise_filtering:
|
||||
return phase_data
|
||||
|
||||
try:
|
||||
filtered_data = self._apply_low_pass_filter(phase_data, self.noise_threshold)
|
||||
return filtered_data
|
||||
|
||||
except Exception as e:
|
||||
raise PhaseSanitizationError(f"Failed to filter noise: {e}")
|
||||
|
||||
def _apply_low_pass_filter(self, phase_data: np.ndarray, threshold: float) -> np.ndarray:
|
||||
"""Apply low-pass filter to remove high-frequency noise."""
|
||||
filtered_data = phase_data.copy()
|
||||
|
||||
# Check if data is large enough for filtering
|
||||
min_filter_length = 18 # Minimum length required for filtfilt with order 4
|
||||
if phase_data.shape[1] < min_filter_length:
|
||||
# Skip filtering for small arrays
|
||||
return filtered_data
|
||||
|
||||
# Apply Butterworth low-pass filter
|
||||
nyquist = 0.5
|
||||
cutoff = threshold * nyquist
|
||||
|
||||
# Design filter
|
||||
b, a = signal.butter(4, cutoff, btype='low')
|
||||
|
||||
# Apply filter to each antenna
|
||||
for i in range(phase_data.shape[0]):
|
||||
filtered_data[i, :] = signal.filtfilt(b, a, phase_data[i, :])
|
||||
|
||||
return filtered_data
|
||||
|
||||
def sanitize_phase(self, phase_data: np.ndarray) -> np.ndarray:
|
||||
"""Sanitize phase data through complete pipeline.
|
||||
|
||||
Args:
|
||||
phase_data: Raw phase data (2D array)
|
||||
|
||||
Returns:
|
||||
Sanitized phase data
|
||||
|
||||
Raises:
|
||||
PhaseSanitizationError: If sanitization fails
|
||||
"""
|
||||
try:
|
||||
self._total_processed += 1
|
||||
|
||||
# Validate input data
|
||||
self.validate_phase_data(phase_data)
|
||||
|
||||
# Apply complete sanitization pipeline
|
||||
sanitized_data = self.unwrap_phase(phase_data)
|
||||
sanitized_data = self.remove_outliers(sanitized_data)
|
||||
sanitized_data = self.smooth_phase(sanitized_data)
|
||||
sanitized_data = self.filter_noise(sanitized_data)
|
||||
|
||||
return sanitized_data
|
||||
|
||||
except PhaseSanitizationError:
|
||||
self._sanitization_errors += 1
|
||||
raise
|
||||
except Exception as e:
|
||||
self._sanitization_errors += 1
|
||||
raise PhaseSanitizationError(f"Sanitization pipeline failed: {e}")
|
||||
|
||||
def validate_phase_data(self, phase_data: np.ndarray) -> bool:
|
||||
"""Validate phase data format and values.
|
||||
|
||||
Args:
|
||||
phase_data: Phase data to validate
|
||||
|
||||
Returns:
|
||||
True if valid
|
||||
|
||||
Raises:
|
||||
PhaseSanitizationError: If validation fails
|
||||
"""
|
||||
# Check if data is 2D
|
||||
if phase_data.ndim != 2:
|
||||
raise PhaseSanitizationError("Phase data must be 2D array")
|
||||
|
||||
# Check if data is not empty
|
||||
if phase_data.size == 0:
|
||||
raise PhaseSanitizationError("Phase data cannot be empty")
|
||||
|
||||
# Check if values are within valid range
|
||||
min_val, max_val = self.phase_range
|
||||
if np.any(phase_data < min_val) or np.any(phase_data > max_val):
|
||||
raise PhaseSanitizationError(f"Phase values outside valid range [{min_val}, {max_val}]")
|
||||
|
||||
return True
|
||||
|
||||
def get_sanitization_statistics(self) -> Dict[str, Any]:
|
||||
"""Get sanitization statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary containing sanitization statistics
|
||||
"""
|
||||
outlier_rate = self._outliers_removed / self._total_processed if self._total_processed > 0 else 0
|
||||
error_rate = self._sanitization_errors / self._total_processed if self._total_processed > 0 else 0
|
||||
|
||||
return {
|
||||
'total_processed': self._total_processed,
|
||||
'outliers_removed': self._outliers_removed,
|
||||
'sanitization_errors': self._sanitization_errors,
|
||||
'outlier_rate': outlier_rate,
|
||||
'error_rate': error_rate
|
||||
}
|
||||
|
||||
def reset_statistics(self) -> None:
|
||||
"""Reset sanitization statistics."""
|
||||
self._total_processed = 0
|
||||
self._outliers_removed = 0
|
||||
self._sanitization_errors = 0
|
||||
340
v1/src/core/router_interface.py
Normal file
340
v1/src/core/router_interface.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""
|
||||
Router interface for WiFi CSI data collection
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RouterInterface:
|
||||
"""Interface for connecting to WiFi routers and collecting CSI data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
router_id: str,
|
||||
host: str,
|
||||
port: int = 22,
|
||||
username: str = "admin",
|
||||
password: str = "",
|
||||
interface: str = "wlan0",
|
||||
mock_mode: bool = False
|
||||
):
|
||||
"""Initialize router interface.
|
||||
|
||||
Args:
|
||||
router_id: Unique identifier for the router
|
||||
host: Router IP address or hostname
|
||||
port: SSH port for connection
|
||||
username: SSH username
|
||||
password: SSH password
|
||||
interface: WiFi interface name
|
||||
mock_mode: Whether to use mock data instead of real connection
|
||||
"""
|
||||
self.router_id = router_id
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.interface = interface
|
||||
self.mock_mode = mock_mode
|
||||
|
||||
self.logger = logging.getLogger(f"{__name__}.{router_id}")
|
||||
|
||||
# Connection state
|
||||
self.is_connected = False
|
||||
self.connection = None
|
||||
self.last_error = None
|
||||
|
||||
# Data collection state
|
||||
self.last_data_time = None
|
||||
self.error_count = 0
|
||||
self.sample_count = 0
|
||||
|
||||
# Mock data generation
|
||||
self.mock_data_generator = None
|
||||
if mock_mode:
|
||||
self._initialize_mock_generator()
|
||||
|
||||
def _initialize_mock_generator(self):
|
||||
"""Initialize mock data generator."""
|
||||
self.mock_data_generator = {
|
||||
'phase': 0,
|
||||
'amplitude_base': 1.0,
|
||||
'frequency': 0.1,
|
||||
'noise_level': 0.1
|
||||
}
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to the router."""
|
||||
if self.mock_mode:
|
||||
self.is_connected = True
|
||||
self.logger.info(f"Mock connection established to router {self.router_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
self.logger.info(f"Connecting to router {self.router_id} at {self.host}:{self.port}")
|
||||
|
||||
# In a real implementation, this would establish SSH connection
|
||||
# For now, we'll simulate the connection
|
||||
await asyncio.sleep(0.1) # Simulate connection delay
|
||||
|
||||
self.is_connected = True
|
||||
self.error_count = 0
|
||||
self.logger.info(f"Connected to router {self.router_id}")
|
||||
|
||||
except Exception as e:
|
||||
self.last_error = str(e)
|
||||
self.error_count += 1
|
||||
self.logger.error(f"Failed to connect to router {self.router_id}: {e}")
|
||||
raise
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from the router."""
|
||||
try:
|
||||
if self.connection:
|
||||
# Close SSH connection
|
||||
self.connection = None
|
||||
|
||||
self.is_connected = False
|
||||
self.logger.info(f"Disconnected from router {self.router_id}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error disconnecting from router {self.router_id}: {e}")
|
||||
|
||||
async def reconnect(self):
|
||||
"""Reconnect to the router."""
|
||||
await self.disconnect()
|
||||
await asyncio.sleep(1) # Wait before reconnecting
|
||||
await self.connect()
|
||||
|
||||
async def get_csi_data(self) -> Optional[np.ndarray]:
|
||||
"""Get CSI data from the router.
|
||||
|
||||
Returns:
|
||||
CSI data as numpy array, or None if no data available
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise RuntimeError(f"Router {self.router_id} is not connected")
|
||||
|
||||
try:
|
||||
if self.mock_mode:
|
||||
csi_data = self._generate_mock_csi_data()
|
||||
else:
|
||||
csi_data = await self._collect_real_csi_data()
|
||||
|
||||
if csi_data is not None:
|
||||
self.last_data_time = datetime.now()
|
||||
self.sample_count += 1
|
||||
self.error_count = 0
|
||||
|
||||
return csi_data
|
||||
|
||||
except Exception as e:
|
||||
self.last_error = str(e)
|
||||
self.error_count += 1
|
||||
self.logger.error(f"Error getting CSI data from router {self.router_id}: {e}")
|
||||
return None
|
||||
|
||||
def _generate_mock_csi_data(self) -> np.ndarray:
|
||||
"""Generate mock CSI data for testing."""
|
||||
# Simulate CSI data with realistic characteristics
|
||||
num_subcarriers = 64
|
||||
num_antennas = 4
|
||||
num_samples = 100
|
||||
|
||||
# Update mock generator state
|
||||
self.mock_data_generator['phase'] += self.mock_data_generator['frequency']
|
||||
|
||||
# Generate amplitude and phase data
|
||||
time_axis = np.linspace(0, 1, num_samples)
|
||||
|
||||
# Create realistic CSI patterns
|
||||
csi_data = np.zeros((num_antennas, num_subcarriers, num_samples), dtype=complex)
|
||||
|
||||
for antenna in range(num_antennas):
|
||||
for subcarrier in range(num_subcarriers):
|
||||
# Base signal with some variation per antenna/subcarrier
|
||||
amplitude = (
|
||||
self.mock_data_generator['amplitude_base'] *
|
||||
(1 + 0.2 * np.sin(2 * np.pi * subcarrier / num_subcarriers)) *
|
||||
(1 + 0.1 * antenna)
|
||||
)
|
||||
|
||||
# Phase with spatial and frequency variation
|
||||
phase_offset = (
|
||||
self.mock_data_generator['phase'] +
|
||||
2 * np.pi * subcarrier / num_subcarriers +
|
||||
np.pi * antenna / num_antennas
|
||||
)
|
||||
|
||||
# Add some movement simulation
|
||||
movement_freq = 0.5 # Hz
|
||||
movement_amplitude = 0.3
|
||||
movement = movement_amplitude * np.sin(2 * np.pi * movement_freq * time_axis)
|
||||
|
||||
# Generate complex signal
|
||||
signal_amplitude = amplitude * (1 + movement)
|
||||
signal_phase = phase_offset + movement * 0.5
|
||||
|
||||
# Add noise
|
||||
noise_real = np.random.normal(0, self.mock_data_generator['noise_level'], num_samples)
|
||||
noise_imag = np.random.normal(0, self.mock_data_generator['noise_level'], num_samples)
|
||||
noise = noise_real + 1j * noise_imag
|
||||
|
||||
# Create complex signal
|
||||
signal = signal_amplitude * np.exp(1j * signal_phase) + noise
|
||||
csi_data[antenna, subcarrier, :] = signal
|
||||
|
||||
return csi_data
|
||||
|
||||
async def _collect_real_csi_data(self) -> Optional[np.ndarray]:
|
||||
"""Collect real CSI data from router (placeholder implementation)."""
|
||||
# This would implement the actual CSI data collection
|
||||
# For now, return None to indicate no real implementation
|
||||
self.logger.warning("Real CSI data collection not implemented")
|
||||
return None
|
||||
|
||||
async def check_health(self) -> bool:
|
||||
"""Check if the router connection is healthy.
|
||||
|
||||
Returns:
|
||||
True if healthy, False otherwise
|
||||
"""
|
||||
if not self.is_connected:
|
||||
return False
|
||||
|
||||
try:
|
||||
# In mock mode, always healthy
|
||||
if self.mock_mode:
|
||||
return True
|
||||
|
||||
# For real connections, we could ping the router or check SSH connection
|
||||
# For now, consider healthy if error count is low
|
||||
return self.error_count < 5
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error checking health of router {self.router_id}: {e}")
|
||||
return False
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get router status information.
|
||||
|
||||
Returns:
|
||||
Dictionary containing router status
|
||||
"""
|
||||
return {
|
||||
"router_id": self.router_id,
|
||||
"connected": self.is_connected,
|
||||
"mock_mode": self.mock_mode,
|
||||
"last_data_time": self.last_data_time.isoformat() if self.last_data_time else None,
|
||||
"error_count": self.error_count,
|
||||
"sample_count": self.sample_count,
|
||||
"last_error": self.last_error,
|
||||
"configuration": {
|
||||
"host": self.host,
|
||||
"port": self.port,
|
||||
"username": self.username,
|
||||
"interface": self.interface
|
||||
}
|
||||
}
|
||||
|
||||
async def get_router_info(self) -> Dict[str, Any]:
|
||||
"""Get router hardware information.
|
||||
|
||||
Returns:
|
||||
Dictionary containing router information
|
||||
"""
|
||||
if self.mock_mode:
|
||||
return {
|
||||
"model": "Mock Router",
|
||||
"firmware": "1.0.0-mock",
|
||||
"wifi_standard": "802.11ac",
|
||||
"antennas": 4,
|
||||
"supported_bands": ["2.4GHz", "5GHz"],
|
||||
"csi_capabilities": {
|
||||
"max_subcarriers": 64,
|
||||
"max_antennas": 4,
|
||||
"sampling_rate": 1000
|
||||
}
|
||||
}
|
||||
|
||||
# For real routers, this would query the actual hardware
|
||||
return {
|
||||
"model": "Unknown",
|
||||
"firmware": "Unknown",
|
||||
"wifi_standard": "Unknown",
|
||||
"antennas": 1,
|
||||
"supported_bands": ["Unknown"],
|
||||
"csi_capabilities": {
|
||||
"max_subcarriers": 64,
|
||||
"max_antennas": 1,
|
||||
"sampling_rate": 100
|
||||
}
|
||||
}
|
||||
|
||||
async def configure_csi_collection(self, config: Dict[str, Any]) -> bool:
|
||||
"""Configure CSI data collection parameters.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
True if configuration successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
if self.mock_mode:
|
||||
# Update mock generator parameters
|
||||
if 'sampling_rate' in config:
|
||||
self.mock_data_generator['frequency'] = config['sampling_rate'] / 1000.0
|
||||
|
||||
if 'noise_level' in config:
|
||||
self.mock_data_generator['noise_level'] = config['noise_level']
|
||||
|
||||
self.logger.info(f"Mock CSI collection configured for router {self.router_id}")
|
||||
return True
|
||||
|
||||
# For real routers, this would send configuration commands
|
||||
self.logger.warning("Real CSI configuration not implemented")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error configuring CSI collection for router {self.router_id}: {e}")
|
||||
return False
|
||||
|
||||
def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get router interface metrics.
|
||||
|
||||
Returns:
|
||||
Dictionary containing metrics
|
||||
"""
|
||||
uptime = 0
|
||||
if self.last_data_time:
|
||||
uptime = (datetime.now() - self.last_data_time).total_seconds()
|
||||
|
||||
success_rate = 0
|
||||
if self.sample_count > 0:
|
||||
success_rate = (self.sample_count - self.error_count) / self.sample_count
|
||||
|
||||
return {
|
||||
"router_id": self.router_id,
|
||||
"sample_count": self.sample_count,
|
||||
"error_count": self.error_count,
|
||||
"success_rate": success_rate,
|
||||
"uptime_seconds": uptime,
|
||||
"is_connected": self.is_connected,
|
||||
"mock_mode": self.mock_mode
|
||||
}
|
||||
|
||||
def reset_stats(self):
|
||||
"""Reset statistics counters."""
|
||||
self.error_count = 0
|
||||
self.sample_count = 0
|
||||
self.last_error = None
|
||||
self.logger.info(f"Statistics reset for router {self.router_id}")
|
||||
640
v1/src/database/connection.py
Normal file
640
v1/src/database/connection.py
Normal file
@@ -0,0 +1,640 @@
|
||||
"""
|
||||
Database connection management for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import create_engine, event, pool, text
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy.pool import QueuePool, NullPool
|
||||
from sqlalchemy.exc import SQLAlchemyError, DisconnectionError
|
||||
import redis.asyncio as redis
|
||||
from redis.exceptions import ConnectionError as RedisConnectionError
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DatabaseConnectionError(Exception):
|
||||
"""Database connection error."""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""Database connection manager."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self._async_engine = None
|
||||
self._sync_engine = None
|
||||
self._async_session_factory = None
|
||||
self._sync_session_factory = None
|
||||
self._redis_client = None
|
||||
self._initialized = False
|
||||
self._connection_pool_size = settings.db_pool_size
|
||||
self._max_overflow = settings.db_max_overflow
|
||||
self._pool_timeout = settings.db_pool_timeout
|
||||
self._pool_recycle = settings.db_pool_recycle
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize database connections."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
logger.info("Initializing database connections")
|
||||
|
||||
try:
|
||||
# Initialize PostgreSQL connections
|
||||
await self._initialize_postgresql()
|
||||
|
||||
# Initialize Redis connection
|
||||
await self._initialize_redis()
|
||||
|
||||
self._initialized = True
|
||||
logger.info("Database connections initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize database connections: {e}")
|
||||
raise DatabaseConnectionError(f"Database initialization failed: {e}")
|
||||
|
||||
async def _initialize_postgresql(self):
|
||||
"""Initialize PostgreSQL connections with SQLite failsafe."""
|
||||
postgresql_failed = False
|
||||
|
||||
try:
|
||||
# Try PostgreSQL first
|
||||
await self._initialize_postgresql_primary()
|
||||
logger.info("PostgreSQL connections initialized")
|
||||
return
|
||||
except Exception as e:
|
||||
postgresql_failed = True
|
||||
logger.error(f"PostgreSQL initialization failed: {e}")
|
||||
|
||||
if not self.settings.enable_database_failsafe:
|
||||
raise DatabaseConnectionError(f"PostgreSQL connection failed and failsafe disabled: {e}")
|
||||
|
||||
logger.warning("Falling back to SQLite database")
|
||||
|
||||
# Fallback to SQLite if PostgreSQL failed and failsafe is enabled
|
||||
if postgresql_failed and self.settings.enable_database_failsafe:
|
||||
await self._initialize_sqlite_fallback()
|
||||
logger.info("SQLite fallback database initialized")
|
||||
|
||||
async def _initialize_postgresql_primary(self):
|
||||
"""Initialize primary PostgreSQL connections."""
|
||||
# Build database URL
|
||||
if self.settings.database_url and "postgresql" in self.settings.database_url:
|
||||
db_url = self.settings.database_url
|
||||
async_db_url = self.settings.database_url.replace("postgresql://", "postgresql+asyncpg://")
|
||||
elif self.settings.db_host and self.settings.db_name and self.settings.db_user:
|
||||
db_url = (
|
||||
f"postgresql://{self.settings.db_user}:{self.settings.db_password}"
|
||||
f"@{self.settings.db_host}:{self.settings.db_port}/{self.settings.db_name}"
|
||||
)
|
||||
async_db_url = (
|
||||
f"postgresql+asyncpg://{self.settings.db_user}:{self.settings.db_password}"
|
||||
f"@{self.settings.db_host}:{self.settings.db_port}/{self.settings.db_name}"
|
||||
)
|
||||
else:
|
||||
raise ValueError("PostgreSQL connection parameters not configured")
|
||||
|
||||
# Create async engine (don't specify poolclass for async engines)
|
||||
self._async_engine = create_async_engine(
|
||||
async_db_url,
|
||||
pool_size=self._connection_pool_size,
|
||||
max_overflow=self._max_overflow,
|
||||
pool_timeout=self._pool_timeout,
|
||||
pool_recycle=self._pool_recycle,
|
||||
pool_pre_ping=True,
|
||||
echo=self.settings.db_echo,
|
||||
future=True,
|
||||
)
|
||||
|
||||
# Create sync engine for migrations and admin tasks
|
||||
self._sync_engine = create_engine(
|
||||
db_url,
|
||||
poolclass=QueuePool,
|
||||
pool_size=max(2, self._connection_pool_size // 2),
|
||||
max_overflow=self._max_overflow // 2,
|
||||
pool_timeout=self._pool_timeout,
|
||||
pool_recycle=self._pool_recycle,
|
||||
pool_pre_ping=True,
|
||||
echo=self.settings.db_echo,
|
||||
future=True,
|
||||
)
|
||||
|
||||
# Create session factories
|
||||
self._async_session_factory = async_sessionmaker(
|
||||
self._async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
self._sync_session_factory = sessionmaker(
|
||||
self._sync_engine,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
# Add connection event listeners
|
||||
self._setup_connection_events()
|
||||
|
||||
# Test connections
|
||||
await self._test_postgresql_connection()
|
||||
|
||||
async def _initialize_sqlite_fallback(self):
|
||||
"""Initialize SQLite fallback database."""
|
||||
import os
|
||||
|
||||
# Ensure directory exists
|
||||
sqlite_path = self.settings.sqlite_fallback_path
|
||||
os.makedirs(os.path.dirname(sqlite_path), exist_ok=True)
|
||||
|
||||
# Build SQLite URLs
|
||||
db_url = f"sqlite:///{sqlite_path}"
|
||||
async_db_url = f"sqlite+aiosqlite:///{sqlite_path}"
|
||||
|
||||
# Create async engine for SQLite
|
||||
self._async_engine = create_async_engine(
|
||||
async_db_url,
|
||||
echo=self.settings.db_echo,
|
||||
future=True,
|
||||
)
|
||||
|
||||
# Create sync engine for SQLite
|
||||
self._sync_engine = create_engine(
|
||||
db_url,
|
||||
poolclass=NullPool, # SQLite doesn't need connection pooling
|
||||
echo=self.settings.db_echo,
|
||||
future=True,
|
||||
)
|
||||
|
||||
# Create session factories
|
||||
self._async_session_factory = async_sessionmaker(
|
||||
self._async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
self._sync_session_factory = sessionmaker(
|
||||
self._sync_engine,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
# Add connection event listeners
|
||||
self._setup_connection_events()
|
||||
|
||||
# Test SQLite connection
|
||||
await self._test_sqlite_connection()
|
||||
|
||||
async def _test_sqlite_connection(self):
|
||||
"""Test SQLite connection."""
|
||||
try:
|
||||
async with self._async_engine.begin() as conn:
|
||||
result = await conn.execute(text("SELECT 1"))
|
||||
result.fetchone() # Don't await this - fetchone() is not async
|
||||
logger.debug("SQLite connection test successful")
|
||||
except Exception as e:
|
||||
logger.error(f"SQLite connection test failed: {e}")
|
||||
raise DatabaseConnectionError(f"SQLite connection test failed: {e}")
|
||||
|
||||
async def _initialize_redis(self):
|
||||
"""Initialize Redis connection with failsafe."""
|
||||
if not self.settings.redis_enabled:
|
||||
logger.info("Redis disabled, skipping initialization")
|
||||
return
|
||||
|
||||
try:
|
||||
# Build Redis URL
|
||||
if self.settings.redis_url:
|
||||
redis_url = self.settings.redis_url
|
||||
else:
|
||||
redis_url = (
|
||||
f"redis://{self.settings.redis_host}:{self.settings.redis_port}"
|
||||
f"/{self.settings.redis_db}"
|
||||
)
|
||||
|
||||
# Create Redis client
|
||||
self._redis_client = redis.from_url(
|
||||
redis_url,
|
||||
password=self.settings.redis_password,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
max_connections=self.settings.redis_max_connections,
|
||||
retry_on_timeout=True,
|
||||
socket_timeout=self.settings.redis_socket_timeout,
|
||||
socket_connect_timeout=self.settings.redis_connect_timeout,
|
||||
)
|
||||
|
||||
# Test Redis connection
|
||||
await self._test_redis_connection()
|
||||
|
||||
logger.info("Redis connection initialized")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Redis: {e}")
|
||||
|
||||
if self.settings.redis_required:
|
||||
raise DatabaseConnectionError(f"Redis connection failed and is required: {e}")
|
||||
elif self.settings.enable_redis_failsafe:
|
||||
logger.warning("Redis initialization failed, continuing without Redis (failsafe enabled)")
|
||||
self._redis_client = None
|
||||
else:
|
||||
logger.warning("Redis initialization failed but not required, continuing without Redis")
|
||||
self._redis_client = None
|
||||
|
||||
def _setup_connection_events(self):
|
||||
"""Setup database connection event listeners."""
|
||||
|
||||
@event.listens_for(self._sync_engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
"""Set database-specific settings on connection."""
|
||||
if "sqlite" in str(self._sync_engine.url):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
@event.listens_for(self._sync_engine, "checkout")
|
||||
def receive_checkout(dbapi_connection, connection_record, connection_proxy):
|
||||
"""Log connection checkout."""
|
||||
logger.debug("Database connection checked out")
|
||||
|
||||
@event.listens_for(self._sync_engine, "checkin")
|
||||
def receive_checkin(dbapi_connection, connection_record):
|
||||
"""Log connection checkin."""
|
||||
logger.debug("Database connection checked in")
|
||||
|
||||
@event.listens_for(self._sync_engine, "invalidate")
|
||||
def receive_invalidate(dbapi_connection, connection_record, exception):
|
||||
"""Handle connection invalidation."""
|
||||
logger.warning(f"Database connection invalidated: {exception}")
|
||||
|
||||
async def _test_postgresql_connection(self):
|
||||
"""Test PostgreSQL connection."""
|
||||
try:
|
||||
async with self._async_engine.begin() as conn:
|
||||
result = await conn.execute(text("SELECT 1"))
|
||||
result.fetchone() # Don't await this - fetchone() is not async
|
||||
logger.debug("PostgreSQL connection test successful")
|
||||
except Exception as e:
|
||||
logger.error(f"PostgreSQL connection test failed: {e}")
|
||||
raise DatabaseConnectionError(f"PostgreSQL connection test failed: {e}")
|
||||
|
||||
async def _test_redis_connection(self):
|
||||
"""Test Redis connection."""
|
||||
if not self._redis_client:
|
||||
return
|
||||
|
||||
try:
|
||||
await self._redis_client.ping()
|
||||
logger.debug("Redis connection test successful")
|
||||
except Exception as e:
|
||||
logger.error(f"Redis connection test failed: {e}")
|
||||
if self.settings.redis_required:
|
||||
raise DatabaseConnectionError(f"Redis connection test failed: {e}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_async_session(self) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Get async database session."""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
if not self._async_session_factory:
|
||||
raise DatabaseConnectionError("Async session factory not initialized")
|
||||
|
||||
session = self._async_session_factory()
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Database session error: {e}")
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_sync_session(self) -> Session:
|
||||
"""Get sync database session."""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
if not self._sync_session_factory:
|
||||
raise DatabaseConnectionError("Sync session factory not initialized")
|
||||
|
||||
session = self._sync_session_factory()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"Database session error: {e}")
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
async def get_redis_client(self) -> Optional[redis.Redis]:
|
||||
"""Get Redis client."""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
return self._redis_client
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Perform database health check."""
|
||||
health_status = {
|
||||
"database": {"status": "unknown", "details": {}},
|
||||
"redis": {"status": "unknown", "details": {}},
|
||||
"overall": "unknown"
|
||||
}
|
||||
|
||||
# Check Database (PostgreSQL or SQLite)
|
||||
try:
|
||||
start_time = datetime.utcnow()
|
||||
async with self.get_async_session() as session:
|
||||
result = await session.execute(text("SELECT 1"))
|
||||
result.fetchone() # Don't await this - fetchone() is not async
|
||||
|
||||
response_time = (datetime.utcnow() - start_time).total_seconds()
|
||||
|
||||
# Determine database type and status
|
||||
is_sqlite = self.is_using_sqlite_fallback()
|
||||
db_type = "sqlite_fallback" if is_sqlite else "postgresql"
|
||||
|
||||
details = {
|
||||
"type": db_type,
|
||||
"response_time_ms": round(response_time * 1000, 2),
|
||||
}
|
||||
|
||||
# Add pool info for PostgreSQL
|
||||
if not is_sqlite and hasattr(self._async_engine, 'pool'):
|
||||
details.update({
|
||||
"pool_size": self._async_engine.pool.size(),
|
||||
"checked_out": self._async_engine.pool.checkedout(),
|
||||
"overflow": self._async_engine.pool.overflow(),
|
||||
})
|
||||
|
||||
# Add failsafe info
|
||||
if is_sqlite:
|
||||
details["failsafe_active"] = True
|
||||
details["fallback_path"] = self.settings.sqlite_fallback_path
|
||||
|
||||
health_status["database"] = {
|
||||
"status": "healthy",
|
||||
"details": details
|
||||
}
|
||||
except Exception as e:
|
||||
health_status["database"] = {
|
||||
"status": "unhealthy",
|
||||
"details": {"error": str(e)}
|
||||
}
|
||||
|
||||
# Check Redis
|
||||
if self._redis_client:
|
||||
try:
|
||||
start_time = datetime.utcnow()
|
||||
await self._redis_client.ping()
|
||||
response_time = (datetime.utcnow() - start_time).total_seconds()
|
||||
|
||||
info = await self._redis_client.info()
|
||||
|
||||
health_status["redis"] = {
|
||||
"status": "healthy",
|
||||
"details": {
|
||||
"response_time_ms": round(response_time * 1000, 2),
|
||||
"connected_clients": info.get("connected_clients", 0),
|
||||
"used_memory": info.get("used_memory_human", "unknown"),
|
||||
"uptime": info.get("uptime_in_seconds", 0),
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
health_status["redis"] = {
|
||||
"status": "unhealthy",
|
||||
"details": {"error": str(e)}
|
||||
}
|
||||
else:
|
||||
health_status["redis"] = {
|
||||
"status": "disabled",
|
||||
"details": {"message": "Redis not enabled"}
|
||||
}
|
||||
|
||||
# Determine overall status
|
||||
database_healthy = health_status["database"]["status"] == "healthy"
|
||||
redis_healthy = (
|
||||
health_status["redis"]["status"] in ["healthy", "disabled"] or
|
||||
not self.settings.redis_required
|
||||
)
|
||||
|
||||
# Check if using failsafe modes
|
||||
using_sqlite_fallback = self.is_using_sqlite_fallback()
|
||||
redis_unavailable = not self.is_redis_available() and self.settings.redis_enabled
|
||||
|
||||
if database_healthy and redis_healthy:
|
||||
if using_sqlite_fallback or redis_unavailable:
|
||||
health_status["overall"] = "degraded" # Working but using failsafe
|
||||
else:
|
||||
health_status["overall"] = "healthy"
|
||||
elif database_healthy:
|
||||
health_status["overall"] = "degraded"
|
||||
else:
|
||||
health_status["overall"] = "unhealthy"
|
||||
|
||||
return health_status
|
||||
|
||||
async def get_connection_stats(self) -> Dict[str, Any]:
|
||||
"""Get database connection statistics."""
|
||||
stats = {
|
||||
"postgresql": {},
|
||||
"redis": {}
|
||||
}
|
||||
|
||||
# PostgreSQL stats
|
||||
if self._async_engine:
|
||||
pool = self._async_engine.pool
|
||||
stats["postgresql"] = {
|
||||
"pool_size": pool.size(),
|
||||
"checked_out": pool.checkedout(),
|
||||
"overflow": pool.overflow(),
|
||||
"checked_in": pool.checkedin(),
|
||||
"total_connections": pool.size() + pool.overflow(),
|
||||
"available_connections": pool.size() - pool.checkedout(),
|
||||
}
|
||||
|
||||
# Redis stats
|
||||
if self._redis_client:
|
||||
try:
|
||||
info = await self._redis_client.info()
|
||||
stats["redis"] = {
|
||||
"connected_clients": info.get("connected_clients", 0),
|
||||
"blocked_clients": info.get("blocked_clients", 0),
|
||||
"total_connections_received": info.get("total_connections_received", 0),
|
||||
"rejected_connections": info.get("rejected_connections", 0),
|
||||
}
|
||||
except Exception as e:
|
||||
stats["redis"] = {"error": str(e)}
|
||||
|
||||
return stats
|
||||
|
||||
async def close_connections(self):
|
||||
"""Close all database connections."""
|
||||
logger.info("Closing database connections")
|
||||
|
||||
# Close PostgreSQL connections
|
||||
if self._async_engine:
|
||||
await self._async_engine.dispose()
|
||||
logger.debug("Async PostgreSQL engine disposed")
|
||||
|
||||
if self._sync_engine:
|
||||
self._sync_engine.dispose()
|
||||
logger.debug("Sync PostgreSQL engine disposed")
|
||||
|
||||
# Close Redis connection
|
||||
if self._redis_client:
|
||||
await self._redis_client.close()
|
||||
logger.debug("Redis connection closed")
|
||||
|
||||
self._initialized = False
|
||||
logger.info("Database connections closed")
|
||||
|
||||
def is_using_sqlite_fallback(self) -> bool:
|
||||
"""Check if currently using SQLite fallback database."""
|
||||
if not self._async_engine:
|
||||
return False
|
||||
return "sqlite" in str(self._async_engine.url)
|
||||
|
||||
def is_redis_available(self) -> bool:
|
||||
"""Check if Redis is available."""
|
||||
return self._redis_client is not None
|
||||
|
||||
async def test_connection(self) -> bool:
|
||||
"""Test database connection for CLI validation."""
|
||||
try:
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
# Test database connection (PostgreSQL or SQLite)
|
||||
async with self.get_async_session() as session:
|
||||
result = await session.execute(text("SELECT 1"))
|
||||
result.fetchone() # Don't await this - fetchone() is not async
|
||||
|
||||
# Test Redis connection if enabled
|
||||
if self._redis_client:
|
||||
await self._redis_client.ping()
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Database connection test failed: {e}")
|
||||
return False
|
||||
|
||||
async def reset_connections(self):
|
||||
"""Reset all database connections."""
|
||||
logger.info("Resetting database connections")
|
||||
await self.close_connections()
|
||||
await self.initialize()
|
||||
logger.info("Database connections reset")
|
||||
|
||||
|
||||
# Global database manager instance
|
||||
_db_manager: Optional[DatabaseManager] = None
|
||||
|
||||
|
||||
def get_database_manager(settings: Settings) -> DatabaseManager:
|
||||
"""Get database manager instance."""
|
||||
global _db_manager
|
||||
if _db_manager is None:
|
||||
_db_manager = DatabaseManager(settings)
|
||||
return _db_manager
|
||||
|
||||
|
||||
async def get_async_session(settings: Settings) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Dependency to get async database session."""
|
||||
db_manager = get_database_manager(settings)
|
||||
async with db_manager.get_async_session() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def get_redis_client(settings: Settings) -> Optional[redis.Redis]:
|
||||
"""Dependency to get Redis client."""
|
||||
db_manager = get_database_manager(settings)
|
||||
return await db_manager.get_redis_client()
|
||||
|
||||
|
||||
class DatabaseHealthCheck:
|
||||
"""Database health check utility."""
|
||||
|
||||
def __init__(self, db_manager: DatabaseManager):
|
||||
self.db_manager = db_manager
|
||||
|
||||
async def check_postgresql(self) -> Dict[str, Any]:
|
||||
"""Check PostgreSQL health."""
|
||||
try:
|
||||
start_time = datetime.utcnow()
|
||||
async with self.db_manager.get_async_session() as session:
|
||||
result = await session.execute(text("SELECT version()"))
|
||||
version = result.fetchone()[0] # Don't await this - fetchone() is not async
|
||||
|
||||
response_time = (datetime.utcnow() - start_time).total_seconds()
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"version": version,
|
||||
"response_time_ms": round(response_time * 1000, 2),
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
async def check_redis(self) -> Dict[str, Any]:
|
||||
"""Check Redis health."""
|
||||
redis_client = await self.db_manager.get_redis_client()
|
||||
|
||||
if not redis_client:
|
||||
return {
|
||||
"status": "disabled",
|
||||
"message": "Redis not configured"
|
||||
}
|
||||
|
||||
try:
|
||||
start_time = datetime.utcnow()
|
||||
pong = await redis_client.ping()
|
||||
response_time = (datetime.utcnow() - start_time).total_seconds()
|
||||
|
||||
info = await redis_client.info("server")
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"ping": pong,
|
||||
"version": info.get("redis_version", "unknown"),
|
||||
"response_time_ms": round(response_time * 1000, 2),
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
async def full_health_check(self) -> Dict[str, Any]:
|
||||
"""Perform full database health check."""
|
||||
postgresql_health = await self.check_postgresql()
|
||||
redis_health = await self.check_redis()
|
||||
|
||||
overall_status = "healthy"
|
||||
if postgresql_health["status"] != "healthy":
|
||||
overall_status = "unhealthy"
|
||||
elif redis_health["status"] == "unhealthy":
|
||||
overall_status = "degraded"
|
||||
|
||||
return {
|
||||
"overall_status": overall_status,
|
||||
"postgresql": postgresql_health,
|
||||
"redis": redis_health,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
370
v1/src/database/migrations/001_initial.py
Normal file
370
v1/src/database/migrations/001_initial.py
Normal file
@@ -0,0 +1,370 @@
|
||||
"""
|
||||
Initial database migration for WiFi-DensePose API
|
||||
|
||||
Revision ID: 001_initial
|
||||
Revises:
|
||||
Create Date: 2025-01-07 07:58:00.000000
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers
|
||||
revision = '001_initial'
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""Create initial database schema."""
|
||||
|
||||
# Create devices table
|
||||
op.create_table(
|
||||
'devices',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('device_type', sa.String(length=50), nullable=False),
|
||||
sa.Column('mac_address', sa.String(length=17), nullable=False),
|
||||
sa.Column('ip_address', sa.String(length=45), nullable=True),
|
||||
sa.Column('status', sa.String(length=20), nullable=False),
|
||||
sa.Column('firmware_version', sa.String(length=50), nullable=True),
|
||||
sa.Column('hardware_version', sa.String(length=50), nullable=True),
|
||||
sa.Column('location_name', sa.String(length=255), nullable=True),
|
||||
sa.Column('room_id', sa.String(length=100), nullable=True),
|
||||
sa.Column('coordinates_x', sa.Float(), nullable=True),
|
||||
sa.Column('coordinates_y', sa.Float(), nullable=True),
|
||||
sa.Column('coordinates_z', sa.Float(), nullable=True),
|
||||
sa.Column('config', sa.JSON(), nullable=True),
|
||||
sa.Column('capabilities', postgresql.ARRAY(sa.String()), nullable=True),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('tags', postgresql.ARRAY(sa.String()), nullable=True),
|
||||
sa.CheckConstraint("status IN ('active', 'inactive', 'maintenance', 'error')", name='check_device_status'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('mac_address')
|
||||
)
|
||||
|
||||
# Create indexes for devices table
|
||||
op.create_index('idx_device_mac_address', 'devices', ['mac_address'])
|
||||
op.create_index('idx_device_status', 'devices', ['status'])
|
||||
op.create_index('idx_device_type', 'devices', ['device_type'])
|
||||
|
||||
# Create sessions table
|
||||
op.create_table(
|
||||
'sessions',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('started_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('ended_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('duration_seconds', sa.Integer(), nullable=True),
|
||||
sa.Column('status', sa.String(length=20), nullable=False),
|
||||
sa.Column('config', sa.JSON(), nullable=True),
|
||||
sa.Column('device_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('tags', postgresql.ARRAY(sa.String()), nullable=True),
|
||||
sa.Column('metadata', sa.JSON(), nullable=True),
|
||||
sa.Column('total_frames', sa.Integer(), nullable=False),
|
||||
sa.Column('processed_frames', sa.Integer(), nullable=False),
|
||||
sa.Column('error_count', sa.Integer(), nullable=False),
|
||||
sa.CheckConstraint("status IN ('active', 'completed', 'failed', 'cancelled')", name='check_session_status'),
|
||||
sa.CheckConstraint('total_frames >= 0', name='check_total_frames_positive'),
|
||||
sa.CheckConstraint('processed_frames >= 0', name='check_processed_frames_positive'),
|
||||
sa.CheckConstraint('error_count >= 0', name='check_error_count_positive'),
|
||||
sa.ForeignKeyConstraint(['device_id'], ['devices.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
|
||||
# Create indexes for sessions table
|
||||
op.create_index('idx_session_device_id', 'sessions', ['device_id'])
|
||||
op.create_index('idx_session_status', 'sessions', ['status'])
|
||||
op.create_index('idx_session_started_at', 'sessions', ['started_at'])
|
||||
|
||||
# Create csi_data table
|
||||
op.create_table(
|
||||
'csi_data',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('sequence_number', sa.Integer(), nullable=False),
|
||||
sa.Column('timestamp_ns', sa.BigInteger(), nullable=False),
|
||||
sa.Column('device_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('session_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column('amplitude', postgresql.ARRAY(sa.Float()), nullable=False),
|
||||
sa.Column('phase', postgresql.ARRAY(sa.Float()), nullable=False),
|
||||
sa.Column('frequency', sa.Float(), nullable=False),
|
||||
sa.Column('bandwidth', sa.Float(), nullable=False),
|
||||
sa.Column('rssi', sa.Float(), nullable=True),
|
||||
sa.Column('snr', sa.Float(), nullable=True),
|
||||
sa.Column('noise_floor', sa.Float(), nullable=True),
|
||||
sa.Column('tx_antenna', sa.Integer(), nullable=True),
|
||||
sa.Column('rx_antenna', sa.Integer(), nullable=True),
|
||||
sa.Column('num_subcarriers', sa.Integer(), nullable=False),
|
||||
sa.Column('processing_status', sa.String(length=20), nullable=False),
|
||||
sa.Column('processed_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('quality_score', sa.Float(), nullable=True),
|
||||
sa.Column('is_valid', sa.Boolean(), nullable=False),
|
||||
sa.Column('metadata', sa.JSON(), nullable=True),
|
||||
sa.CheckConstraint('frequency > 0', name='check_frequency_positive'),
|
||||
sa.CheckConstraint('bandwidth > 0', name='check_bandwidth_positive'),
|
||||
sa.CheckConstraint('num_subcarriers > 0', name='check_subcarriers_positive'),
|
||||
sa.CheckConstraint("processing_status IN ('pending', 'processing', 'completed', 'failed')", name='check_processing_status'),
|
||||
sa.ForeignKeyConstraint(['device_id'], ['devices.id'], ),
|
||||
sa.ForeignKeyConstraint(['session_id'], ['sessions.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('device_id', 'sequence_number', 'timestamp_ns', name='uq_csi_device_seq_time')
|
||||
)
|
||||
|
||||
# Create indexes for csi_data table
|
||||
op.create_index('idx_csi_device_id', 'csi_data', ['device_id'])
|
||||
op.create_index('idx_csi_session_id', 'csi_data', ['session_id'])
|
||||
op.create_index('idx_csi_timestamp', 'csi_data', ['timestamp_ns'])
|
||||
op.create_index('idx_csi_sequence', 'csi_data', ['sequence_number'])
|
||||
op.create_index('idx_csi_processing_status', 'csi_data', ['processing_status'])
|
||||
|
||||
# Create pose_detections table
|
||||
op.create_table(
|
||||
'pose_detections',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('frame_number', sa.Integer(), nullable=False),
|
||||
sa.Column('timestamp_ns', sa.BigInteger(), nullable=False),
|
||||
sa.Column('session_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('person_count', sa.Integer(), nullable=False),
|
||||
sa.Column('keypoints', sa.JSON(), nullable=True),
|
||||
sa.Column('bounding_boxes', sa.JSON(), nullable=True),
|
||||
sa.Column('detection_confidence', sa.Float(), nullable=True),
|
||||
sa.Column('pose_confidence', sa.Float(), nullable=True),
|
||||
sa.Column('overall_confidence', sa.Float(), nullable=True),
|
||||
sa.Column('processing_time_ms', sa.Float(), nullable=True),
|
||||
sa.Column('model_version', sa.String(length=50), nullable=True),
|
||||
sa.Column('algorithm', sa.String(length=100), nullable=True),
|
||||
sa.Column('image_quality', sa.Float(), nullable=True),
|
||||
sa.Column('pose_quality', sa.Float(), nullable=True),
|
||||
sa.Column('is_valid', sa.Boolean(), nullable=False),
|
||||
sa.Column('metadata', sa.JSON(), nullable=True),
|
||||
sa.CheckConstraint('person_count >= 0', name='check_person_count_positive'),
|
||||
sa.CheckConstraint('detection_confidence >= 0 AND detection_confidence <= 1', name='check_detection_confidence_range'),
|
||||
sa.CheckConstraint('pose_confidence >= 0 AND pose_confidence <= 1', name='check_pose_confidence_range'),
|
||||
sa.CheckConstraint('overall_confidence >= 0 AND overall_confidence <= 1', name='check_overall_confidence_range'),
|
||||
sa.ForeignKeyConstraint(['session_id'], ['sessions.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
|
||||
# Create indexes for pose_detections table
|
||||
op.create_index('idx_pose_session_id', 'pose_detections', ['session_id'])
|
||||
op.create_index('idx_pose_timestamp', 'pose_detections', ['timestamp_ns'])
|
||||
op.create_index('idx_pose_frame', 'pose_detections', ['frame_number'])
|
||||
op.create_index('idx_pose_person_count', 'pose_detections', ['person_count'])
|
||||
|
||||
# Create system_metrics table
|
||||
op.create_table(
|
||||
'system_metrics',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('metric_name', sa.String(length=255), nullable=False),
|
||||
sa.Column('metric_type', sa.String(length=50), nullable=False),
|
||||
sa.Column('value', sa.Float(), nullable=False),
|
||||
sa.Column('unit', sa.String(length=50), nullable=True),
|
||||
sa.Column('labels', sa.JSON(), nullable=True),
|
||||
sa.Column('tags', postgresql.ARRAY(sa.String()), nullable=True),
|
||||
sa.Column('source', sa.String(length=255), nullable=True),
|
||||
sa.Column('component', sa.String(length=100), nullable=True),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('metadata', sa.JSON(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
|
||||
# Create indexes for system_metrics table
|
||||
op.create_index('idx_metric_name', 'system_metrics', ['metric_name'])
|
||||
op.create_index('idx_metric_type', 'system_metrics', ['metric_type'])
|
||||
op.create_index('idx_metric_created_at', 'system_metrics', ['created_at'])
|
||||
op.create_index('idx_metric_source', 'system_metrics', ['source'])
|
||||
op.create_index('idx_metric_component', 'system_metrics', ['component'])
|
||||
|
||||
# Create audit_logs table
|
||||
op.create_table(
|
||||
'audit_logs',
|
||||
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('event_type', sa.String(length=100), nullable=False),
|
||||
sa.Column('event_name', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('user_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('session_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('ip_address', sa.String(length=45), nullable=True),
|
||||
sa.Column('user_agent', sa.Text(), nullable=True),
|
||||
sa.Column('resource_type', sa.String(length=100), nullable=True),
|
||||
sa.Column('resource_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('before_state', sa.JSON(), nullable=True),
|
||||
sa.Column('after_state', sa.JSON(), nullable=True),
|
||||
sa.Column('changes', sa.JSON(), nullable=True),
|
||||
sa.Column('success', sa.Boolean(), nullable=False),
|
||||
sa.Column('error_message', sa.Text(), nullable=True),
|
||||
sa.Column('metadata', sa.JSON(), nullable=True),
|
||||
sa.Column('tags', postgresql.ARRAY(sa.String()), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
|
||||
# Create indexes for audit_logs table
|
||||
op.create_index('idx_audit_event_type', 'audit_logs', ['event_type'])
|
||||
op.create_index('idx_audit_user_id', 'audit_logs', ['user_id'])
|
||||
op.create_index('idx_audit_resource', 'audit_logs', ['resource_type', 'resource_id'])
|
||||
op.create_index('idx_audit_created_at', 'audit_logs', ['created_at'])
|
||||
op.create_index('idx_audit_success', 'audit_logs', ['success'])
|
||||
|
||||
# Create triggers for updated_at columns
|
||||
op.execute("""
|
||||
CREATE OR REPLACE FUNCTION update_updated_at_column()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
NEW.updated_at = now();
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ language 'plpgsql';
|
||||
""")
|
||||
|
||||
# Add triggers to all tables with updated_at column
|
||||
tables_with_updated_at = [
|
||||
'devices', 'sessions', 'csi_data', 'pose_detections',
|
||||
'system_metrics', 'audit_logs'
|
||||
]
|
||||
|
||||
for table in tables_with_updated_at:
|
||||
op.execute(f"""
|
||||
CREATE TRIGGER update_{table}_updated_at
|
||||
BEFORE UPDATE ON {table}
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_updated_at_column();
|
||||
""")
|
||||
|
||||
# Insert initial data
|
||||
_insert_initial_data()
|
||||
|
||||
|
||||
def downgrade():
|
||||
"""Drop all tables and functions."""
|
||||
|
||||
# Drop triggers first
|
||||
tables_with_updated_at = [
|
||||
'devices', 'sessions', 'csi_data', 'pose_detections',
|
||||
'system_metrics', 'audit_logs'
|
||||
]
|
||||
|
||||
for table in tables_with_updated_at:
|
||||
op.execute(f"DROP TRIGGER IF EXISTS update_{table}_updated_at ON {table};")
|
||||
|
||||
# Drop function
|
||||
op.execute("DROP FUNCTION IF EXISTS update_updated_at_column();")
|
||||
|
||||
# Drop tables in reverse order (respecting foreign key constraints)
|
||||
op.drop_table('audit_logs')
|
||||
op.drop_table('system_metrics')
|
||||
op.drop_table('pose_detections')
|
||||
op.drop_table('csi_data')
|
||||
op.drop_table('sessions')
|
||||
op.drop_table('devices')
|
||||
|
||||
|
||||
def _insert_initial_data():
|
||||
"""Insert initial data into tables."""
|
||||
|
||||
# Insert sample device
|
||||
op.execute("""
|
||||
INSERT INTO devices (
|
||||
id, name, device_type, mac_address, ip_address, status,
|
||||
firmware_version, hardware_version, location_name, room_id,
|
||||
coordinates_x, coordinates_y, coordinates_z,
|
||||
config, capabilities, description, tags
|
||||
) VALUES (
|
||||
gen_random_uuid(),
|
||||
'Demo Router',
|
||||
'router',
|
||||
'00:11:22:33:44:55',
|
||||
'192.168.1.1',
|
||||
'active',
|
||||
'1.0.0',
|
||||
'v1.0',
|
||||
'Living Room',
|
||||
'room_001',
|
||||
0.0,
|
||||
0.0,
|
||||
2.5,
|
||||
'{"channel": 6, "power": 20, "bandwidth": 80}',
|
||||
ARRAY['wifi6', 'csi', 'beamforming'],
|
||||
'Demo WiFi router for testing',
|
||||
ARRAY['demo', 'testing']
|
||||
);
|
||||
""")
|
||||
|
||||
# Insert sample session
|
||||
op.execute("""
|
||||
INSERT INTO sessions (
|
||||
id, name, description, started_at, status, config,
|
||||
device_id, tags, metadata, total_frames, processed_frames, error_count
|
||||
) VALUES (
|
||||
gen_random_uuid(),
|
||||
'Demo Session',
|
||||
'Initial demo session for testing',
|
||||
now(),
|
||||
'active',
|
||||
'{"duration": 3600, "sampling_rate": 100}',
|
||||
(SELECT id FROM devices WHERE name = 'Demo Router' LIMIT 1),
|
||||
ARRAY['demo', 'initial'],
|
||||
'{"purpose": "testing", "environment": "lab"}',
|
||||
0,
|
||||
0,
|
||||
0
|
||||
);
|
||||
""")
|
||||
|
||||
# Insert initial system metrics
|
||||
metrics_data = [
|
||||
('system_startup', 'counter', 1.0, 'count', 'system', 'application'),
|
||||
('database_connections', 'gauge', 0.0, 'count', 'database', 'postgresql'),
|
||||
('api_requests_total', 'counter', 0.0, 'count', 'api', 'http'),
|
||||
('memory_usage', 'gauge', 0.0, 'bytes', 'system', 'memory'),
|
||||
('cpu_usage', 'gauge', 0.0, 'percent', 'system', 'cpu'),
|
||||
]
|
||||
|
||||
for metric_name, metric_type, value, unit, source, component in metrics_data:
|
||||
op.execute(f"""
|
||||
INSERT INTO system_metrics (
|
||||
id, metric_name, metric_type, value, unit, source, component,
|
||||
description, metadata
|
||||
) VALUES (
|
||||
gen_random_uuid(),
|
||||
'{metric_name}',
|
||||
'{metric_type}',
|
||||
{value},
|
||||
'{unit}',
|
||||
'{source}',
|
||||
'{component}',
|
||||
'Initial {metric_name} metric',
|
||||
'{{"initial": true, "version": "1.0.0"}}'
|
||||
);
|
||||
""")
|
||||
|
||||
# Insert initial audit log
|
||||
op.execute("""
|
||||
INSERT INTO audit_logs (
|
||||
id, event_type, event_name, description, user_id, success,
|
||||
resource_type, metadata
|
||||
) VALUES (
|
||||
gen_random_uuid(),
|
||||
'system',
|
||||
'database_migration',
|
||||
'Initial database schema created',
|
||||
'system',
|
||||
true,
|
||||
'database',
|
||||
'{"migration": "001_initial", "version": "1.0.0"}'
|
||||
);
|
||||
""")
|
||||
109
v1/src/database/migrations/env.py
Normal file
109
v1/src/database/migrations/env.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Alembic environment configuration for WiFi-DensePose API."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from logging.config import fileConfig
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
|
||||
from alembic import context
|
||||
|
||||
# Add the project root to the Python path
|
||||
project_root = Path(__file__).parent.parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Import the models and settings
|
||||
from src.database.models import Base
|
||||
from src.config.settings import get_settings
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
|
||||
def get_database_url():
|
||||
"""Get the database URL from settings."""
|
||||
try:
|
||||
settings = get_settings()
|
||||
return settings.get_database_url()
|
||||
except Exception:
|
||||
# Fallback to SQLite if settings can't be loaded
|
||||
return "sqlite:///./data/wifi_densepose_fallback.db"
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = get_database_url()
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
"""Run migrations with a database connection."""
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""Run migrations in async mode."""
|
||||
configuration = config.get_section(config.config_ini_section)
|
||||
configuration["sqlalchemy.url"] = get_database_url()
|
||||
|
||||
connectable = async_engine_from_config(
|
||||
configuration,
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode."""
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
26
v1/src/database/migrations/script.py.mako
Normal file
26
v1/src/database/migrations/script.py.mako
Normal file
@@ -0,0 +1,26 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = ${repr(up_revision)}
|
||||
down_revision = ${repr(down_revision)}
|
||||
branch_labels = ${repr(branch_labels)}
|
||||
depends_on = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade database schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade database schema."""
|
||||
${downgrades if downgrades else "pass"}
|
||||
60
v1/src/database/model_types.py
Normal file
60
v1/src/database/model_types.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Database type compatibility helpers for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
from typing import Type, Any
|
||||
from sqlalchemy import String, Text, JSON
|
||||
from sqlalchemy.dialects.postgresql import ARRAY as PostgreSQL_ARRAY
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.sql import sqltypes
|
||||
|
||||
|
||||
class ArrayType(sqltypes.TypeDecorator):
|
||||
"""Array type that works with both PostgreSQL and SQLite."""
|
||||
|
||||
impl = Text
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, item_type: Type = String):
|
||||
super().__init__()
|
||||
self.item_type = item_type
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
"""Load dialect-specific implementation."""
|
||||
if dialect.name == 'postgresql':
|
||||
return dialect.type_descriptor(PostgreSQL_ARRAY(self.item_type))
|
||||
else:
|
||||
# For SQLite and others, use JSON
|
||||
return dialect.type_descriptor(JSON)
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
"""Process value before saving to database."""
|
||||
if value is None:
|
||||
return value
|
||||
|
||||
if dialect.name == 'postgresql':
|
||||
return value
|
||||
else:
|
||||
# For SQLite, convert to JSON
|
||||
return value if isinstance(value, (list, type(None))) else list(value)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
"""Process value after loading from database."""
|
||||
if value is None:
|
||||
return value
|
||||
|
||||
if dialect.name == 'postgresql':
|
||||
return value
|
||||
else:
|
||||
# For SQLite, value is already a list from JSON
|
||||
return value if isinstance(value, list) else []
|
||||
|
||||
|
||||
def get_array_type(item_type: Type = String) -> Type:
|
||||
"""Get appropriate array type based on database."""
|
||||
return ArrayType(item_type)
|
||||
|
||||
|
||||
# Convenience types
|
||||
StringArray = ArrayType(String)
|
||||
FloatArray = ArrayType(sqltypes.Float)
|
||||
498
v1/src/database/models.py
Normal file
498
v1/src/database/models.py
Normal file
@@ -0,0 +1,498 @@
|
||||
"""
|
||||
SQLAlchemy models for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import (
|
||||
Column, String, Integer, Float, Boolean, DateTime, Text, JSON,
|
||||
ForeignKey, Index, UniqueConstraint, CheckConstraint
|
||||
)
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import relationship, validates
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
# Import custom array type for compatibility
|
||||
from src.database.model_types import StringArray, FloatArray
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
"""Mixin for timestamp fields."""
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
|
||||
|
||||
class UUIDMixin:
|
||||
"""Mixin for UUID primary key."""
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, nullable=False)
|
||||
|
||||
|
||||
class DeviceStatus(str, Enum):
|
||||
"""Device status enumeration."""
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
MAINTENANCE = "maintenance"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class SessionStatus(str, Enum):
|
||||
"""Session status enumeration."""
|
||||
ACTIVE = "active"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class ProcessingStatus(str, Enum):
|
||||
"""Processing status enumeration."""
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class Device(Base, UUIDMixin, TimestampMixin):
|
||||
"""Device model for WiFi routers and sensors."""
|
||||
__tablename__ = "devices"
|
||||
|
||||
# Basic device information
|
||||
name = Column(String(255), nullable=False)
|
||||
device_type = Column(String(50), nullable=False) # router, sensor, etc.
|
||||
mac_address = Column(String(17), unique=True, nullable=False)
|
||||
ip_address = Column(String(45), nullable=True) # IPv4 or IPv6
|
||||
|
||||
# Device status and configuration
|
||||
status = Column(String(20), default=DeviceStatus.INACTIVE, nullable=False)
|
||||
firmware_version = Column(String(50), nullable=True)
|
||||
hardware_version = Column(String(50), nullable=True)
|
||||
|
||||
# Location information
|
||||
location_name = Column(String(255), nullable=True)
|
||||
room_id = Column(String(100), nullable=True)
|
||||
coordinates_x = Column(Float, nullable=True)
|
||||
coordinates_y = Column(Float, nullable=True)
|
||||
coordinates_z = Column(Float, nullable=True)
|
||||
|
||||
# Configuration
|
||||
config = Column(JSON, nullable=True)
|
||||
capabilities = Column(StringArray, nullable=True)
|
||||
|
||||
# Metadata
|
||||
description = Column(Text, nullable=True)
|
||||
tags = Column(StringArray, nullable=True)
|
||||
|
||||
# Relationships
|
||||
sessions = relationship("Session", back_populates="device", cascade="all, delete-orphan")
|
||||
csi_data = relationship("CSIData", back_populates="device", cascade="all, delete-orphan")
|
||||
|
||||
# Constraints and indexes
|
||||
__table_args__ = (
|
||||
Index("idx_device_mac_address", "mac_address"),
|
||||
Index("idx_device_status", "status"),
|
||||
Index("idx_device_type", "device_type"),
|
||||
CheckConstraint("status IN ('active', 'inactive', 'maintenance', 'error')", name="check_device_status"),
|
||||
)
|
||||
|
||||
@validates('mac_address')
|
||||
def validate_mac_address(self, key, address):
|
||||
"""Validate MAC address format."""
|
||||
if address and len(address) == 17:
|
||||
# Basic MAC address format validation
|
||||
parts = address.split(':')
|
||||
if len(parts) == 6 and all(len(part) == 2 for part in parts):
|
||||
return address.lower()
|
||||
raise ValueError("Invalid MAC address format")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"name": self.name,
|
||||
"device_type": self.device_type,
|
||||
"mac_address": self.mac_address,
|
||||
"ip_address": self.ip_address,
|
||||
"status": self.status,
|
||||
"firmware_version": self.firmware_version,
|
||||
"hardware_version": self.hardware_version,
|
||||
"location_name": self.location_name,
|
||||
"room_id": self.room_id,
|
||||
"coordinates": {
|
||||
"x": self.coordinates_x,
|
||||
"y": self.coordinates_y,
|
||||
"z": self.coordinates_z,
|
||||
} if any([self.coordinates_x, self.coordinates_y, self.coordinates_z]) else None,
|
||||
"config": self.config,
|
||||
"capabilities": self.capabilities,
|
||||
"description": self.description,
|
||||
"tags": self.tags,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
class Session(Base, UUIDMixin, TimestampMixin):
|
||||
"""Session model for tracking data collection sessions."""
|
||||
__tablename__ = "sessions"
|
||||
|
||||
# Session identification
|
||||
name = Column(String(255), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
# Session timing
|
||||
started_at = Column(DateTime(timezone=True), nullable=True)
|
||||
ended_at = Column(DateTime(timezone=True), nullable=True)
|
||||
duration_seconds = Column(Integer, nullable=True)
|
||||
|
||||
# Session status and configuration
|
||||
status = Column(String(20), default=SessionStatus.ACTIVE, nullable=False)
|
||||
config = Column(JSON, nullable=True)
|
||||
|
||||
# Device relationship
|
||||
device_id = Column(UUID(as_uuid=True), ForeignKey("devices.id"), nullable=False)
|
||||
device = relationship("Device", back_populates="sessions")
|
||||
|
||||
# Data relationships
|
||||
csi_data = relationship("CSIData", back_populates="session", cascade="all, delete-orphan")
|
||||
pose_detections = relationship("PoseDetection", back_populates="session", cascade="all, delete-orphan")
|
||||
|
||||
# Metadata
|
||||
tags = Column(StringArray, nullable=True)
|
||||
meta_data = Column(JSON, nullable=True)
|
||||
|
||||
# Statistics
|
||||
total_frames = Column(Integer, default=0, nullable=False)
|
||||
processed_frames = Column(Integer, default=0, nullable=False)
|
||||
error_count = Column(Integer, default=0, nullable=False)
|
||||
|
||||
# Constraints and indexes
|
||||
__table_args__ = (
|
||||
Index("idx_session_device_id", "device_id"),
|
||||
Index("idx_session_status", "status"),
|
||||
Index("idx_session_started_at", "started_at"),
|
||||
CheckConstraint("status IN ('active', 'completed', 'failed', 'cancelled')", name="check_session_status"),
|
||||
CheckConstraint("total_frames >= 0", name="check_total_frames_positive"),
|
||||
CheckConstraint("processed_frames >= 0", name="check_processed_frames_positive"),
|
||||
CheckConstraint("error_count >= 0", name="check_error_count_positive"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||||
"ended_at": self.ended_at.isoformat() if self.ended_at else None,
|
||||
"duration_seconds": self.duration_seconds,
|
||||
"status": self.status,
|
||||
"config": self.config,
|
||||
"device_id": str(self.device_id),
|
||||
"tags": self.tags,
|
||||
"metadata": self.meta_data,
|
||||
"total_frames": self.total_frames,
|
||||
"processed_frames": self.processed_frames,
|
||||
"error_count": self.error_count,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
class CSIData(Base, UUIDMixin, TimestampMixin):
|
||||
"""CSI (Channel State Information) data model."""
|
||||
__tablename__ = "csi_data"
|
||||
|
||||
# Data identification
|
||||
sequence_number = Column(Integer, nullable=False)
|
||||
timestamp_ns = Column(Integer, nullable=False) # Nanosecond timestamp
|
||||
|
||||
# Device and session relationships
|
||||
device_id = Column(UUID(as_uuid=True), ForeignKey("devices.id"), nullable=False)
|
||||
session_id = Column(UUID(as_uuid=True), ForeignKey("sessions.id"), nullable=True)
|
||||
|
||||
device = relationship("Device", back_populates="csi_data")
|
||||
session = relationship("Session", back_populates="csi_data")
|
||||
|
||||
# CSI data
|
||||
amplitude = Column(FloatArray, nullable=False)
|
||||
phase = Column(FloatArray, nullable=False)
|
||||
frequency = Column(Float, nullable=False) # MHz
|
||||
bandwidth = Column(Float, nullable=False) # MHz
|
||||
|
||||
# Signal characteristics
|
||||
rssi = Column(Float, nullable=True) # dBm
|
||||
snr = Column(Float, nullable=True) # dB
|
||||
noise_floor = Column(Float, nullable=True) # dBm
|
||||
|
||||
# Antenna information
|
||||
tx_antenna = Column(Integer, nullable=True)
|
||||
rx_antenna = Column(Integer, nullable=True)
|
||||
num_subcarriers = Column(Integer, nullable=False)
|
||||
|
||||
# Processing status
|
||||
processing_status = Column(String(20), default=ProcessingStatus.PENDING, nullable=False)
|
||||
processed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Quality metrics
|
||||
quality_score = Column(Float, nullable=True)
|
||||
is_valid = Column(Boolean, default=True, nullable=False)
|
||||
|
||||
# Metadata
|
||||
meta_data = Column(JSON, nullable=True)
|
||||
|
||||
# Constraints and indexes
|
||||
__table_args__ = (
|
||||
Index("idx_csi_device_id", "device_id"),
|
||||
Index("idx_csi_session_id", "session_id"),
|
||||
Index("idx_csi_timestamp", "timestamp_ns"),
|
||||
Index("idx_csi_sequence", "sequence_number"),
|
||||
Index("idx_csi_processing_status", "processing_status"),
|
||||
UniqueConstraint("device_id", "sequence_number", "timestamp_ns", name="uq_csi_device_seq_time"),
|
||||
CheckConstraint("frequency > 0", name="check_frequency_positive"),
|
||||
CheckConstraint("bandwidth > 0", name="check_bandwidth_positive"),
|
||||
CheckConstraint("num_subcarriers > 0", name="check_subcarriers_positive"),
|
||||
CheckConstraint("processing_status IN ('pending', 'processing', 'completed', 'failed')", name="check_processing_status"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"sequence_number": self.sequence_number,
|
||||
"timestamp_ns": self.timestamp_ns,
|
||||
"device_id": str(self.device_id),
|
||||
"session_id": str(self.session_id) if self.session_id else None,
|
||||
"amplitude": self.amplitude,
|
||||
"phase": self.phase,
|
||||
"frequency": self.frequency,
|
||||
"bandwidth": self.bandwidth,
|
||||
"rssi": self.rssi,
|
||||
"snr": self.snr,
|
||||
"noise_floor": self.noise_floor,
|
||||
"tx_antenna": self.tx_antenna,
|
||||
"rx_antenna": self.rx_antenna,
|
||||
"num_subcarriers": self.num_subcarriers,
|
||||
"processing_status": self.processing_status,
|
||||
"processed_at": self.processed_at.isoformat() if self.processed_at else None,
|
||||
"quality_score": self.quality_score,
|
||||
"is_valid": self.is_valid,
|
||||
"metadata": self.meta_data,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
class PoseDetection(Base, UUIDMixin, TimestampMixin):
|
||||
"""Pose detection results model."""
|
||||
__tablename__ = "pose_detections"
|
||||
|
||||
# Detection identification
|
||||
frame_number = Column(Integer, nullable=False)
|
||||
timestamp_ns = Column(Integer, nullable=False)
|
||||
|
||||
# Session relationship
|
||||
session_id = Column(UUID(as_uuid=True), ForeignKey("sessions.id"), nullable=False)
|
||||
session = relationship("Session", back_populates="pose_detections")
|
||||
|
||||
# Detection results
|
||||
person_count = Column(Integer, default=0, nullable=False)
|
||||
keypoints = Column(JSON, nullable=True) # Array of person keypoints
|
||||
bounding_boxes = Column(JSON, nullable=True) # Array of bounding boxes
|
||||
|
||||
# Confidence scores
|
||||
detection_confidence = Column(Float, nullable=True)
|
||||
pose_confidence = Column(Float, nullable=True)
|
||||
overall_confidence = Column(Float, nullable=True)
|
||||
|
||||
# Processing information
|
||||
processing_time_ms = Column(Float, nullable=True)
|
||||
model_version = Column(String(50), nullable=True)
|
||||
algorithm = Column(String(100), nullable=True)
|
||||
|
||||
# Quality metrics
|
||||
image_quality = Column(Float, nullable=True)
|
||||
pose_quality = Column(Float, nullable=True)
|
||||
is_valid = Column(Boolean, default=True, nullable=False)
|
||||
|
||||
# Metadata
|
||||
meta_data = Column(JSON, nullable=True)
|
||||
|
||||
# Constraints and indexes
|
||||
__table_args__ = (
|
||||
Index("idx_pose_session_id", "session_id"),
|
||||
Index("idx_pose_timestamp", "timestamp_ns"),
|
||||
Index("idx_pose_frame", "frame_number"),
|
||||
Index("idx_pose_person_count", "person_count"),
|
||||
CheckConstraint("person_count >= 0", name="check_person_count_positive"),
|
||||
CheckConstraint("detection_confidence >= 0 AND detection_confidence <= 1", name="check_detection_confidence_range"),
|
||||
CheckConstraint("pose_confidence >= 0 AND pose_confidence <= 1", name="check_pose_confidence_range"),
|
||||
CheckConstraint("overall_confidence >= 0 AND overall_confidence <= 1", name="check_overall_confidence_range"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"frame_number": self.frame_number,
|
||||
"timestamp_ns": self.timestamp_ns,
|
||||
"session_id": str(self.session_id),
|
||||
"person_count": self.person_count,
|
||||
"keypoints": self.keypoints,
|
||||
"bounding_boxes": self.bounding_boxes,
|
||||
"detection_confidence": self.detection_confidence,
|
||||
"pose_confidence": self.pose_confidence,
|
||||
"overall_confidence": self.overall_confidence,
|
||||
"processing_time_ms": self.processing_time_ms,
|
||||
"model_version": self.model_version,
|
||||
"algorithm": self.algorithm,
|
||||
"image_quality": self.image_quality,
|
||||
"pose_quality": self.pose_quality,
|
||||
"is_valid": self.is_valid,
|
||||
"metadata": self.meta_data,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
class SystemMetric(Base, UUIDMixin, TimestampMixin):
|
||||
"""System metrics model for monitoring."""
|
||||
__tablename__ = "system_metrics"
|
||||
|
||||
# Metric identification
|
||||
metric_name = Column(String(255), nullable=False)
|
||||
metric_type = Column(String(50), nullable=False) # counter, gauge, histogram
|
||||
|
||||
# Metric value
|
||||
value = Column(Float, nullable=False)
|
||||
unit = Column(String(50), nullable=True)
|
||||
|
||||
# Labels and tags
|
||||
labels = Column(JSON, nullable=True)
|
||||
tags = Column(StringArray, nullable=True)
|
||||
|
||||
# Source information
|
||||
source = Column(String(255), nullable=True)
|
||||
component = Column(String(100), nullable=True)
|
||||
|
||||
# Metadata
|
||||
description = Column(Text, nullable=True)
|
||||
meta_data = Column(JSON, nullable=True)
|
||||
|
||||
# Constraints and indexes
|
||||
__table_args__ = (
|
||||
Index("idx_metric_name", "metric_name"),
|
||||
Index("idx_metric_type", "metric_type"),
|
||||
Index("idx_metric_created_at", "created_at"),
|
||||
Index("idx_metric_source", "source"),
|
||||
Index("idx_metric_component", "component"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"metric_name": self.metric_name,
|
||||
"metric_type": self.metric_type,
|
||||
"value": self.value,
|
||||
"unit": self.unit,
|
||||
"labels": self.labels,
|
||||
"tags": self.tags,
|
||||
"source": self.source,
|
||||
"component": self.component,
|
||||
"description": self.description,
|
||||
"metadata": self.meta_data,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
class AuditLog(Base, UUIDMixin, TimestampMixin):
|
||||
"""Audit log model for tracking system events."""
|
||||
__tablename__ = "audit_logs"
|
||||
|
||||
# Event information
|
||||
event_type = Column(String(100), nullable=False)
|
||||
event_name = Column(String(255), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
# User and session information
|
||||
user_id = Column(String(255), nullable=True)
|
||||
session_id = Column(String(255), nullable=True)
|
||||
ip_address = Column(String(45), nullable=True)
|
||||
user_agent = Column(Text, nullable=True)
|
||||
|
||||
# Resource information
|
||||
resource_type = Column(String(100), nullable=True)
|
||||
resource_id = Column(String(255), nullable=True)
|
||||
|
||||
# Event details
|
||||
before_state = Column(JSON, nullable=True)
|
||||
after_state = Column(JSON, nullable=True)
|
||||
changes = Column(JSON, nullable=True)
|
||||
|
||||
# Result information
|
||||
success = Column(Boolean, nullable=False)
|
||||
error_message = Column(Text, nullable=True)
|
||||
|
||||
# Metadata
|
||||
meta_data = Column(JSON, nullable=True)
|
||||
tags = Column(StringArray, nullable=True)
|
||||
|
||||
# Constraints and indexes
|
||||
__table_args__ = (
|
||||
Index("idx_audit_event_type", "event_type"),
|
||||
Index("idx_audit_user_id", "user_id"),
|
||||
Index("idx_audit_resource", "resource_type", "resource_id"),
|
||||
Index("idx_audit_created_at", "created_at"),
|
||||
Index("idx_audit_success", "success"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"event_type": self.event_type,
|
||||
"event_name": self.event_name,
|
||||
"description": self.description,
|
||||
"user_id": self.user_id,
|
||||
"session_id": self.session_id,
|
||||
"ip_address": self.ip_address,
|
||||
"user_agent": self.user_agent,
|
||||
"resource_type": self.resource_type,
|
||||
"resource_id": self.resource_id,
|
||||
"before_state": self.before_state,
|
||||
"after_state": self.after_state,
|
||||
"changes": self.changes,
|
||||
"success": self.success,
|
||||
"error_message": self.error_message,
|
||||
"metadata": self.meta_data,
|
||||
"tags": self.tags,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
# Model registry for easy access
|
||||
MODEL_REGISTRY = {
|
||||
"Device": Device,
|
||||
"Session": Session,
|
||||
"CSIData": CSIData,
|
||||
"PoseDetection": PoseDetection,
|
||||
"SystemMetric": SystemMetric,
|
||||
"AuditLog": AuditLog,
|
||||
}
|
||||
|
||||
|
||||
def get_model_by_name(name: str):
|
||||
"""Get model class by name."""
|
||||
return MODEL_REGISTRY.get(name)
|
||||
|
||||
|
||||
def get_all_models() -> List:
|
||||
"""Get all model classes."""
|
||||
return list(MODEL_REGISTRY.values())
|
||||
1
v1/src/hardware/__init__.py
Normal file
1
v1/src/hardware/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Hardware abstraction layer for WiFi-DensePose system."""
|
||||
326
v1/src/hardware/csi_extractor.py
Normal file
326
v1/src/hardware/csi_extractor.py
Normal file
@@ -0,0 +1,326 @@
|
||||
"""CSI data extraction from WiFi hardware using Test-Driven Development approach."""
|
||||
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Any, Optional, Callable, Protocol
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
|
||||
|
||||
class CSIParseError(Exception):
|
||||
"""Exception raised for CSI parsing errors."""
|
||||
pass
|
||||
|
||||
|
||||
class CSIValidationError(Exception):
|
||||
"""Exception raised for CSI validation errors."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CSIData:
|
||||
"""Data structure for CSI measurements."""
|
||||
timestamp: datetime
|
||||
amplitude: np.ndarray
|
||||
phase: np.ndarray
|
||||
frequency: float
|
||||
bandwidth: float
|
||||
num_subcarriers: int
|
||||
num_antennas: int
|
||||
snr: float
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class CSIParser(Protocol):
|
||||
"""Protocol for CSI data parsers."""
|
||||
|
||||
def parse(self, raw_data: bytes) -> CSIData:
|
||||
"""Parse raw CSI data into structured format."""
|
||||
...
|
||||
|
||||
|
||||
class ESP32CSIParser:
|
||||
"""Parser for ESP32 CSI data format."""
|
||||
|
||||
def parse(self, raw_data: bytes) -> CSIData:
|
||||
"""Parse ESP32 CSI data format.
|
||||
|
||||
Args:
|
||||
raw_data: Raw bytes from ESP32
|
||||
|
||||
Returns:
|
||||
Parsed CSI data
|
||||
|
||||
Raises:
|
||||
CSIParseError: If data format is invalid
|
||||
"""
|
||||
if not raw_data:
|
||||
raise CSIParseError("Empty data received")
|
||||
|
||||
try:
|
||||
data_str = raw_data.decode('utf-8')
|
||||
if not data_str.startswith('CSI_DATA:'):
|
||||
raise CSIParseError("Invalid ESP32 CSI data format")
|
||||
|
||||
# Parse ESP32 format: CSI_DATA:timestamp,antennas,subcarriers,freq,bw,snr,[amp],[phase]
|
||||
parts = data_str[9:].split(',') # Remove 'CSI_DATA:' prefix
|
||||
|
||||
timestamp_ms = int(parts[0])
|
||||
num_antennas = int(parts[1])
|
||||
num_subcarriers = int(parts[2])
|
||||
frequency_mhz = float(parts[3])
|
||||
bandwidth_mhz = float(parts[4])
|
||||
snr = float(parts[5])
|
||||
|
||||
# Convert to proper units
|
||||
frequency = frequency_mhz * 1e6 # MHz to Hz
|
||||
bandwidth = bandwidth_mhz * 1e6 # MHz to Hz
|
||||
|
||||
# Parse amplitude and phase arrays (simplified for now)
|
||||
# In real implementation, this would parse actual CSI matrix data
|
||||
amplitude = np.random.rand(num_antennas, num_subcarriers)
|
||||
phase = np.random.rand(num_antennas, num_subcarriers)
|
||||
|
||||
return CSIData(
|
||||
timestamp=datetime.fromtimestamp(timestamp_ms / 1000, tz=timezone.utc),
|
||||
amplitude=amplitude,
|
||||
phase=phase,
|
||||
frequency=frequency,
|
||||
bandwidth=bandwidth,
|
||||
num_subcarriers=num_subcarriers,
|
||||
num_antennas=num_antennas,
|
||||
snr=snr,
|
||||
metadata={'source': 'esp32', 'raw_length': len(raw_data)}
|
||||
)
|
||||
|
||||
except (ValueError, IndexError) as e:
|
||||
raise CSIParseError(f"Failed to parse ESP32 data: {e}")
|
||||
|
||||
|
||||
class RouterCSIParser:
|
||||
"""Parser for router CSI data format."""
|
||||
|
||||
def parse(self, raw_data: bytes) -> CSIData:
|
||||
"""Parse router CSI data format.
|
||||
|
||||
Args:
|
||||
raw_data: Raw bytes from router
|
||||
|
||||
Returns:
|
||||
Parsed CSI data
|
||||
|
||||
Raises:
|
||||
CSIParseError: If data format is invalid
|
||||
"""
|
||||
if not raw_data:
|
||||
raise CSIParseError("Empty data received")
|
||||
|
||||
# Handle different router formats
|
||||
data_str = raw_data.decode('utf-8')
|
||||
|
||||
if data_str.startswith('ATHEROS_CSI:'):
|
||||
return self._parse_atheros_format(raw_data)
|
||||
else:
|
||||
raise CSIParseError("Unknown router CSI format")
|
||||
|
||||
def _parse_atheros_format(self, raw_data: bytes) -> CSIData:
|
||||
"""Parse Atheros CSI format (placeholder implementation)."""
|
||||
# This would implement actual Atheros CSI parsing
|
||||
# For now, return mock data for testing
|
||||
return CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=12.0,
|
||||
metadata={'source': 'atheros_router'}
|
||||
)
|
||||
|
||||
|
||||
class CSIExtractor:
|
||||
"""Main CSI data extractor supporting multiple hardware types."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any], logger: Optional[logging.Logger] = None):
|
||||
"""Initialize CSI extractor.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
logger: Optional logger instance
|
||||
|
||||
Raises:
|
||||
ValueError: If configuration is invalid
|
||||
"""
|
||||
self._validate_config(config)
|
||||
|
||||
self.config = config
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
self.hardware_type = config['hardware_type']
|
||||
self.sampling_rate = config['sampling_rate']
|
||||
self.buffer_size = config['buffer_size']
|
||||
self.timeout = config['timeout']
|
||||
self.validation_enabled = config.get('validation_enabled', True)
|
||||
self.retry_attempts = config.get('retry_attempts', 3)
|
||||
|
||||
# State management
|
||||
self.is_connected = False
|
||||
self.is_streaming = False
|
||||
|
||||
# Create appropriate parser
|
||||
if self.hardware_type == 'esp32':
|
||||
self.parser = ESP32CSIParser()
|
||||
elif self.hardware_type == 'router':
|
||||
self.parser = RouterCSIParser()
|
||||
else:
|
||||
raise ValueError(f"Unsupported hardware type: {self.hardware_type}")
|
||||
|
||||
def _validate_config(self, config: Dict[str, Any]) -> None:
|
||||
"""Validate configuration parameters.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If configuration is invalid
|
||||
"""
|
||||
required_fields = ['hardware_type', 'sampling_rate', 'buffer_size', 'timeout']
|
||||
missing_fields = [field for field in required_fields if field not in config]
|
||||
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required configuration: {missing_fields}")
|
||||
|
||||
if config['sampling_rate'] <= 0:
|
||||
raise ValueError("sampling_rate must be positive")
|
||||
|
||||
if config['buffer_size'] <= 0:
|
||||
raise ValueError("buffer_size must be positive")
|
||||
|
||||
if config['timeout'] <= 0:
|
||||
raise ValueError("timeout must be positive")
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Establish connection to CSI hardware.
|
||||
|
||||
Returns:
|
||||
True if connection successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
success = await self._establish_hardware_connection()
|
||||
self.is_connected = success
|
||||
return success
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to connect to hardware: {e}")
|
||||
self.is_connected = False
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from CSI hardware."""
|
||||
if self.is_connected:
|
||||
await self._close_hardware_connection()
|
||||
self.is_connected = False
|
||||
|
||||
async def extract_csi(self) -> CSIData:
|
||||
"""Extract CSI data from hardware.
|
||||
|
||||
Returns:
|
||||
Extracted CSI data
|
||||
|
||||
Raises:
|
||||
CSIParseError: If not connected or extraction fails
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise CSIParseError("Not connected to hardware")
|
||||
|
||||
# Retry mechanism for temporary failures
|
||||
for attempt in range(self.retry_attempts):
|
||||
try:
|
||||
raw_data = await self._read_raw_data()
|
||||
csi_data = self.parser.parse(raw_data)
|
||||
|
||||
if self.validation_enabled:
|
||||
self.validate_csi_data(csi_data)
|
||||
|
||||
return csi_data
|
||||
|
||||
except ConnectionError as e:
|
||||
if attempt < self.retry_attempts - 1:
|
||||
self.logger.warning(f"Extraction attempt {attempt + 1} failed, retrying: {e}")
|
||||
await asyncio.sleep(0.1) # Brief delay before retry
|
||||
else:
|
||||
raise CSIParseError(f"Extraction failed after {self.retry_attempts} attempts: {e}")
|
||||
|
||||
def validate_csi_data(self, csi_data: CSIData) -> bool:
|
||||
"""Validate CSI data structure and values.
|
||||
|
||||
Args:
|
||||
csi_data: CSI data to validate
|
||||
|
||||
Returns:
|
||||
True if valid
|
||||
|
||||
Raises:
|
||||
CSIValidationError: If data is invalid
|
||||
"""
|
||||
if csi_data.amplitude.size == 0:
|
||||
raise CSIValidationError("Empty amplitude data")
|
||||
|
||||
if csi_data.phase.size == 0:
|
||||
raise CSIValidationError("Empty phase data")
|
||||
|
||||
if csi_data.frequency <= 0:
|
||||
raise CSIValidationError("Invalid frequency")
|
||||
|
||||
if csi_data.bandwidth <= 0:
|
||||
raise CSIValidationError("Invalid bandwidth")
|
||||
|
||||
if csi_data.num_subcarriers <= 0:
|
||||
raise CSIValidationError("Invalid number of subcarriers")
|
||||
|
||||
if csi_data.num_antennas <= 0:
|
||||
raise CSIValidationError("Invalid number of antennas")
|
||||
|
||||
if csi_data.snr < -50 or csi_data.snr > 50: # Reasonable SNR range
|
||||
raise CSIValidationError("Invalid SNR value")
|
||||
|
||||
return True
|
||||
|
||||
async def start_streaming(self, callback: Callable[[CSIData], None]) -> None:
|
||||
"""Start streaming CSI data.
|
||||
|
||||
Args:
|
||||
callback: Function to call with each CSI sample
|
||||
"""
|
||||
self.is_streaming = True
|
||||
|
||||
try:
|
||||
while self.is_streaming:
|
||||
csi_data = await self.extract_csi()
|
||||
callback(csi_data)
|
||||
await asyncio.sleep(1.0 / self.sampling_rate)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Streaming error: {e}")
|
||||
finally:
|
||||
self.is_streaming = False
|
||||
|
||||
def stop_streaming(self) -> None:
|
||||
"""Stop streaming CSI data."""
|
||||
self.is_streaming = False
|
||||
|
||||
async def _establish_hardware_connection(self) -> bool:
|
||||
"""Establish connection to hardware (to be implemented by subclasses)."""
|
||||
# Placeholder implementation for testing
|
||||
return True
|
||||
|
||||
async def _close_hardware_connection(self) -> None:
|
||||
"""Close hardware connection (to be implemented by subclasses)."""
|
||||
# Placeholder implementation for testing
|
||||
pass
|
||||
|
||||
async def _read_raw_data(self) -> bytes:
|
||||
"""Read raw data from hardware (to be implemented by subclasses)."""
|
||||
# Placeholder implementation for testing
|
||||
return b"CSI_DATA:1234567890,3,56,2400,20,15.5,[1.0,2.0,3.0],[0.5,1.5,2.5]"
|
||||
238
v1/src/hardware/router_interface.py
Normal file
238
v1/src/hardware/router_interface.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""Router interface for WiFi-DensePose system using TDD approach."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
import asyncssh
|
||||
from datetime import datetime, timezone
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from .csi_extractor import CSIData
|
||||
except ImportError:
|
||||
# Handle import for testing
|
||||
from src.hardware.csi_extractor import CSIData
|
||||
|
||||
|
||||
class RouterConnectionError(Exception):
|
||||
"""Exception raised for router connection errors."""
|
||||
pass
|
||||
|
||||
|
||||
class RouterInterface:
|
||||
"""Interface for communicating with WiFi routers via SSH."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any], logger: Optional[logging.Logger] = None):
|
||||
"""Initialize router interface.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary with connection parameters
|
||||
logger: Optional logger instance
|
||||
|
||||
Raises:
|
||||
ValueError: If configuration is invalid
|
||||
"""
|
||||
self._validate_config(config)
|
||||
|
||||
self.config = config
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
|
||||
# Connection parameters
|
||||
self.host = config['host']
|
||||
self.port = config['port']
|
||||
self.username = config['username']
|
||||
self.password = config['password']
|
||||
self.command_timeout = config.get('command_timeout', 30)
|
||||
self.connection_timeout = config.get('connection_timeout', 10)
|
||||
self.max_retries = config.get('max_retries', 3)
|
||||
self.retry_delay = config.get('retry_delay', 1.0)
|
||||
|
||||
# Connection state
|
||||
self.is_connected = False
|
||||
self.ssh_client = None
|
||||
|
||||
def _validate_config(self, config: Dict[str, Any]) -> None:
|
||||
"""Validate configuration parameters.
|
||||
|
||||
Args:
|
||||
config: Configuration to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If configuration is invalid
|
||||
"""
|
||||
required_fields = ['host', 'port', 'username', 'password']
|
||||
missing_fields = [field for field in required_fields if field not in config]
|
||||
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required configuration: {missing_fields}")
|
||||
|
||||
if not isinstance(config['port'], int) or config['port'] <= 0:
|
||||
raise ValueError("Port must be a positive integer")
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Establish SSH connection to router.
|
||||
|
||||
Returns:
|
||||
True if connection successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
self.ssh_client = await asyncssh.connect(
|
||||
self.host,
|
||||
port=self.port,
|
||||
username=self.username,
|
||||
password=self.password,
|
||||
connect_timeout=self.connection_timeout
|
||||
)
|
||||
self.is_connected = True
|
||||
self.logger.info(f"Connected to router at {self.host}:{self.port}")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to connect to router: {e}")
|
||||
self.is_connected = False
|
||||
self.ssh_client = None
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from router."""
|
||||
if self.is_connected and self.ssh_client:
|
||||
self.ssh_client.close()
|
||||
self.is_connected = False
|
||||
self.ssh_client = None
|
||||
self.logger.info("Disconnected from router")
|
||||
|
||||
async def execute_command(self, command: str) -> str:
|
||||
"""Execute command on router via SSH.
|
||||
|
||||
Args:
|
||||
command: Command to execute
|
||||
|
||||
Returns:
|
||||
Command output
|
||||
|
||||
Raises:
|
||||
RouterConnectionError: If not connected or command fails
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise RouterConnectionError("Not connected to router")
|
||||
|
||||
# Retry mechanism for temporary failures
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
result = await self.ssh_client.run(command, timeout=self.command_timeout)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RouterConnectionError(f"Command failed: {result.stderr}")
|
||||
|
||||
return result.stdout
|
||||
|
||||
except ConnectionError as e:
|
||||
if attempt < self.max_retries - 1:
|
||||
self.logger.warning(f"Command attempt {attempt + 1} failed, retrying: {e}")
|
||||
await asyncio.sleep(self.retry_delay)
|
||||
else:
|
||||
raise RouterConnectionError(f"Command execution failed after {self.max_retries} retries: {e}")
|
||||
except Exception as e:
|
||||
raise RouterConnectionError(f"Command execution error: {e}")
|
||||
|
||||
async def get_csi_data(self) -> CSIData:
|
||||
"""Retrieve CSI data from router.
|
||||
|
||||
Returns:
|
||||
CSI data structure
|
||||
|
||||
Raises:
|
||||
RouterConnectionError: If data retrieval fails
|
||||
"""
|
||||
try:
|
||||
response = await self.execute_command("iwlist scan | grep CSI")
|
||||
return self._parse_csi_response(response)
|
||||
except Exception as e:
|
||||
raise RouterConnectionError(f"Failed to retrieve CSI data: {e}")
|
||||
|
||||
async def get_router_status(self) -> Dict[str, Any]:
|
||||
"""Get router system status.
|
||||
|
||||
Returns:
|
||||
Dictionary containing router status information
|
||||
|
||||
Raises:
|
||||
RouterConnectionError: If status retrieval fails
|
||||
"""
|
||||
try:
|
||||
response = await self.execute_command("cat /proc/stat && free && iwconfig")
|
||||
return self._parse_status_response(response)
|
||||
except Exception as e:
|
||||
raise RouterConnectionError(f"Failed to retrieve router status: {e}")
|
||||
|
||||
async def configure_csi_monitoring(self, config: Dict[str, Any]) -> bool:
|
||||
"""Configure CSI monitoring on router.
|
||||
|
||||
Args:
|
||||
config: CSI monitoring configuration
|
||||
|
||||
Returns:
|
||||
True if configuration successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
channel = config.get('channel', 6)
|
||||
command = f"iwconfig wlan0 channel {channel} && echo 'CSI monitoring configured'"
|
||||
await self.execute_command(command)
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to configure CSI monitoring: {e}")
|
||||
return False
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""Perform health check on router.
|
||||
|
||||
Returns:
|
||||
True if router is healthy, False otherwise
|
||||
"""
|
||||
try:
|
||||
response = await self.execute_command("echo 'ping' && echo 'pong'")
|
||||
return "pong" in response
|
||||
except Exception as e:
|
||||
self.logger.error(f"Health check failed: {e}")
|
||||
return False
|
||||
|
||||
def _parse_csi_response(self, response: str) -> CSIData:
|
||||
"""Parse CSI response data.
|
||||
|
||||
Args:
|
||||
response: Raw response from router
|
||||
|
||||
Returns:
|
||||
Parsed CSI data
|
||||
"""
|
||||
# Mock implementation for testing
|
||||
# In real implementation, this would parse actual router CSI format
|
||||
return CSIData(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
amplitude=np.random.rand(3, 56),
|
||||
phase=np.random.rand(3, 56),
|
||||
frequency=2.4e9,
|
||||
bandwidth=20e6,
|
||||
num_subcarriers=56,
|
||||
num_antennas=3,
|
||||
snr=15.0,
|
||||
metadata={'source': 'router', 'raw_response': response}
|
||||
)
|
||||
|
||||
def _parse_status_response(self, response: str) -> Dict[str, Any]:
|
||||
"""Parse router status response.
|
||||
|
||||
Args:
|
||||
response: Raw response from router
|
||||
|
||||
Returns:
|
||||
Parsed status information
|
||||
"""
|
||||
# Mock implementation for testing
|
||||
# In real implementation, this would parse actual system status
|
||||
return {
|
||||
'cpu_usage': 25.5,
|
||||
'memory_usage': 60.2,
|
||||
'wifi_status': 'active',
|
||||
'uptime': '5 days, 3 hours',
|
||||
'raw_response': response
|
||||
}
|
||||
330
v1/src/logger.py
Normal file
330
v1/src/logger.py
Normal file
@@ -0,0 +1,330 @@
|
||||
"""
|
||||
Logging configuration for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
import logging.config
|
||||
import logging.handlers
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from src.config.settings import Settings
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
"""Colored log formatter for console output."""
|
||||
|
||||
# ANSI color codes
|
||||
COLORS = {
|
||||
'DEBUG': '\033[36m', # Cyan
|
||||
'INFO': '\033[32m', # Green
|
||||
'WARNING': '\033[33m', # Yellow
|
||||
'ERROR': '\033[31m', # Red
|
||||
'CRITICAL': '\033[35m', # Magenta
|
||||
'RESET': '\033[0m' # Reset
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
"""Format log record with colors."""
|
||||
if hasattr(record, 'levelname'):
|
||||
color = self.COLORS.get(record.levelname, self.COLORS['RESET'])
|
||||
record.levelname = f"{color}{record.levelname}{self.COLORS['RESET']}"
|
||||
|
||||
return super().format(record)
|
||||
|
||||
|
||||
class StructuredFormatter(logging.Formatter):
|
||||
"""Structured JSON formatter for log files."""
|
||||
|
||||
def format(self, record):
|
||||
"""Format log record as structured JSON."""
|
||||
import json
|
||||
|
||||
log_entry = {
|
||||
'timestamp': datetime.utcnow().isoformat(),
|
||||
'level': record.levelname,
|
||||
'logger': record.name,
|
||||
'message': record.getMessage(),
|
||||
'module': record.module,
|
||||
'function': record.funcName,
|
||||
'line': record.lineno,
|
||||
}
|
||||
|
||||
# Add exception info if present
|
||||
if record.exc_info:
|
||||
log_entry['exception'] = self.formatException(record.exc_info)
|
||||
|
||||
# Add extra fields
|
||||
for key, value in record.__dict__.items():
|
||||
if key not in ['name', 'msg', 'args', 'levelname', 'levelno', 'pathname',
|
||||
'filename', 'module', 'lineno', 'funcName', 'created',
|
||||
'msecs', 'relativeCreated', 'thread', 'threadName',
|
||||
'processName', 'process', 'getMessage', 'exc_info',
|
||||
'exc_text', 'stack_info']:
|
||||
log_entry[key] = value
|
||||
|
||||
return json.dumps(log_entry)
|
||||
|
||||
|
||||
class RequestContextFilter(logging.Filter):
|
||||
"""Filter to add request context to log records."""
|
||||
|
||||
def filter(self, record):
|
||||
"""Add request context to log record."""
|
||||
# Try to get request context from contextvars or thread local
|
||||
try:
|
||||
import contextvars
|
||||
request_id = contextvars.ContextVar('request_id', default=None).get()
|
||||
user_id = contextvars.ContextVar('user_id', default=None).get()
|
||||
|
||||
if request_id:
|
||||
record.request_id = request_id
|
||||
if user_id:
|
||||
record.user_id = user_id
|
||||
|
||||
except (ImportError, LookupError):
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def setup_logging(settings: Settings) -> None:
|
||||
"""Setup application logging configuration."""
|
||||
|
||||
# Create log directory if file logging is enabled
|
||||
if settings.log_file:
|
||||
log_path = Path(settings.log_file)
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Build logging configuration
|
||||
config = build_logging_config(settings)
|
||||
|
||||
# Apply configuration
|
||||
logging.config.dictConfig(config)
|
||||
|
||||
# Set up root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(settings.log_level)
|
||||
|
||||
# Add request context filter to all handlers
|
||||
request_filter = RequestContextFilter()
|
||||
for handler in root_logger.handlers:
|
||||
handler.addFilter(request_filter)
|
||||
|
||||
# Log startup message
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"Logging configured - Level: {settings.log_level}, File: {settings.log_file}")
|
||||
|
||||
|
||||
def build_logging_config(settings: Settings) -> Dict[str, Any]:
|
||||
"""Build logging configuration dictionary."""
|
||||
|
||||
config = {
|
||||
'version': 1,
|
||||
'disable_existing_loggers': False,
|
||||
'formatters': {
|
||||
'console': {
|
||||
'()': ColoredFormatter,
|
||||
'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
'datefmt': '%Y-%m-%d %H:%M:%S'
|
||||
},
|
||||
'file': {
|
||||
'format': '%(asctime)s - %(name)s - %(levelname)s - %(module)s:%(lineno)d - %(message)s',
|
||||
'datefmt': '%Y-%m-%d %H:%M:%S'
|
||||
},
|
||||
'structured': {
|
||||
'()': StructuredFormatter
|
||||
}
|
||||
},
|
||||
'handlers': {
|
||||
'console': {
|
||||
'class': 'logging.StreamHandler',
|
||||
'level': settings.log_level,
|
||||
'formatter': 'console',
|
||||
'stream': 'ext://sys.stdout'
|
||||
}
|
||||
},
|
||||
'loggers': {
|
||||
'': { # Root logger
|
||||
'level': settings.log_level,
|
||||
'handlers': ['console'],
|
||||
'propagate': False
|
||||
},
|
||||
'src': { # Application logger
|
||||
'level': settings.log_level,
|
||||
'handlers': ['console'],
|
||||
'propagate': False
|
||||
},
|
||||
'uvicorn': {
|
||||
'level': 'INFO',
|
||||
'handlers': ['console'],
|
||||
'propagate': False
|
||||
},
|
||||
'uvicorn.access': {
|
||||
'level': 'INFO',
|
||||
'handlers': ['console'],
|
||||
'propagate': False
|
||||
},
|
||||
'fastapi': {
|
||||
'level': 'INFO',
|
||||
'handlers': ['console'],
|
||||
'propagate': False
|
||||
},
|
||||
'sqlalchemy': {
|
||||
'level': 'WARNING',
|
||||
'handlers': ['console'],
|
||||
'propagate': False
|
||||
},
|
||||
'sqlalchemy.engine': {
|
||||
'level': 'INFO' if settings.debug else 'WARNING',
|
||||
'handlers': ['console'],
|
||||
'propagate': False
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Add file handler if log file is specified
|
||||
if settings.log_file:
|
||||
config['handlers']['file'] = {
|
||||
'class': 'logging.handlers.RotatingFileHandler',
|
||||
'level': settings.log_level,
|
||||
'formatter': 'file',
|
||||
'filename': settings.log_file,
|
||||
'maxBytes': settings.log_max_size,
|
||||
'backupCount': settings.log_backup_count,
|
||||
'encoding': 'utf-8'
|
||||
}
|
||||
|
||||
# Add structured log handler for JSON logs
|
||||
structured_log_file = str(Path(settings.log_file).with_suffix('.json'))
|
||||
config['handlers']['structured'] = {
|
||||
'class': 'logging.handlers.RotatingFileHandler',
|
||||
'level': settings.log_level,
|
||||
'formatter': 'structured',
|
||||
'filename': structured_log_file,
|
||||
'maxBytes': settings.log_max_size,
|
||||
'backupCount': settings.log_backup_count,
|
||||
'encoding': 'utf-8'
|
||||
}
|
||||
|
||||
# Add file handlers to all loggers
|
||||
for logger_config in config['loggers'].values():
|
||||
logger_config['handlers'].extend(['file', 'structured'])
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""Get a logger with the specified name."""
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
def configure_third_party_loggers(settings: Settings) -> None:
|
||||
"""Configure third-party library loggers."""
|
||||
|
||||
# Suppress noisy loggers in production
|
||||
if settings.is_production:
|
||||
logging.getLogger('urllib3').setLevel(logging.WARNING)
|
||||
logging.getLogger('requests').setLevel(logging.WARNING)
|
||||
logging.getLogger('asyncio').setLevel(logging.WARNING)
|
||||
logging.getLogger('multipart').setLevel(logging.WARNING)
|
||||
|
||||
# Configure SQLAlchemy logging
|
||||
if settings.debug and settings.is_development:
|
||||
logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
|
||||
logging.getLogger('sqlalchemy.pool').setLevel(logging.DEBUG)
|
||||
else:
|
||||
logging.getLogger('sqlalchemy').setLevel(logging.WARNING)
|
||||
|
||||
# Configure Redis logging
|
||||
logging.getLogger('redis').setLevel(logging.WARNING)
|
||||
|
||||
# Configure WebSocket logging
|
||||
logging.getLogger('websockets').setLevel(logging.INFO)
|
||||
|
||||
|
||||
class LoggerMixin:
|
||||
"""Mixin class to add logging capabilities to any class."""
|
||||
|
||||
@property
|
||||
def logger(self) -> logging.Logger:
|
||||
"""Get logger for this class."""
|
||||
return logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}")
|
||||
|
||||
|
||||
def log_function_call(func):
|
||||
"""Decorator to log function calls."""
|
||||
import functools
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
logger = logging.getLogger(func.__module__)
|
||||
logger.debug(f"Calling {func.__name__} with args={args}, kwargs={kwargs}")
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
logger.debug(f"{func.__name__} completed successfully")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"{func.__name__} failed with error: {e}")
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def log_async_function_call(func):
|
||||
"""Decorator to log async function calls."""
|
||||
import functools
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
logger = logging.getLogger(func.__module__)
|
||||
logger.debug(f"Calling async {func.__name__} with args={args}, kwargs={kwargs}")
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
logger.debug(f"Async {func.__name__} completed successfully")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Async {func.__name__} failed with error: {e}")
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def setup_request_logging():
|
||||
"""Setup request-specific logging context."""
|
||||
import contextvars
|
||||
import uuid
|
||||
|
||||
# Create context variables for request tracking
|
||||
request_id_var = contextvars.ContextVar('request_id')
|
||||
user_id_var = contextvars.ContextVar('user_id')
|
||||
|
||||
def set_request_context(request_id: Optional[str] = None, user_id: Optional[str] = None):
|
||||
"""Set request context for logging."""
|
||||
if request_id is None:
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
request_id_var.set(request_id)
|
||||
if user_id:
|
||||
user_id_var.set(user_id)
|
||||
|
||||
def get_request_context():
|
||||
"""Get current request context."""
|
||||
try:
|
||||
return {
|
||||
'request_id': request_id_var.get(),
|
||||
'user_id': user_id_var.get(None)
|
||||
}
|
||||
except LookupError:
|
||||
return {}
|
||||
|
||||
return set_request_context, get_request_context
|
||||
|
||||
|
||||
# Initialize request logging context
|
||||
set_request_context, get_request_context = setup_request_logging()
|
||||
117
v1/src/main.py
Normal file
117
v1/src/main.py
Normal file
@@ -0,0 +1,117 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Main application entry point for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
import signal
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
# Add src to Python path
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from src.config.settings import get_settings, validate_settings
|
||||
from src.logger import setup_logging
|
||||
from src.app import create_app
|
||||
from src.services.orchestrator import ServiceOrchestrator
|
||||
from src.cli import create_cli
|
||||
|
||||
|
||||
def setup_signal_handlers(orchestrator: ServiceOrchestrator):
|
||||
"""Setup signal handlers for graceful shutdown."""
|
||||
def signal_handler(signum, frame):
|
||||
logging.info(f"Received signal {signum}, initiating graceful shutdown...")
|
||||
asyncio.create_task(orchestrator.shutdown())
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main application entry point."""
|
||||
try:
|
||||
# Load settings
|
||||
settings = get_settings()
|
||||
|
||||
# Setup logging
|
||||
setup_logging(settings)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger.info(f"Starting {settings.app_name} v{settings.version}")
|
||||
logger.info(f"Environment: {settings.environment}")
|
||||
|
||||
# Validate settings
|
||||
issues = validate_settings(settings)
|
||||
if issues:
|
||||
logger.error("Configuration issues found:")
|
||||
for issue in issues:
|
||||
logger.error(f" - {issue}")
|
||||
if settings.is_production:
|
||||
sys.exit(1)
|
||||
else:
|
||||
logger.warning("Continuing with configuration issues in development mode")
|
||||
|
||||
# Create service orchestrator
|
||||
orchestrator = ServiceOrchestrator(settings)
|
||||
|
||||
# Setup signal handlers
|
||||
setup_signal_handlers(orchestrator)
|
||||
|
||||
# Initialize services
|
||||
await orchestrator.initialize()
|
||||
|
||||
# Create FastAPI app
|
||||
app = create_app(settings, orchestrator)
|
||||
|
||||
# Start the application
|
||||
if len(sys.argv) > 1:
|
||||
# CLI mode
|
||||
cli = create_cli(orchestrator)
|
||||
await cli.run(sys.argv[1:])
|
||||
else:
|
||||
# Server mode
|
||||
import uvicorn
|
||||
|
||||
logger.info(f"Starting server on {settings.host}:{settings.port}")
|
||||
|
||||
config = uvicorn.Config(
|
||||
app,
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
reload=settings.reload and settings.is_development,
|
||||
workers=settings.workers if not settings.reload else 1,
|
||||
log_level=settings.log_level.lower(),
|
||||
access_log=True,
|
||||
use_colors=True
|
||||
)
|
||||
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received keyboard interrupt, shutting down...")
|
||||
except Exception as e:
|
||||
logger.error(f"Application failed to start: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
finally:
|
||||
# Cleanup
|
||||
if 'orchestrator' in locals():
|
||||
await orchestrator.shutdown()
|
||||
logger.info("Application shutdown complete")
|
||||
|
||||
|
||||
def run():
|
||||
"""Entry point for package installation."""
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
467
v1/src/middleware/auth.py
Normal file
467
v1/src/middleware/auth.py
Normal file
@@ -0,0 +1,467 @@
|
||||
"""
|
||||
Authentication middleware for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, Dict, Any, Callable
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import Request, Response, HTTPException, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.logger import set_request_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Password hashing
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# JWT token handler
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
"""Authentication error."""
|
||||
pass
|
||||
|
||||
|
||||
class AuthorizationError(Exception):
|
||||
"""Authorization error."""
|
||||
pass
|
||||
|
||||
|
||||
class TokenManager:
|
||||
"""JWT token management."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self.secret_key = settings.secret_key
|
||||
self.algorithm = settings.jwt_algorithm
|
||||
self.expire_hours = settings.jwt_expire_hours
|
||||
|
||||
def create_access_token(self, data: Dict[str, Any]) -> str:
|
||||
"""Create JWT access token."""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + timedelta(hours=self.expire_hours)
|
||||
to_encode.update({"exp": expire, "iat": datetime.utcnow()})
|
||||
|
||||
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
||||
return encoded_jwt
|
||||
|
||||
def verify_token(self, token: str) -> Dict[str, Any]:
|
||||
"""Verify and decode JWT token."""
|
||||
try:
|
||||
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
||||
return payload
|
||||
except JWTError as e:
|
||||
logger.warning(f"JWT verification failed: {e}")
|
||||
raise AuthenticationError("Invalid token")
|
||||
|
||||
def decode_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""Decode token without verification (for debugging)."""
|
||||
try:
|
||||
return jwt.decode(token, options={"verify_signature": False})
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
|
||||
class UserManager:
|
||||
"""User management for authentication."""
|
||||
|
||||
def __init__(self):
|
||||
# In a real application, this would connect to a database
|
||||
# For now, we'll use a simple in-memory store
|
||||
self._users: Dict[str, Dict[str, Any]] = {
|
||||
"admin": {
|
||||
"username": "admin",
|
||||
"email": "admin@example.com",
|
||||
"hashed_password": self.hash_password("admin123"),
|
||||
"roles": ["admin"],
|
||||
"is_active": True,
|
||||
"created_at": datetime.utcnow(),
|
||||
},
|
||||
"user": {
|
||||
"username": "user",
|
||||
"email": "user@example.com",
|
||||
"hashed_password": self.hash_password("user123"),
|
||||
"roles": ["user"],
|
||||
"is_active": True,
|
||||
"created_at": datetime.utcnow(),
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a password."""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
@staticmethod
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash."""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
def get_user(self, username: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get user by username."""
|
||||
return self._users.get(username)
|
||||
|
||||
def authenticate_user(self, username: str, password: str) -> Optional[Dict[str, Any]]:
|
||||
"""Authenticate user with username and password."""
|
||||
user = self.get_user(username)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
if not self.verify_password(password, user["hashed_password"]):
|
||||
return None
|
||||
|
||||
if not user.get("is_active", False):
|
||||
return None
|
||||
|
||||
return user
|
||||
|
||||
def create_user(self, username: str, email: str, password: str, roles: list = None) -> Dict[str, Any]:
|
||||
"""Create a new user."""
|
||||
if username in self._users:
|
||||
raise ValueError("User already exists")
|
||||
|
||||
user = {
|
||||
"username": username,
|
||||
"email": email,
|
||||
"hashed_password": self.hash_password(password),
|
||||
"roles": roles or ["user"],
|
||||
"is_active": True,
|
||||
"created_at": datetime.utcnow(),
|
||||
}
|
||||
|
||||
self._users[username] = user
|
||||
return user
|
||||
|
||||
def update_user(self, username: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Update user information."""
|
||||
user = self._users.get(username)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
# Don't allow updating certain fields
|
||||
protected_fields = {"username", "created_at", "hashed_password"}
|
||||
updates = {k: v for k, v in updates.items() if k not in protected_fields}
|
||||
|
||||
user.update(updates)
|
||||
return user
|
||||
|
||||
def deactivate_user(self, username: str) -> bool:
|
||||
"""Deactivate a user."""
|
||||
user = self._users.get(username)
|
||||
if user:
|
||||
user["is_active"] = False
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
"""Authentication middleware for FastAPI."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self.token_manager = TokenManager(settings)
|
||||
self.user_manager = UserManager()
|
||||
self.enabled = settings.enable_authentication
|
||||
|
||||
async def __call__(self, request: Request, call_next: Callable) -> Response:
|
||||
"""Process request through authentication middleware."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Skip authentication for certain paths
|
||||
if self._should_skip_auth(request):
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
# Skip if authentication is disabled
|
||||
if not self.enabled:
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
# Extract and verify token
|
||||
user_info = await self._authenticate_request(request)
|
||||
|
||||
# Set user context
|
||||
if user_info:
|
||||
request.state.user = user_info
|
||||
set_request_context(user_id=user_info.get("username"))
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add authentication headers
|
||||
self._add_auth_headers(response, user_info)
|
||||
|
||||
return response
|
||||
|
||||
except AuthenticationError as e:
|
||||
logger.warning(f"Authentication failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except AuthorizationError as e:
|
||||
logger.warning(f"Authorization failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication middleware error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Authentication service error",
|
||||
)
|
||||
finally:
|
||||
# Log request processing time
|
||||
processing_time = time.time() - start_time
|
||||
logger.debug(f"Auth middleware processing time: {processing_time:.3f}s")
|
||||
|
||||
def _should_skip_auth(self, request: Request) -> bool:
|
||||
"""Check if authentication should be skipped for this request."""
|
||||
path = request.url.path
|
||||
|
||||
# Skip authentication for these paths
|
||||
skip_paths = [
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/auth/login",
|
||||
"/auth/register",
|
||||
"/static",
|
||||
]
|
||||
|
||||
return any(path.startswith(skip_path) for skip_path in skip_paths)
|
||||
|
||||
async def _authenticate_request(self, request: Request) -> Optional[Dict[str, Any]]:
|
||||
"""Authenticate the request and return user info."""
|
||||
# Try to get token from Authorization header
|
||||
authorization = request.headers.get("Authorization")
|
||||
if not authorization:
|
||||
# For WebSocket connections, try to get token from query parameters
|
||||
if request.url.path.startswith("/ws"):
|
||||
token = request.query_params.get("token")
|
||||
if token:
|
||||
authorization = f"Bearer {token}"
|
||||
|
||||
if not authorization:
|
||||
if self._requires_auth(request):
|
||||
raise AuthenticationError("Missing authorization header")
|
||||
return None
|
||||
|
||||
# Extract token
|
||||
try:
|
||||
scheme, token = authorization.split()
|
||||
if scheme.lower() != "bearer":
|
||||
raise AuthenticationError("Invalid authentication scheme")
|
||||
except ValueError:
|
||||
raise AuthenticationError("Invalid authorization header format")
|
||||
|
||||
# Verify token
|
||||
try:
|
||||
payload = self.token_manager.verify_token(token)
|
||||
username = payload.get("sub")
|
||||
if not username:
|
||||
raise AuthenticationError("Invalid token payload")
|
||||
|
||||
# Get user info
|
||||
user = self.user_manager.get_user(username)
|
||||
if not user:
|
||||
raise AuthenticationError("User not found")
|
||||
|
||||
if not user.get("is_active", False):
|
||||
raise AuthenticationError("User account is disabled")
|
||||
|
||||
# Return user info without sensitive data
|
||||
return {
|
||||
"username": user["username"],
|
||||
"email": user["email"],
|
||||
"roles": user["roles"],
|
||||
"is_active": user["is_active"],
|
||||
}
|
||||
|
||||
except AuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Token verification error: {e}")
|
||||
raise AuthenticationError("Token verification failed")
|
||||
|
||||
def _requires_auth(self, request: Request) -> bool:
|
||||
"""Check if the request requires authentication."""
|
||||
# All API endpoints require authentication by default
|
||||
path = request.url.path
|
||||
return path.startswith("/api/") or path.startswith("/ws/")
|
||||
|
||||
def _add_auth_headers(self, response: Response, user_info: Optional[Dict[str, Any]]):
|
||||
"""Add authentication-related headers to response."""
|
||||
if user_info:
|
||||
response.headers["X-User"] = user_info["username"]
|
||||
response.headers["X-User-Roles"] = ",".join(user_info["roles"])
|
||||
|
||||
async def login(self, username: str, password: str) -> Dict[str, Any]:
|
||||
"""Authenticate user and return token."""
|
||||
user = self.user_manager.authenticate_user(username, password)
|
||||
if not user:
|
||||
raise AuthenticationError("Invalid username or password")
|
||||
|
||||
# Create token
|
||||
token_data = {
|
||||
"sub": user["username"],
|
||||
"email": user["email"],
|
||||
"roles": user["roles"],
|
||||
}
|
||||
|
||||
access_token = self.token_manager.create_access_token(token_data)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": self.settings.jwt_expire_hours * 3600,
|
||||
"user": {
|
||||
"username": user["username"],
|
||||
"email": user["email"],
|
||||
"roles": user["roles"],
|
||||
}
|
||||
}
|
||||
|
||||
async def register(self, username: str, email: str, password: str) -> Dict[str, Any]:
|
||||
"""Register a new user."""
|
||||
try:
|
||||
user = self.user_manager.create_user(username, email, password)
|
||||
|
||||
# Create token for new user
|
||||
token_data = {
|
||||
"sub": user["username"],
|
||||
"email": user["email"],
|
||||
"roles": user["roles"],
|
||||
}
|
||||
|
||||
access_token = self.token_manager.create_access_token(token_data)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": self.settings.jwt_expire_hours * 3600,
|
||||
"user": {
|
||||
"username": user["username"],
|
||||
"email": user["email"],
|
||||
"roles": user["roles"],
|
||||
}
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
raise AuthenticationError(str(e))
|
||||
|
||||
async def refresh_token(self, token: str) -> Dict[str, Any]:
|
||||
"""Refresh an access token."""
|
||||
try:
|
||||
payload = self.token_manager.verify_token(token)
|
||||
username = payload.get("sub")
|
||||
|
||||
user = self.user_manager.get_user(username)
|
||||
if not user or not user.get("is_active", False):
|
||||
raise AuthenticationError("User not found or inactive")
|
||||
|
||||
# Create new token
|
||||
token_data = {
|
||||
"sub": user["username"],
|
||||
"email": user["email"],
|
||||
"roles": user["roles"],
|
||||
}
|
||||
|
||||
new_token = self.token_manager.create_access_token(token_data)
|
||||
|
||||
return {
|
||||
"access_token": new_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": self.settings.jwt_expire_hours * 3600,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise AuthenticationError("Token refresh failed")
|
||||
|
||||
def check_permission(self, user_info: Dict[str, Any], required_role: str) -> bool:
|
||||
"""Check if user has required role/permission."""
|
||||
user_roles = user_info.get("roles", [])
|
||||
|
||||
# Admin role has all permissions
|
||||
if "admin" in user_roles:
|
||||
return True
|
||||
|
||||
# Check specific role
|
||||
return required_role in user_roles
|
||||
|
||||
def require_role(self, required_role: str):
|
||||
"""Decorator to require specific role."""
|
||||
def decorator(func):
|
||||
import functools
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(request: Request, *args, **kwargs):
|
||||
user_info = getattr(request.state, "user", None)
|
||||
if not user_info:
|
||||
raise AuthorizationError("Authentication required")
|
||||
|
||||
if not self.check_permission(user_info, required_role):
|
||||
raise AuthorizationError(f"Role '{required_role}' required")
|
||||
|
||||
return await func(request, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
# Global authentication middleware instance
|
||||
_auth_middleware: Optional[AuthenticationMiddleware] = None
|
||||
|
||||
|
||||
def get_auth_middleware(settings: Settings) -> AuthenticationMiddleware:
|
||||
"""Get authentication middleware instance."""
|
||||
global _auth_middleware
|
||||
if _auth_middleware is None:
|
||||
_auth_middleware = AuthenticationMiddleware(settings)
|
||||
return _auth_middleware
|
||||
|
||||
|
||||
def get_current_user(request: Request) -> Optional[Dict[str, Any]]:
|
||||
"""Get current authenticated user from request."""
|
||||
return getattr(request.state, "user", None)
|
||||
|
||||
|
||||
def require_authentication(request: Request) -> Dict[str, Any]:
|
||||
"""Require authentication and return user info."""
|
||||
user = get_current_user(request)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
def require_role(role: str):
|
||||
"""Dependency to require specific role."""
|
||||
def dependency(request: Request) -> Dict[str, Any]:
|
||||
user = require_authentication(request)
|
||||
|
||||
auth_middleware = get_auth_middleware(request.app.state.settings)
|
||||
if not auth_middleware.check_permission(user, role):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Role '{role}' required",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
return dependency
|
||||
375
v1/src/middleware/cors.py
Normal file
375
v1/src/middleware/cors.py
Normal file
@@ -0,0 +1,375 @@
|
||||
"""
|
||||
CORS middleware for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Union, Callable
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import Request, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware as FastAPICORSMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from src.config.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CORSMiddleware:
|
||||
"""Enhanced CORS middleware with additional security features."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
settings: Settings,
|
||||
allow_origins: Optional[List[str]] = None,
|
||||
allow_methods: Optional[List[str]] = None,
|
||||
allow_headers: Optional[List[str]] = None,
|
||||
allow_credentials: bool = False,
|
||||
expose_headers: Optional[List[str]] = None,
|
||||
max_age: int = 600,
|
||||
):
|
||||
self.app = app
|
||||
self.settings = settings
|
||||
self.allow_origins = allow_origins or settings.cors_origins
|
||||
self.allow_methods = allow_methods or ["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"]
|
||||
self.allow_headers = allow_headers or [
|
||||
"Accept",
|
||||
"Accept-Language",
|
||||
"Content-Language",
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
"X-Requested-With",
|
||||
"X-Request-ID",
|
||||
"X-User-Agent",
|
||||
]
|
||||
self.allow_credentials = allow_credentials or settings.cors_allow_credentials
|
||||
self.expose_headers = expose_headers or [
|
||||
"X-Request-ID",
|
||||
"X-Response-Time",
|
||||
"X-Rate-Limit-Remaining",
|
||||
"X-Rate-Limit-Reset",
|
||||
]
|
||||
self.max_age = max_age
|
||||
|
||||
# Security settings
|
||||
self.strict_origin_check = settings.is_production
|
||||
self.log_cors_violations = True
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
"""ASGI middleware implementation."""
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
request = Request(scope, receive)
|
||||
|
||||
# Check if this is a CORS preflight request
|
||||
if request.method == "OPTIONS" and "access-control-request-method" in request.headers:
|
||||
response = await self._handle_preflight(request)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
# Handle actual request
|
||||
async def send_wrapper(message):
|
||||
if message["type"] == "http.response.start":
|
||||
# Add CORS headers to response
|
||||
headers = dict(message.get("headers", []))
|
||||
cors_headers = self._get_cors_headers(request)
|
||||
|
||||
for key, value in cors_headers.items():
|
||||
headers[key.encode()] = value.encode()
|
||||
|
||||
message["headers"] = list(headers.items())
|
||||
|
||||
await send(message)
|
||||
|
||||
await self.app(scope, receive, send_wrapper)
|
||||
|
||||
async def _handle_preflight(self, request: Request) -> Response:
|
||||
"""Handle CORS preflight request."""
|
||||
origin = request.headers.get("origin")
|
||||
requested_method = request.headers.get("access-control-request-method")
|
||||
requested_headers = request.headers.get("access-control-request-headers", "")
|
||||
|
||||
# Validate origin
|
||||
if not self._is_origin_allowed(origin):
|
||||
if self.log_cors_violations:
|
||||
logger.warning(f"CORS preflight rejected for origin: {origin}")
|
||||
|
||||
return Response(
|
||||
status_code=403,
|
||||
content="CORS preflight request rejected",
|
||||
headers={"Content-Type": "text/plain"}
|
||||
)
|
||||
|
||||
# Validate method
|
||||
if requested_method not in self.allow_methods:
|
||||
if self.log_cors_violations:
|
||||
logger.warning(f"CORS preflight rejected for method: {requested_method}")
|
||||
|
||||
return Response(
|
||||
status_code=405,
|
||||
content="Method not allowed",
|
||||
headers={"Content-Type": "text/plain"}
|
||||
)
|
||||
|
||||
# Validate headers
|
||||
if requested_headers:
|
||||
requested_header_list = [h.strip().lower() for h in requested_headers.split(",")]
|
||||
allowed_headers_lower = [h.lower() for h in self.allow_headers]
|
||||
|
||||
for header in requested_header_list:
|
||||
if header not in allowed_headers_lower:
|
||||
if self.log_cors_violations:
|
||||
logger.warning(f"CORS preflight rejected for header: {header}")
|
||||
|
||||
return Response(
|
||||
status_code=400,
|
||||
content="Header not allowed",
|
||||
headers={"Content-Type": "text/plain"}
|
||||
)
|
||||
|
||||
# Build preflight response headers
|
||||
headers = {
|
||||
"Access-Control-Allow-Origin": origin,
|
||||
"Access-Control-Allow-Methods": ", ".join(self.allow_methods),
|
||||
"Access-Control-Allow-Headers": ", ".join(self.allow_headers),
|
||||
"Access-Control-Max-Age": str(self.max_age),
|
||||
}
|
||||
|
||||
if self.allow_credentials:
|
||||
headers["Access-Control-Allow-Credentials"] = "true"
|
||||
|
||||
if self.expose_headers:
|
||||
headers["Access-Control-Expose-Headers"] = ", ".join(self.expose_headers)
|
||||
|
||||
logger.debug(f"CORS preflight approved for origin: {origin}")
|
||||
|
||||
return Response(
|
||||
status_code=200,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
def _get_cors_headers(self, request: Request) -> dict:
|
||||
"""Get CORS headers for actual request."""
|
||||
origin = request.headers.get("origin")
|
||||
headers = {}
|
||||
|
||||
if self._is_origin_allowed(origin):
|
||||
headers["Access-Control-Allow-Origin"] = origin
|
||||
|
||||
if self.allow_credentials:
|
||||
headers["Access-Control-Allow-Credentials"] = "true"
|
||||
|
||||
if self.expose_headers:
|
||||
headers["Access-Control-Expose-Headers"] = ", ".join(self.expose_headers)
|
||||
|
||||
return headers
|
||||
|
||||
def _is_origin_allowed(self, origin: Optional[str]) -> bool:
|
||||
"""Check if origin is allowed."""
|
||||
if not origin:
|
||||
return not self.strict_origin_check
|
||||
|
||||
# Allow all origins in development
|
||||
if not self.settings.is_production and "*" in self.allow_origins:
|
||||
return True
|
||||
|
||||
# Check exact matches
|
||||
if origin in self.allow_origins:
|
||||
return True
|
||||
|
||||
# Check wildcard patterns
|
||||
for allowed_origin in self.allow_origins:
|
||||
if allowed_origin == "*":
|
||||
return not self.strict_origin_check
|
||||
|
||||
if self._match_origin_pattern(origin, allowed_origin):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _match_origin_pattern(self, origin: str, pattern: str) -> bool:
|
||||
"""Match origin against pattern with wildcard support."""
|
||||
if "*" not in pattern:
|
||||
return origin == pattern
|
||||
|
||||
# Simple wildcard matching
|
||||
if pattern.startswith("*."):
|
||||
domain = pattern[2:]
|
||||
parsed_origin = urlparse(origin)
|
||||
origin_host = parsed_origin.netloc
|
||||
|
||||
# Check if origin ends with the domain
|
||||
return origin_host.endswith(domain) or origin_host == domain[1:] if domain.startswith('.') else origin_host == domain
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def setup_cors_middleware(app: ASGIApp, settings: Settings) -> ASGIApp:
|
||||
"""Setup CORS middleware for the application."""
|
||||
|
||||
if settings.cors_enabled:
|
||||
logger.info("Setting up CORS middleware")
|
||||
|
||||
# Use FastAPI's built-in CORS middleware for basic functionality
|
||||
app = FastAPICORSMiddleware(
|
||||
app,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=settings.cors_allow_credentials,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
|
||||
allow_headers=[
|
||||
"Accept",
|
||||
"Accept-Language",
|
||||
"Content-Language",
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
"X-Requested-With",
|
||||
"X-Request-ID",
|
||||
"X-User-Agent",
|
||||
],
|
||||
expose_headers=[
|
||||
"X-Request-ID",
|
||||
"X-Response-Time",
|
||||
"X-Rate-Limit-Remaining",
|
||||
"X-Rate-Limit-Reset",
|
||||
],
|
||||
max_age=600,
|
||||
)
|
||||
|
||||
logger.info(f"CORS enabled for origins: {settings.cors_origins}")
|
||||
else:
|
||||
logger.info("CORS middleware disabled")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class CORSConfig:
|
||||
"""CORS configuration helper."""
|
||||
|
||||
@staticmethod
|
||||
def development_config() -> dict:
|
||||
"""Get CORS configuration for development."""
|
||||
return {
|
||||
"allow_origins": ["*"],
|
||||
"allow_credentials": True,
|
||||
"allow_methods": ["*"],
|
||||
"allow_headers": ["*"],
|
||||
"expose_headers": [
|
||||
"X-Request-ID",
|
||||
"X-Response-Time",
|
||||
"X-Rate-Limit-Remaining",
|
||||
"X-Rate-Limit-Reset",
|
||||
],
|
||||
"max_age": 600,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def production_config(allowed_origins: List[str]) -> dict:
|
||||
"""Get CORS configuration for production."""
|
||||
return {
|
||||
"allow_origins": allowed_origins,
|
||||
"allow_credentials": True,
|
||||
"allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
|
||||
"allow_headers": [
|
||||
"Accept",
|
||||
"Accept-Language",
|
||||
"Content-Language",
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
"X-Requested-With",
|
||||
"X-Request-ID",
|
||||
"X-User-Agent",
|
||||
],
|
||||
"expose_headers": [
|
||||
"X-Request-ID",
|
||||
"X-Response-Time",
|
||||
"X-Rate-Limit-Remaining",
|
||||
"X-Rate-Limit-Reset",
|
||||
],
|
||||
"max_age": 3600, # 1 hour for production
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def api_only_config(allowed_origins: List[str]) -> dict:
|
||||
"""Get CORS configuration for API-only access."""
|
||||
return {
|
||||
"allow_origins": allowed_origins,
|
||||
"allow_credentials": False,
|
||||
"allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
"allow_headers": [
|
||||
"Accept",
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
"X-Request-ID",
|
||||
],
|
||||
"expose_headers": [
|
||||
"X-Request-ID",
|
||||
"X-Rate-Limit-Remaining",
|
||||
"X-Rate-Limit-Reset",
|
||||
],
|
||||
"max_age": 3600,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def websocket_config(allowed_origins: List[str]) -> dict:
|
||||
"""Get CORS configuration for WebSocket connections."""
|
||||
return {
|
||||
"allow_origins": allowed_origins,
|
||||
"allow_credentials": True,
|
||||
"allow_methods": ["GET", "OPTIONS"],
|
||||
"allow_headers": [
|
||||
"Accept",
|
||||
"Authorization",
|
||||
"Sec-WebSocket-Protocol",
|
||||
"Sec-WebSocket-Extensions",
|
||||
],
|
||||
"expose_headers": [],
|
||||
"max_age": 86400, # 24 hours for WebSocket
|
||||
}
|
||||
|
||||
|
||||
def validate_cors_config(settings: Settings) -> List[str]:
|
||||
"""Validate CORS configuration and return issues."""
|
||||
issues = []
|
||||
|
||||
if not settings.cors_enabled:
|
||||
return issues
|
||||
|
||||
# Check origins
|
||||
if not settings.cors_origins:
|
||||
issues.append("CORS is enabled but no origins are configured")
|
||||
|
||||
# Check for wildcard in production
|
||||
if settings.is_production and "*" in settings.cors_origins:
|
||||
issues.append("Wildcard origin (*) should not be used in production")
|
||||
|
||||
# Validate origin formats
|
||||
for origin in settings.cors_origins:
|
||||
if origin != "*" and not origin.startswith(("http://", "https://")):
|
||||
issues.append(f"Invalid origin format: {origin}")
|
||||
|
||||
# Check credentials with wildcard
|
||||
if settings.cors_allow_credentials and "*" in settings.cors_origins:
|
||||
issues.append("Cannot use credentials with wildcard origin")
|
||||
|
||||
return issues
|
||||
|
||||
|
||||
def get_cors_headers_for_origin(origin: str, settings: Settings) -> dict:
|
||||
"""Get appropriate CORS headers for a specific origin."""
|
||||
headers = {}
|
||||
|
||||
if not settings.cors_enabled:
|
||||
return headers
|
||||
|
||||
# Check if origin is allowed
|
||||
cors_middleware = CORSMiddleware(None, settings)
|
||||
if cors_middleware._is_origin_allowed(origin):
|
||||
headers["Access-Control-Allow-Origin"] = origin
|
||||
|
||||
if settings.cors_allow_credentials:
|
||||
headers["Access-Control-Allow-Credentials"] = "true"
|
||||
|
||||
return headers
|
||||
505
v1/src/middleware/error_handler.py
Normal file
505
v1/src/middleware/error_handler.py
Normal file
@@ -0,0 +1,505 @@
|
||||
"""
|
||||
Global error handling middleware for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
import time
|
||||
from typing import Dict, Any, Optional, Callable, Union
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import Request, Response, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.logger import get_request_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ErrorResponse:
|
||||
"""Standardized error response format."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_code: str,
|
||||
message: str,
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
status_code: int = 500,
|
||||
request_id: Optional[str] = None,
|
||||
):
|
||||
self.error_code = error_code
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
self.status_code = status_code
|
||||
self.request_id = request_id
|
||||
self.timestamp = datetime.utcnow().isoformat()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON response."""
|
||||
response = {
|
||||
"error": {
|
||||
"code": self.error_code,
|
||||
"message": self.message,
|
||||
"timestamp": self.timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
if self.details:
|
||||
response["error"]["details"] = self.details
|
||||
|
||||
if self.request_id:
|
||||
response["error"]["request_id"] = self.request_id
|
||||
|
||||
return response
|
||||
|
||||
def to_response(self) -> JSONResponse:
|
||||
"""Convert to FastAPI JSONResponse."""
|
||||
headers = {}
|
||||
if self.request_id:
|
||||
headers["X-Request-ID"] = self.request_id
|
||||
|
||||
return JSONResponse(
|
||||
status_code=self.status_code,
|
||||
content=self.to_dict(),
|
||||
headers=headers
|
||||
)
|
||||
|
||||
|
||||
class ErrorHandler:
|
||||
"""Central error handler for the application."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self.include_traceback = settings.debug and settings.is_development
|
||||
self.log_errors = True
|
||||
|
||||
def handle_http_exception(self, request: Request, exc: HTTPException) -> ErrorResponse:
|
||||
"""Handle HTTP exceptions."""
|
||||
request_context = get_request_context()
|
||||
request_id = request_context.get("request_id")
|
||||
|
||||
# Log the error
|
||||
if self.log_errors:
|
||||
logger.warning(
|
||||
f"HTTP {exc.status_code}: {exc.detail} - "
|
||||
f"{request.method} {request.url.path} - "
|
||||
f"Request ID: {request_id}"
|
||||
)
|
||||
|
||||
# Determine error code
|
||||
error_code = self._get_error_code_for_status(exc.status_code)
|
||||
|
||||
# Build error details
|
||||
details = {}
|
||||
if hasattr(exc, "headers") and exc.headers:
|
||||
details["headers"] = exc.headers
|
||||
|
||||
if self.include_traceback and hasattr(exc, "__traceback__"):
|
||||
details["traceback"] = traceback.format_exception(
|
||||
type(exc), exc, exc.__traceback__
|
||||
)
|
||||
|
||||
return ErrorResponse(
|
||||
error_code=error_code,
|
||||
message=str(exc.detail),
|
||||
details=details,
|
||||
status_code=exc.status_code,
|
||||
request_id=request_id
|
||||
)
|
||||
|
||||
def handle_validation_error(self, request: Request, exc: RequestValidationError) -> ErrorResponse:
|
||||
"""Handle request validation errors."""
|
||||
request_context = get_request_context()
|
||||
request_id = request_context.get("request_id")
|
||||
|
||||
# Log the error
|
||||
if self.log_errors:
|
||||
logger.warning(
|
||||
f"Validation error: {exc.errors()} - "
|
||||
f"{request.method} {request.url.path} - "
|
||||
f"Request ID: {request_id}"
|
||||
)
|
||||
|
||||
# Format validation errors
|
||||
validation_details = []
|
||||
for error in exc.errors():
|
||||
validation_details.append({
|
||||
"field": ".".join(str(loc) for loc in error["loc"]),
|
||||
"message": error["msg"],
|
||||
"type": error["type"],
|
||||
"input": error.get("input"),
|
||||
})
|
||||
|
||||
details = {
|
||||
"validation_errors": validation_details,
|
||||
"error_count": len(validation_details)
|
||||
}
|
||||
|
||||
if self.include_traceback:
|
||||
details["traceback"] = traceback.format_exception(
|
||||
type(exc), exc, exc.__traceback__
|
||||
)
|
||||
|
||||
return ErrorResponse(
|
||||
error_code="VALIDATION_ERROR",
|
||||
message="Request validation failed",
|
||||
details=details,
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
request_id=request_id
|
||||
)
|
||||
|
||||
def handle_pydantic_error(self, request: Request, exc: ValidationError) -> ErrorResponse:
|
||||
"""Handle Pydantic validation errors."""
|
||||
request_context = get_request_context()
|
||||
request_id = request_context.get("request_id")
|
||||
|
||||
# Log the error
|
||||
if self.log_errors:
|
||||
logger.warning(
|
||||
f"Pydantic validation error: {exc.errors()} - "
|
||||
f"{request.method} {request.url.path} - "
|
||||
f"Request ID: {request_id}"
|
||||
)
|
||||
|
||||
# Format validation errors
|
||||
validation_details = []
|
||||
for error in exc.errors():
|
||||
validation_details.append({
|
||||
"field": ".".join(str(loc) for loc in error["loc"]),
|
||||
"message": error["msg"],
|
||||
"type": error["type"],
|
||||
})
|
||||
|
||||
details = {
|
||||
"validation_errors": validation_details,
|
||||
"error_count": len(validation_details)
|
||||
}
|
||||
|
||||
return ErrorResponse(
|
||||
error_code="DATA_VALIDATION_ERROR",
|
||||
message="Data validation failed",
|
||||
details=details,
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
request_id=request_id
|
||||
)
|
||||
|
||||
def handle_generic_exception(self, request: Request, exc: Exception) -> ErrorResponse:
|
||||
"""Handle generic exceptions."""
|
||||
request_context = get_request_context()
|
||||
request_id = request_context.get("request_id")
|
||||
|
||||
# Log the error
|
||||
if self.log_errors:
|
||||
logger.error(
|
||||
f"Unhandled exception: {type(exc).__name__}: {exc} - "
|
||||
f"{request.method} {request.url.path} - "
|
||||
f"Request ID: {request_id}",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# Determine error details
|
||||
details = {
|
||||
"exception_type": type(exc).__name__,
|
||||
}
|
||||
|
||||
if self.include_traceback:
|
||||
details["traceback"] = traceback.format_exception(
|
||||
type(exc), exc, exc.__traceback__
|
||||
)
|
||||
|
||||
# Don't expose internal error details in production
|
||||
if self.settings.is_production:
|
||||
message = "An internal server error occurred"
|
||||
else:
|
||||
message = str(exc) or "An unexpected error occurred"
|
||||
|
||||
return ErrorResponse(
|
||||
error_code="INTERNAL_SERVER_ERROR",
|
||||
message=message,
|
||||
details=details,
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
request_id=request_id
|
||||
)
|
||||
|
||||
def handle_database_error(self, request: Request, exc: Exception) -> ErrorResponse:
|
||||
"""Handle database-related errors."""
|
||||
request_context = get_request_context()
|
||||
request_id = request_context.get("request_id")
|
||||
|
||||
# Log the error
|
||||
if self.log_errors:
|
||||
logger.error(
|
||||
f"Database error: {type(exc).__name__}: {exc} - "
|
||||
f"{request.method} {request.url.path} - "
|
||||
f"Request ID: {request_id}",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
details = {
|
||||
"exception_type": type(exc).__name__,
|
||||
"category": "database"
|
||||
}
|
||||
|
||||
if self.include_traceback:
|
||||
details["traceback"] = traceback.format_exception(
|
||||
type(exc), exc, exc.__traceback__
|
||||
)
|
||||
|
||||
return ErrorResponse(
|
||||
error_code="DATABASE_ERROR",
|
||||
message="Database operation failed" if self.settings.is_production else str(exc),
|
||||
details=details,
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
request_id=request_id
|
||||
)
|
||||
|
||||
def handle_external_service_error(self, request: Request, exc: Exception) -> ErrorResponse:
|
||||
"""Handle external service errors."""
|
||||
request_context = get_request_context()
|
||||
request_id = request_context.get("request_id")
|
||||
|
||||
# Log the error
|
||||
if self.log_errors:
|
||||
logger.error(
|
||||
f"External service error: {type(exc).__name__}: {exc} - "
|
||||
f"{request.method} {request.url.path} - "
|
||||
f"Request ID: {request_id}",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
details = {
|
||||
"exception_type": type(exc).__name__,
|
||||
"category": "external_service"
|
||||
}
|
||||
|
||||
return ErrorResponse(
|
||||
error_code="EXTERNAL_SERVICE_ERROR",
|
||||
message="External service unavailable" if self.settings.is_production else str(exc),
|
||||
details=details,
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
request_id=request_id
|
||||
)
|
||||
|
||||
def _get_error_code_for_status(self, status_code: int) -> str:
|
||||
"""Get error code for HTTP status code."""
|
||||
error_codes = {
|
||||
400: "BAD_REQUEST",
|
||||
401: "UNAUTHORIZED",
|
||||
403: "FORBIDDEN",
|
||||
404: "NOT_FOUND",
|
||||
405: "METHOD_NOT_ALLOWED",
|
||||
409: "CONFLICT",
|
||||
422: "UNPROCESSABLE_ENTITY",
|
||||
429: "TOO_MANY_REQUESTS",
|
||||
500: "INTERNAL_SERVER_ERROR",
|
||||
502: "BAD_GATEWAY",
|
||||
503: "SERVICE_UNAVAILABLE",
|
||||
504: "GATEWAY_TIMEOUT",
|
||||
}
|
||||
|
||||
return error_codes.get(status_code, "HTTP_ERROR")
|
||||
|
||||
|
||||
class ErrorHandlingMiddleware:
|
||||
"""Error handling middleware for FastAPI."""
|
||||
|
||||
def __init__(self, app, settings: Settings):
|
||||
self.app = app
|
||||
self.settings = settings
|
||||
self.error_handler = ErrorHandler(settings)
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
"""Process request through error handling middleware."""
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
await self.app(scope, receive, send)
|
||||
except Exception as exc:
|
||||
# Create a mock request for error handling
|
||||
from starlette.requests import Request
|
||||
request = Request(scope, receive)
|
||||
|
||||
# Handle different exception types
|
||||
if isinstance(exc, HTTPException):
|
||||
error_response = self.error_handler.handle_http_exception(request, exc)
|
||||
elif isinstance(exc, RequestValidationError):
|
||||
error_response = self.error_handler.handle_validation_error(request, exc)
|
||||
elif isinstance(exc, ValidationError):
|
||||
error_response = self.error_handler.handle_pydantic_error(request, exc)
|
||||
else:
|
||||
# Check for specific error types
|
||||
if self._is_database_error(exc):
|
||||
error_response = self.error_handler.handle_database_error(request, exc)
|
||||
elif self._is_external_service_error(exc):
|
||||
error_response = self.error_handler.handle_external_service_error(request, exc)
|
||||
else:
|
||||
error_response = self.error_handler.handle_generic_exception(request, exc)
|
||||
|
||||
# Send the error response
|
||||
response = error_response.to_response()
|
||||
await response(scope, receive, send)
|
||||
|
||||
finally:
|
||||
# Log request processing time
|
||||
processing_time = time.time() - start_time
|
||||
logger.debug(f"Error handling middleware processing time: {processing_time:.3f}s")
|
||||
|
||||
def _is_database_error(self, exc: Exception) -> bool:
|
||||
"""Check if exception is database-related."""
|
||||
database_exceptions = [
|
||||
"sqlalchemy",
|
||||
"psycopg2",
|
||||
"pymongo",
|
||||
"redis",
|
||||
"ConnectionError",
|
||||
"OperationalError",
|
||||
"IntegrityError",
|
||||
]
|
||||
|
||||
exc_module = getattr(type(exc), "__module__", "")
|
||||
exc_name = type(exc).__name__
|
||||
|
||||
return any(
|
||||
db_exc in exc_module or db_exc in exc_name
|
||||
for db_exc in database_exceptions
|
||||
)
|
||||
|
||||
def _is_external_service_error(self, exc: Exception) -> bool:
|
||||
"""Check if exception is external service-related."""
|
||||
external_exceptions = [
|
||||
"requests",
|
||||
"httpx",
|
||||
"aiohttp",
|
||||
"urllib",
|
||||
"ConnectionError",
|
||||
"TimeoutError",
|
||||
"ConnectTimeout",
|
||||
"ReadTimeout",
|
||||
]
|
||||
|
||||
exc_module = getattr(type(exc), "__module__", "")
|
||||
exc_name = type(exc).__name__
|
||||
|
||||
return any(
|
||||
ext_exc in exc_module or ext_exc in exc_name
|
||||
for ext_exc in external_exceptions
|
||||
)
|
||||
|
||||
|
||||
def setup_error_handling(app, settings: Settings):
|
||||
"""Setup error handling for the application."""
|
||||
logger.info("Setting up error handling middleware")
|
||||
|
||||
error_handler = ErrorHandler(settings)
|
||||
|
||||
# Add exception handlers
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, exc: HTTPException):
|
||||
error_response = error_handler.handle_http_exception(request, exc)
|
||||
return error_response.to_response()
|
||||
|
||||
@app.exception_handler(StarletteHTTPException)
|
||||
async def starlette_http_exception_handler(request: Request, exc: StarletteHTTPException):
|
||||
# Convert Starlette HTTPException to FastAPI HTTPException
|
||||
fastapi_exc = HTTPException(status_code=exc.status_code, detail=exc.detail)
|
||||
error_response = error_handler.handle_http_exception(request, fastapi_exc)
|
||||
return error_response.to_response()
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
error_response = error_handler.handle_validation_error(request, exc)
|
||||
return error_response.to_response()
|
||||
|
||||
@app.exception_handler(ValidationError)
|
||||
async def pydantic_exception_handler(request: Request, exc: ValidationError):
|
||||
error_response = error_handler.handle_pydantic_error(request, exc)
|
||||
return error_response.to_response()
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def generic_exception_handler(request: Request, exc: Exception):
|
||||
error_response = error_handler.handle_generic_exception(request, exc)
|
||||
return error_response.to_response()
|
||||
|
||||
# Add middleware for additional error handling
|
||||
# Note: We use exception handlers instead of custom middleware to avoid ASGI conflicts
|
||||
# The middleware approach is commented out but kept for reference
|
||||
# middleware = ErrorHandlingMiddleware(app, settings)
|
||||
# app.add_middleware(ErrorHandlingMiddleware, settings=settings)
|
||||
|
||||
logger.info("Error handling configured")
|
||||
|
||||
|
||||
class CustomHTTPException(HTTPException):
|
||||
"""Custom HTTP exception with additional context."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
detail: str,
|
||||
error_code: Optional[str] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
super().__init__(status_code=status_code, detail=detail, headers=headers)
|
||||
self.error_code = error_code
|
||||
self.context = context or {}
|
||||
|
||||
|
||||
class BusinessLogicError(CustomHTTPException):
|
||||
"""Exception for business logic errors."""
|
||||
|
||||
def __init__(self, message: str, context: Optional[Dict[str, Any]] = None):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=message,
|
||||
error_code="BUSINESS_LOGIC_ERROR",
|
||||
context=context
|
||||
)
|
||||
|
||||
|
||||
class ResourceNotFoundError(CustomHTTPException):
|
||||
"""Exception for resource not found errors."""
|
||||
|
||||
def __init__(self, resource: str, identifier: str):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"{resource} not found",
|
||||
error_code="RESOURCE_NOT_FOUND",
|
||||
context={"resource": resource, "identifier": identifier}
|
||||
)
|
||||
|
||||
|
||||
class ConflictError(CustomHTTPException):
|
||||
"""Exception for conflict errors."""
|
||||
|
||||
def __init__(self, message: str, context: Optional[Dict[str, Any]] = None):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=message,
|
||||
error_code="CONFLICT_ERROR",
|
||||
context=context
|
||||
)
|
||||
|
||||
|
||||
class ServiceUnavailableError(CustomHTTPException):
|
||||
"""Exception for service unavailable errors."""
|
||||
|
||||
def __init__(self, service: str, reason: Optional[str] = None):
|
||||
detail = f"{service} service is unavailable"
|
||||
if reason:
|
||||
detail += f": {reason}"
|
||||
|
||||
super().__init__(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=detail,
|
||||
error_code="SERVICE_UNAVAILABLE",
|
||||
context={"service": service, "reason": reason}
|
||||
)
|
||||
465
v1/src/middleware/rate_limit.py
Normal file
465
v1/src/middleware/rate_limit.py
Normal file
@@ -0,0 +1,465 @@
|
||||
"""
|
||||
Rate limiting middleware for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Any, Optional, Callable, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fastapi import Request, Response, HTTPException, status
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from src.config.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitInfo:
|
||||
"""Rate limit information."""
|
||||
requests: int
|
||||
window_start: float
|
||||
window_size: int
|
||||
limit: int
|
||||
|
||||
@property
|
||||
def remaining(self) -> int:
|
||||
"""Get remaining requests in current window."""
|
||||
return max(0, self.limit - self.requests)
|
||||
|
||||
@property
|
||||
def reset_time(self) -> float:
|
||||
"""Get time when window resets."""
|
||||
return self.window_start + self.window_size
|
||||
|
||||
@property
|
||||
def is_exceeded(self) -> bool:
|
||||
"""Check if rate limit is exceeded."""
|
||||
return self.requests >= self.limit
|
||||
|
||||
|
||||
class TokenBucket:
|
||||
"""Token bucket algorithm for rate limiting."""
|
||||
|
||||
def __init__(self, capacity: int, refill_rate: float):
|
||||
self.capacity = capacity
|
||||
self.tokens = capacity
|
||||
self.refill_rate = refill_rate
|
||||
self.last_refill = time.time()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def consume(self, tokens: int = 1) -> bool:
|
||||
"""Try to consume tokens from bucket."""
|
||||
async with self._lock:
|
||||
now = time.time()
|
||||
|
||||
# Refill tokens based on time elapsed
|
||||
time_passed = now - self.last_refill
|
||||
tokens_to_add = time_passed * self.refill_rate
|
||||
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
|
||||
self.last_refill = now
|
||||
|
||||
# Check if we have enough tokens
|
||||
if self.tokens >= tokens:
|
||||
self.tokens -= tokens
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_info(self) -> Dict[str, Any]:
|
||||
"""Get bucket information."""
|
||||
return {
|
||||
"capacity": self.capacity,
|
||||
"tokens": self.tokens,
|
||||
"refill_rate": self.refill_rate,
|
||||
"last_refill": self.last_refill
|
||||
}
|
||||
|
||||
|
||||
class SlidingWindowCounter:
|
||||
"""Sliding window counter for rate limiting."""
|
||||
|
||||
def __init__(self, window_size: int, limit: int):
|
||||
self.window_size = window_size
|
||||
self.limit = limit
|
||||
self.requests = deque()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def is_allowed(self) -> Tuple[bool, RateLimitInfo]:
|
||||
"""Check if request is allowed."""
|
||||
async with self._lock:
|
||||
now = time.time()
|
||||
window_start = now - self.window_size
|
||||
|
||||
# Remove old requests outside the window
|
||||
while self.requests and self.requests[0] < window_start:
|
||||
self.requests.popleft()
|
||||
|
||||
# Check if limit is exceeded
|
||||
current_requests = len(self.requests)
|
||||
allowed = current_requests < self.limit
|
||||
|
||||
if allowed:
|
||||
self.requests.append(now)
|
||||
|
||||
rate_limit_info = RateLimitInfo(
|
||||
requests=current_requests + (1 if allowed else 0),
|
||||
window_start=window_start,
|
||||
window_size=self.window_size,
|
||||
limit=self.limit
|
||||
)
|
||||
|
||||
return allowed, rate_limit_info
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Rate limiter with multiple algorithms."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self.enabled = settings.enable_rate_limiting
|
||||
|
||||
# Rate limit configurations
|
||||
self.default_limit = settings.rate_limit_requests
|
||||
self.authenticated_limit = settings.rate_limit_authenticated_requests
|
||||
self.window_size = settings.rate_limit_window
|
||||
|
||||
# Storage for rate limit data
|
||||
self._sliding_windows: Dict[str, SlidingWindowCounter] = {}
|
||||
self._token_buckets: Dict[str, TokenBucket] = {}
|
||||
|
||||
# Cleanup task
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
self._cleanup_interval = 300 # 5 minutes
|
||||
|
||||
async def start(self):
|
||||
"""Start rate limiter background tasks."""
|
||||
if self.enabled:
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
logger.info("Rate limiter started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop rate limiter background tasks."""
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Rate limiter stopped")
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""Background task to cleanup old rate limit data."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self._cleanup_interval)
|
||||
await self._cleanup_old_data()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in rate limiter cleanup: {e}")
|
||||
|
||||
async def _cleanup_old_data(self):
|
||||
"""Remove old rate limit data."""
|
||||
now = time.time()
|
||||
cutoff = now - (self.window_size * 2) # Keep data for 2 windows
|
||||
|
||||
# Cleanup sliding windows
|
||||
keys_to_remove = []
|
||||
for key, window in self._sliding_windows.items():
|
||||
# Remove old requests
|
||||
while window.requests and window.requests[0] < cutoff:
|
||||
window.requests.popleft()
|
||||
|
||||
# Remove empty windows
|
||||
if not window.requests:
|
||||
keys_to_remove.append(key)
|
||||
|
||||
for key in keys_to_remove:
|
||||
del self._sliding_windows[key]
|
||||
|
||||
logger.debug(f"Cleaned up {len(keys_to_remove)} old rate limit windows")
|
||||
|
||||
def _get_client_identifier(self, request: Request) -> str:
|
||||
"""Get client identifier for rate limiting."""
|
||||
# Try to get user ID from authenticated request
|
||||
user = getattr(request.state, "user", None)
|
||||
if user:
|
||||
return f"user:{user.get('username', 'unknown')}"
|
||||
|
||||
# Fall back to IP address
|
||||
client_ip = self._get_client_ip(request)
|
||||
return f"ip:{client_ip}"
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Get client IP address."""
|
||||
# Check for forwarded headers
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
# Fall back to direct connection
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
def _get_rate_limit(self, request: Request) -> int:
|
||||
"""Get rate limit for request."""
|
||||
# Check if user is authenticated
|
||||
user = getattr(request.state, "user", None)
|
||||
if user:
|
||||
return self.authenticated_limit
|
||||
|
||||
return self.default_limit
|
||||
|
||||
def _get_rate_limit_key(self, request: Request) -> str:
|
||||
"""Get rate limit key for request."""
|
||||
client_id = self._get_client_identifier(request)
|
||||
endpoint = f"{request.method}:{request.url.path}"
|
||||
return f"{client_id}:{endpoint}"
|
||||
|
||||
async def check_rate_limit(self, request: Request) -> Tuple[bool, RateLimitInfo]:
|
||||
"""Check if request is within rate limits."""
|
||||
if not self.enabled:
|
||||
# Return dummy info when rate limiting is disabled
|
||||
return True, RateLimitInfo(
|
||||
requests=0,
|
||||
window_start=time.time(),
|
||||
window_size=self.window_size,
|
||||
limit=float('inf')
|
||||
)
|
||||
|
||||
key = self._get_rate_limit_key(request)
|
||||
limit = self._get_rate_limit(request)
|
||||
|
||||
# Get or create sliding window counter
|
||||
if key not in self._sliding_windows:
|
||||
self._sliding_windows[key] = SlidingWindowCounter(self.window_size, limit)
|
||||
|
||||
window = self._sliding_windows[key]
|
||||
|
||||
# Update limit if it changed (e.g., user authenticated)
|
||||
window.limit = limit
|
||||
|
||||
return await window.is_allowed()
|
||||
|
||||
async def check_token_bucket(self, request: Request, tokens: int = 1) -> bool:
|
||||
"""Check rate limit using token bucket algorithm."""
|
||||
if not self.enabled:
|
||||
return True
|
||||
|
||||
key = self._get_client_identifier(request)
|
||||
limit = self._get_rate_limit(request)
|
||||
|
||||
# Get or create token bucket
|
||||
if key not in self._token_buckets:
|
||||
# Refill rate: limit per window size
|
||||
refill_rate = limit / self.window_size
|
||||
self._token_buckets[key] = TokenBucket(limit, refill_rate)
|
||||
|
||||
bucket = self._token_buckets[key]
|
||||
return await bucket.consume(tokens)
|
||||
|
||||
def get_rate_limit_headers(self, rate_limit_info: RateLimitInfo) -> Dict[str, str]:
|
||||
"""Get rate limit headers for response."""
|
||||
return {
|
||||
"X-RateLimit-Limit": str(rate_limit_info.limit),
|
||||
"X-RateLimit-Remaining": str(rate_limit_info.remaining),
|
||||
"X-RateLimit-Reset": str(int(rate_limit_info.reset_time)),
|
||||
"X-RateLimit-Window": str(rate_limit_info.window_size),
|
||||
}
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get rate limiter statistics."""
|
||||
return {
|
||||
"enabled": self.enabled,
|
||||
"default_limit": self.default_limit,
|
||||
"authenticated_limit": self.authenticated_limit,
|
||||
"window_size": self.window_size,
|
||||
"active_windows": len(self._sliding_windows),
|
||||
"active_buckets": len(self._token_buckets),
|
||||
}
|
||||
|
||||
|
||||
class RateLimitMiddleware:
|
||||
"""Rate limiting middleware for FastAPI."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self.rate_limiter = RateLimiter(settings)
|
||||
self.enabled = settings.enable_rate_limiting
|
||||
|
||||
async def __call__(self, request: Request, call_next: Callable) -> Response:
|
||||
"""Process request through rate limiting middleware."""
|
||||
if not self.enabled:
|
||||
return await call_next(request)
|
||||
|
||||
# Skip rate limiting for certain paths
|
||||
if self._should_skip_rate_limit(request):
|
||||
return await call_next(request)
|
||||
|
||||
try:
|
||||
# Check rate limit
|
||||
allowed, rate_limit_info = await self.rate_limiter.check_rate_limit(request)
|
||||
|
||||
if not allowed:
|
||||
# Rate limit exceeded
|
||||
logger.warning(
|
||||
f"Rate limit exceeded for {self.rate_limiter._get_client_identifier(request)} "
|
||||
f"on {request.method} {request.url.path}"
|
||||
)
|
||||
|
||||
headers = self.rate_limiter.get_rate_limit_headers(rate_limit_info)
|
||||
headers["Retry-After"] = str(int(rate_limit_info.reset_time - time.time()))
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Rate limit exceeded",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add rate limit headers to response
|
||||
headers = self.rate_limiter.get_rate_limit_headers(rate_limit_info)
|
||||
for key, value in headers.items():
|
||||
response.headers[key] = value
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Rate limiting middleware error: {e}")
|
||||
# Continue without rate limiting on error
|
||||
return await call_next(request)
|
||||
|
||||
def _should_skip_rate_limit(self, request: Request) -> bool:
|
||||
"""Check if rate limiting should be skipped for this request."""
|
||||
path = request.url.path
|
||||
|
||||
# Skip rate limiting for these paths
|
||||
skip_paths = [
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/static",
|
||||
]
|
||||
|
||||
return any(path.startswith(skip_path) for skip_path in skip_paths)
|
||||
|
||||
async def start(self):
|
||||
"""Start rate limiting middleware."""
|
||||
await self.rate_limiter.start()
|
||||
|
||||
async def stop(self):
|
||||
"""Stop rate limiting middleware."""
|
||||
await self.rate_limiter.stop()
|
||||
|
||||
|
||||
# Global rate limit middleware instance
|
||||
_rate_limit_middleware: Optional[RateLimitMiddleware] = None
|
||||
|
||||
|
||||
def get_rate_limit_middleware(settings: Settings) -> RateLimitMiddleware:
|
||||
"""Get rate limit middleware instance."""
|
||||
global _rate_limit_middleware
|
||||
if _rate_limit_middleware is None:
|
||||
_rate_limit_middleware = RateLimitMiddleware(settings)
|
||||
return _rate_limit_middleware
|
||||
|
||||
|
||||
def setup_rate_limiting(app: ASGIApp, settings: Settings) -> ASGIApp:
|
||||
"""Setup rate limiting middleware for the application."""
|
||||
if settings.enable_rate_limiting:
|
||||
logger.info("Setting up rate limiting middleware")
|
||||
|
||||
middleware = get_rate_limit_middleware(settings)
|
||||
|
||||
# Add middleware to app
|
||||
@app.middleware("http")
|
||||
async def rate_limit_middleware(request: Request, call_next):
|
||||
return await middleware(request, call_next)
|
||||
|
||||
logger.info(
|
||||
f"Rate limiting enabled - Default: {settings.rate_limit_requests}/"
|
||||
f"{settings.rate_limit_window}s, Authenticated: "
|
||||
f"{settings.rate_limit_authenticated_requests}/{settings.rate_limit_window}s"
|
||||
)
|
||||
else:
|
||||
logger.info("Rate limiting disabled")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class RateLimitConfig:
|
||||
"""Rate limiting configuration helper."""
|
||||
|
||||
@staticmethod
|
||||
def development_config() -> dict:
|
||||
"""Get rate limiting configuration for development."""
|
||||
return {
|
||||
"enable_rate_limiting": False, # Disabled in development
|
||||
"rate_limit_requests": 1000,
|
||||
"rate_limit_authenticated_requests": 5000,
|
||||
"rate_limit_window": 3600, # 1 hour
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def production_config() -> dict:
|
||||
"""Get rate limiting configuration for production."""
|
||||
return {
|
||||
"enable_rate_limiting": True,
|
||||
"rate_limit_requests": 100, # 100 requests per hour for unauthenticated
|
||||
"rate_limit_authenticated_requests": 1000, # 1000 requests per hour for authenticated
|
||||
"rate_limit_window": 3600, # 1 hour
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def api_config() -> dict:
|
||||
"""Get rate limiting configuration for API access."""
|
||||
return {
|
||||
"enable_rate_limiting": True,
|
||||
"rate_limit_requests": 60, # 60 requests per minute
|
||||
"rate_limit_authenticated_requests": 300, # 300 requests per minute
|
||||
"rate_limit_window": 60, # 1 minute
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def strict_config() -> dict:
|
||||
"""Get strict rate limiting configuration."""
|
||||
return {
|
||||
"enable_rate_limiting": True,
|
||||
"rate_limit_requests": 10, # 10 requests per minute
|
||||
"rate_limit_authenticated_requests": 100, # 100 requests per minute
|
||||
"rate_limit_window": 60, # 1 minute
|
||||
}
|
||||
|
||||
|
||||
def validate_rate_limit_config(settings: Settings) -> list:
|
||||
"""Validate rate limiting configuration."""
|
||||
issues = []
|
||||
|
||||
if settings.enable_rate_limiting:
|
||||
if settings.rate_limit_requests <= 0:
|
||||
issues.append("Rate limit requests must be positive")
|
||||
|
||||
if settings.rate_limit_authenticated_requests <= 0:
|
||||
issues.append("Authenticated rate limit requests must be positive")
|
||||
|
||||
if settings.rate_limit_window <= 0:
|
||||
issues.append("Rate limit window must be positive")
|
||||
|
||||
if settings.rate_limit_authenticated_requests < settings.rate_limit_requests:
|
||||
issues.append("Authenticated rate limit should be higher than default rate limit")
|
||||
|
||||
return issues
|
||||
0
v1/src/models/__init__.py
Normal file
0
v1/src/models/__init__.py
Normal file
279
v1/src/models/densepose_head.py
Normal file
279
v1/src/models/densepose_head.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""DensePose head for WiFi-DensePose system."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Any, Tuple, List
|
||||
|
||||
|
||||
class DensePoseError(Exception):
|
||||
"""Exception raised for DensePose head errors."""
|
||||
pass
|
||||
|
||||
|
||||
class DensePoseHead(nn.Module):
|
||||
"""DensePose head for body part segmentation and UV coordinate regression."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""Initialize DensePose head.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary with head parameters
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self._validate_config(config)
|
||||
self.config = config
|
||||
|
||||
self.input_channels = config['input_channels']
|
||||
self.num_body_parts = config['num_body_parts']
|
||||
self.num_uv_coordinates = config['num_uv_coordinates']
|
||||
self.hidden_channels = config.get('hidden_channels', [128, 64])
|
||||
self.kernel_size = config.get('kernel_size', 3)
|
||||
self.padding = config.get('padding', 1)
|
||||
self.dropout_rate = config.get('dropout_rate', 0.1)
|
||||
self.use_deformable_conv = config.get('use_deformable_conv', False)
|
||||
self.use_fpn = config.get('use_fpn', False)
|
||||
self.fpn_levels = config.get('fpn_levels', [2, 3, 4, 5])
|
||||
self.output_stride = config.get('output_stride', 4)
|
||||
|
||||
# Feature Pyramid Network (optional)
|
||||
if self.use_fpn:
|
||||
self.fpn = self._build_fpn()
|
||||
|
||||
# Shared feature processing
|
||||
self.shared_conv = self._build_shared_layers()
|
||||
|
||||
# Segmentation head for body part classification
|
||||
self.segmentation_head = self._build_segmentation_head()
|
||||
|
||||
# UV regression head for coordinate prediction
|
||||
self.uv_regression_head = self._build_uv_regression_head()
|
||||
|
||||
# Initialize weights
|
||||
self._initialize_weights()
|
||||
|
||||
def _validate_config(self, config: Dict[str, Any]):
|
||||
"""Validate configuration parameters."""
|
||||
required_fields = ['input_channels', 'num_body_parts', 'num_uv_coordinates']
|
||||
for field in required_fields:
|
||||
if field not in config:
|
||||
raise ValueError(f"Missing required field: {field}")
|
||||
|
||||
if config['input_channels'] <= 0:
|
||||
raise ValueError("input_channels must be positive")
|
||||
|
||||
if config['num_body_parts'] <= 0:
|
||||
raise ValueError("num_body_parts must be positive")
|
||||
|
||||
if config['num_uv_coordinates'] <= 0:
|
||||
raise ValueError("num_uv_coordinates must be positive")
|
||||
|
||||
def _build_fpn(self) -> nn.Module:
|
||||
"""Build Feature Pyramid Network."""
|
||||
return nn.ModuleDict({
|
||||
f'level_{level}': nn.Conv2d(self.input_channels, self.input_channels, 1)
|
||||
for level in self.fpn_levels
|
||||
})
|
||||
|
||||
def _build_shared_layers(self) -> nn.Module:
|
||||
"""Build shared feature processing layers."""
|
||||
layers = []
|
||||
in_channels = self.input_channels
|
||||
|
||||
for hidden_dim in self.hidden_channels:
|
||||
layers.extend([
|
||||
nn.Conv2d(in_channels, hidden_dim,
|
||||
kernel_size=self.kernel_size,
|
||||
padding=self.padding),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout2d(self.dropout_rate)
|
||||
])
|
||||
in_channels = hidden_dim
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _build_segmentation_head(self) -> nn.Module:
|
||||
"""Build segmentation head for body part classification."""
|
||||
final_hidden = self.hidden_channels[-1] if self.hidden_channels else self.input_channels
|
||||
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(final_hidden, final_hidden // 2,
|
||||
kernel_size=self.kernel_size,
|
||||
padding=self.padding),
|
||||
nn.BatchNorm2d(final_hidden // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout2d(self.dropout_rate),
|
||||
|
||||
# Upsampling to increase resolution
|
||||
nn.ConvTranspose2d(final_hidden // 2, final_hidden // 4,
|
||||
kernel_size=4, stride=2, padding=1),
|
||||
nn.BatchNorm2d(final_hidden // 4),
|
||||
nn.ReLU(inplace=True),
|
||||
|
||||
nn.Conv2d(final_hidden // 4, self.num_body_parts + 1, kernel_size=1),
|
||||
# +1 for background class
|
||||
)
|
||||
|
||||
def _build_uv_regression_head(self) -> nn.Module:
|
||||
"""Build UV regression head for coordinate prediction."""
|
||||
final_hidden = self.hidden_channels[-1] if self.hidden_channels else self.input_channels
|
||||
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(final_hidden, final_hidden // 2,
|
||||
kernel_size=self.kernel_size,
|
||||
padding=self.padding),
|
||||
nn.BatchNorm2d(final_hidden // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout2d(self.dropout_rate),
|
||||
|
||||
# Upsampling to increase resolution
|
||||
nn.ConvTranspose2d(final_hidden // 2, final_hidden // 4,
|
||||
kernel_size=4, stride=2, padding=1),
|
||||
nn.BatchNorm2d(final_hidden // 4),
|
||||
nn.ReLU(inplace=True),
|
||||
|
||||
nn.Conv2d(final_hidden // 4, self.num_uv_coordinates, kernel_size=1),
|
||||
)
|
||||
|
||||
def _initialize_weights(self):
|
||||
"""Initialize network weights."""
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
"""Forward pass through the DensePose head.
|
||||
|
||||
Args:
|
||||
x: Input feature tensor of shape (batch_size, channels, height, width)
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- segmentation: Body part logits (batch_size, num_parts+1, height, width)
|
||||
- uv_coordinates: UV coordinates (batch_size, 2, height, width)
|
||||
"""
|
||||
# Validate input shape
|
||||
if x.shape[1] != self.input_channels:
|
||||
raise DensePoseError(f"Expected {self.input_channels} input channels, got {x.shape[1]}")
|
||||
|
||||
# Apply FPN if enabled
|
||||
if self.use_fpn:
|
||||
# Simple FPN processing - in practice this would be more sophisticated
|
||||
x = self.fpn['level_2'](x)
|
||||
|
||||
# Shared feature processing
|
||||
shared_features = self.shared_conv(x)
|
||||
|
||||
# Segmentation branch
|
||||
segmentation_logits = self.segmentation_head(shared_features)
|
||||
|
||||
# UV regression branch
|
||||
uv_coordinates = self.uv_regression_head(shared_features)
|
||||
uv_coordinates = torch.sigmoid(uv_coordinates) # Normalize to [0, 1]
|
||||
|
||||
return {
|
||||
'segmentation': segmentation_logits,
|
||||
'uv_coordinates': uv_coordinates
|
||||
}
|
||||
|
||||
def compute_segmentation_loss(self, pred_logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute segmentation loss.
|
||||
|
||||
Args:
|
||||
pred_logits: Predicted segmentation logits
|
||||
target: Target segmentation masks
|
||||
|
||||
Returns:
|
||||
Computed cross-entropy loss
|
||||
"""
|
||||
return F.cross_entropy(pred_logits, target, ignore_index=-1)
|
||||
|
||||
def compute_uv_loss(self, pred_uv: torch.Tensor, target_uv: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute UV coordinate regression loss.
|
||||
|
||||
Args:
|
||||
pred_uv: Predicted UV coordinates
|
||||
target_uv: Target UV coordinates
|
||||
|
||||
Returns:
|
||||
Computed L1 loss
|
||||
"""
|
||||
return F.l1_loss(pred_uv, target_uv)
|
||||
|
||||
def compute_total_loss(self, predictions: Dict[str, torch.Tensor],
|
||||
seg_target: torch.Tensor,
|
||||
uv_target: torch.Tensor,
|
||||
seg_weight: float = 1.0,
|
||||
uv_weight: float = 1.0) -> torch.Tensor:
|
||||
"""Compute total loss combining segmentation and UV losses.
|
||||
|
||||
Args:
|
||||
predictions: Dictionary of predictions
|
||||
seg_target: Target segmentation masks
|
||||
uv_target: Target UV coordinates
|
||||
seg_weight: Weight for segmentation loss
|
||||
uv_weight: Weight for UV loss
|
||||
|
||||
Returns:
|
||||
Combined loss
|
||||
"""
|
||||
seg_loss = self.compute_segmentation_loss(predictions['segmentation'], seg_target)
|
||||
uv_loss = self.compute_uv_loss(predictions['uv_coordinates'], uv_target)
|
||||
|
||||
return seg_weight * seg_loss + uv_weight * uv_loss
|
||||
|
||||
def get_prediction_confidence(self, predictions: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""Get prediction confidence scores.
|
||||
|
||||
Args:
|
||||
predictions: Dictionary of predictions
|
||||
|
||||
Returns:
|
||||
Dictionary of confidence scores
|
||||
"""
|
||||
seg_logits = predictions['segmentation']
|
||||
uv_coords = predictions['uv_coordinates']
|
||||
|
||||
# Segmentation confidence: max probability
|
||||
seg_probs = F.softmax(seg_logits, dim=1)
|
||||
seg_confidence = torch.max(seg_probs, dim=1)[0]
|
||||
|
||||
# UV confidence: inverse of prediction variance
|
||||
uv_variance = torch.var(uv_coords, dim=1, keepdim=True)
|
||||
uv_confidence = 1.0 / (1.0 + uv_variance)
|
||||
|
||||
return {
|
||||
'segmentation_confidence': seg_confidence,
|
||||
'uv_confidence': uv_confidence.squeeze(1)
|
||||
}
|
||||
|
||||
def post_process_predictions(self, predictions: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""Post-process predictions for final output.
|
||||
|
||||
Args:
|
||||
predictions: Raw predictions from forward pass
|
||||
|
||||
Returns:
|
||||
Post-processed predictions
|
||||
"""
|
||||
seg_logits = predictions['segmentation']
|
||||
uv_coords = predictions['uv_coordinates']
|
||||
|
||||
# Convert logits to class predictions
|
||||
body_parts = torch.argmax(seg_logits, dim=1)
|
||||
|
||||
# Get confidence scores
|
||||
confidence = self.get_prediction_confidence(predictions)
|
||||
|
||||
return {
|
||||
'body_parts': body_parts,
|
||||
'uv_coordinates': uv_coords,
|
||||
'confidence_scores': confidence
|
||||
}
|
||||
301
v1/src/models/modality_translation.py
Normal file
301
v1/src/models/modality_translation.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""Modality translation network for WiFi-DensePose system."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Any, List
|
||||
|
||||
|
||||
class ModalityTranslationError(Exception):
|
||||
"""Exception raised for modality translation errors."""
|
||||
pass
|
||||
|
||||
|
||||
class ModalityTranslationNetwork(nn.Module):
|
||||
"""Neural network for translating CSI data to visual feature space."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""Initialize modality translation network.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary with network parameters
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self._validate_config(config)
|
||||
self.config = config
|
||||
|
||||
self.input_channels = config['input_channels']
|
||||
self.hidden_channels = config['hidden_channels']
|
||||
self.output_channels = config['output_channels']
|
||||
self.kernel_size = config.get('kernel_size', 3)
|
||||
self.stride = config.get('stride', 1)
|
||||
self.padding = config.get('padding', 1)
|
||||
self.dropout_rate = config.get('dropout_rate', 0.1)
|
||||
self.activation = config.get('activation', 'relu')
|
||||
self.normalization = config.get('normalization', 'batch')
|
||||
self.use_attention = config.get('use_attention', False)
|
||||
self.attention_heads = config.get('attention_heads', 8)
|
||||
|
||||
# Encoder: CSI -> Feature space
|
||||
self.encoder = self._build_encoder()
|
||||
|
||||
# Decoder: Feature space -> Visual-like features
|
||||
self.decoder = self._build_decoder()
|
||||
|
||||
# Attention mechanism
|
||||
if self.use_attention:
|
||||
self.attention = self._build_attention()
|
||||
|
||||
# Initialize weights
|
||||
self._initialize_weights()
|
||||
|
||||
def _validate_config(self, config: Dict[str, Any]):
|
||||
"""Validate configuration parameters."""
|
||||
required_fields = ['input_channels', 'hidden_channels', 'output_channels']
|
||||
for field in required_fields:
|
||||
if field not in config:
|
||||
raise ValueError(f"Missing required field: {field}")
|
||||
|
||||
if config['input_channels'] <= 0:
|
||||
raise ValueError("input_channels must be positive")
|
||||
|
||||
if not config['hidden_channels'] or len(config['hidden_channels']) == 0:
|
||||
raise ValueError("hidden_channels must be a non-empty list")
|
||||
|
||||
if config['output_channels'] <= 0:
|
||||
raise ValueError("output_channels must be positive")
|
||||
|
||||
def _build_encoder(self) -> nn.ModuleList:
|
||||
"""Build encoder network."""
|
||||
layers = nn.ModuleList()
|
||||
|
||||
# Initial convolution
|
||||
in_channels = self.input_channels
|
||||
|
||||
for i, out_channels in enumerate(self.hidden_channels):
|
||||
layer_block = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=self.stride if i == 0 else 2,
|
||||
padding=self.padding),
|
||||
self._get_normalization(out_channels),
|
||||
self._get_activation(),
|
||||
nn.Dropout2d(self.dropout_rate)
|
||||
)
|
||||
layers.append(layer_block)
|
||||
in_channels = out_channels
|
||||
|
||||
return layers
|
||||
|
||||
def _build_decoder(self) -> nn.ModuleList:
|
||||
"""Build decoder network."""
|
||||
layers = nn.ModuleList()
|
||||
|
||||
# Start with the last hidden channel size
|
||||
in_channels = self.hidden_channels[-1]
|
||||
|
||||
# Progressive upsampling (reverse of encoder)
|
||||
for i, out_channels in enumerate(reversed(self.hidden_channels[:-1])):
|
||||
layer_block = nn.Sequential(
|
||||
nn.ConvTranspose2d(in_channels, out_channels,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=2,
|
||||
padding=self.padding,
|
||||
output_padding=1),
|
||||
self._get_normalization(out_channels),
|
||||
self._get_activation(),
|
||||
nn.Dropout2d(self.dropout_rate)
|
||||
)
|
||||
layers.append(layer_block)
|
||||
in_channels = out_channels
|
||||
|
||||
# Final output layer
|
||||
final_layer = nn.Sequential(
|
||||
nn.Conv2d(in_channels, self.output_channels,
|
||||
kernel_size=self.kernel_size,
|
||||
padding=self.padding),
|
||||
nn.Tanh() # Normalize output
|
||||
)
|
||||
layers.append(final_layer)
|
||||
|
||||
return layers
|
||||
|
||||
def _get_normalization(self, channels: int) -> nn.Module:
|
||||
"""Get normalization layer."""
|
||||
if self.normalization == 'batch':
|
||||
return nn.BatchNorm2d(channels)
|
||||
elif self.normalization == 'instance':
|
||||
return nn.InstanceNorm2d(channels)
|
||||
elif self.normalization == 'layer':
|
||||
return nn.GroupNorm(1, channels)
|
||||
else:
|
||||
return nn.Identity()
|
||||
|
||||
def _get_activation(self) -> nn.Module:
|
||||
"""Get activation function."""
|
||||
if self.activation == 'relu':
|
||||
return nn.ReLU(inplace=True)
|
||||
elif self.activation == 'leaky_relu':
|
||||
return nn.LeakyReLU(0.2, inplace=True)
|
||||
elif self.activation == 'gelu':
|
||||
return nn.GELU()
|
||||
else:
|
||||
return nn.ReLU(inplace=True)
|
||||
|
||||
def _build_attention(self) -> nn.Module:
|
||||
"""Build attention mechanism."""
|
||||
return nn.MultiheadAttention(
|
||||
embed_dim=self.hidden_channels[-1],
|
||||
num_heads=self.attention_heads,
|
||||
dropout=self.dropout_rate,
|
||||
batch_first=True
|
||||
)
|
||||
|
||||
def _initialize_weights(self):
|
||||
"""Initialize network weights."""
|
||||
for m in self.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass through the network.
|
||||
|
||||
Args:
|
||||
x: Input CSI tensor of shape (batch_size, channels, height, width)
|
||||
|
||||
Returns:
|
||||
Translated features tensor
|
||||
"""
|
||||
# Validate input shape
|
||||
if x.shape[1] != self.input_channels:
|
||||
raise ModalityTranslationError(f"Expected {self.input_channels} input channels, got {x.shape[1]}")
|
||||
|
||||
# Encode CSI data
|
||||
encoded_features = self.encode(x)
|
||||
|
||||
# Decode to visual-like features
|
||||
decoded = self.decode(encoded_features)
|
||||
|
||||
return decoded
|
||||
|
||||
def encode(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
"""Encode input through encoder layers.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
|
||||
Returns:
|
||||
List of feature maps from each encoder layer
|
||||
"""
|
||||
features = []
|
||||
current = x
|
||||
|
||||
for layer in self.encoder:
|
||||
current = layer(current)
|
||||
features.append(current)
|
||||
|
||||
return features
|
||||
|
||||
def decode(self, encoded_features: List[torch.Tensor]) -> torch.Tensor:
|
||||
"""Decode features through decoder layers.
|
||||
|
||||
Args:
|
||||
encoded_features: List of encoded feature maps
|
||||
|
||||
Returns:
|
||||
Decoded output tensor
|
||||
"""
|
||||
# Start with the last encoded feature
|
||||
current = encoded_features[-1]
|
||||
|
||||
# Apply attention if enabled
|
||||
if self.use_attention:
|
||||
batch_size, channels, height, width = current.shape
|
||||
# Reshape for attention: (batch, seq_len, embed_dim)
|
||||
current_flat = current.view(batch_size, channels, -1).transpose(1, 2)
|
||||
attended, _ = self.attention(current_flat, current_flat, current_flat)
|
||||
current = attended.transpose(1, 2).view(batch_size, channels, height, width)
|
||||
|
||||
# Apply decoder layers
|
||||
for layer in self.decoder:
|
||||
current = layer(current)
|
||||
|
||||
return current
|
||||
|
||||
def compute_translation_loss(self, predicted: torch.Tensor, target: torch.Tensor, loss_type: str = 'mse') -> torch.Tensor:
|
||||
"""Compute translation loss between predicted and target features.
|
||||
|
||||
Args:
|
||||
predicted: Predicted feature tensor
|
||||
target: Target feature tensor
|
||||
loss_type: Type of loss ('mse', 'l1', 'smooth_l1')
|
||||
|
||||
Returns:
|
||||
Computed loss tensor
|
||||
"""
|
||||
if loss_type == 'mse':
|
||||
return F.mse_loss(predicted, target)
|
||||
elif loss_type == 'l1':
|
||||
return F.l1_loss(predicted, target)
|
||||
elif loss_type == 'smooth_l1':
|
||||
return F.smooth_l1_loss(predicted, target)
|
||||
else:
|
||||
return F.mse_loss(predicted, target)
|
||||
|
||||
def get_feature_statistics(self, features: torch.Tensor) -> Dict[str, float]:
|
||||
"""Get statistics of feature tensor.
|
||||
|
||||
Args:
|
||||
features: Feature tensor to analyze
|
||||
|
||||
Returns:
|
||||
Dictionary of feature statistics
|
||||
"""
|
||||
with torch.no_grad():
|
||||
return {
|
||||
'mean': features.mean().item(),
|
||||
'std': features.std().item(),
|
||||
'min': features.min().item(),
|
||||
'max': features.max().item(),
|
||||
'sparsity': (features == 0).float().mean().item()
|
||||
}
|
||||
|
||||
def get_intermediate_features(self, x: torch.Tensor) -> Dict[str, Any]:
|
||||
"""Get intermediate features for visualization.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
|
||||
Returns:
|
||||
Dictionary containing intermediate features
|
||||
"""
|
||||
result = {}
|
||||
|
||||
# Get encoder features
|
||||
encoder_features = self.encode(x)
|
||||
result['encoder_features'] = encoder_features
|
||||
|
||||
# Get decoder features
|
||||
decoder_features = []
|
||||
current = encoder_features[-1]
|
||||
|
||||
if self.use_attention:
|
||||
batch_size, channels, height, width = current.shape
|
||||
current_flat = current.view(batch_size, channels, -1).transpose(1, 2)
|
||||
attended, attention_weights = self.attention(current_flat, current_flat, current_flat)
|
||||
current = attended.transpose(1, 2).view(batch_size, channels, height, width)
|
||||
result['attention_weights'] = attention_weights
|
||||
|
||||
for layer in self.decoder:
|
||||
current = layer(current)
|
||||
decoder_features.append(current)
|
||||
|
||||
result['decoder_features'] = decoder_features
|
||||
|
||||
return result
|
||||
19
v1/src/services/__init__.py
Normal file
19
v1/src/services/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Services package for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
from .orchestrator import ServiceOrchestrator
|
||||
from .health_check import HealthCheckService
|
||||
from .metrics import MetricsService
|
||||
from .pose_service import PoseService
|
||||
from .stream_service import StreamService
|
||||
from .hardware_service import HardwareService
|
||||
|
||||
__all__ = [
|
||||
'ServiceOrchestrator',
|
||||
'HealthCheckService',
|
||||
'MetricsService',
|
||||
'PoseService',
|
||||
'StreamService',
|
||||
'HardwareService'
|
||||
]
|
||||
482
v1/src/services/hardware_service.py
Normal file
482
v1/src/services/hardware_service.py
Normal file
@@ -0,0 +1,482 @@
|
||||
"""
|
||||
Hardware interface service for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.config.domains import DomainConfig
|
||||
from src.core.router_interface import RouterInterface
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HardwareService:
|
||||
"""Service for hardware interface operations."""
|
||||
|
||||
def __init__(self, settings: Settings, domain_config: DomainConfig):
|
||||
"""Initialize hardware service."""
|
||||
self.settings = settings
|
||||
self.domain_config = domain_config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Router interfaces
|
||||
self.router_interfaces: Dict[str, RouterInterface] = {}
|
||||
|
||||
# Service state
|
||||
self.is_running = False
|
||||
self.last_error = None
|
||||
|
||||
# Data collection statistics
|
||||
self.stats = {
|
||||
"total_samples": 0,
|
||||
"successful_samples": 0,
|
||||
"failed_samples": 0,
|
||||
"average_sample_rate": 0.0,
|
||||
"last_sample_time": None,
|
||||
"connected_routers": 0
|
||||
}
|
||||
|
||||
# Background tasks
|
||||
self.collection_task = None
|
||||
self.monitoring_task = None
|
||||
|
||||
# Data buffers
|
||||
self.recent_samples = []
|
||||
self.max_recent_samples = 1000
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the hardware service."""
|
||||
await self.start()
|
||||
|
||||
async def start(self):
|
||||
"""Start the hardware service."""
|
||||
if self.is_running:
|
||||
return
|
||||
|
||||
try:
|
||||
self.logger.info("Starting hardware service...")
|
||||
|
||||
# Initialize router interfaces
|
||||
await self._initialize_routers()
|
||||
|
||||
self.is_running = True
|
||||
|
||||
# Start background tasks
|
||||
if not self.settings.mock_hardware:
|
||||
self.collection_task = asyncio.create_task(self._data_collection_loop())
|
||||
|
||||
self.monitoring_task = asyncio.create_task(self._monitoring_loop())
|
||||
|
||||
self.logger.info("Hardware service started successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.last_error = str(e)
|
||||
self.logger.error(f"Failed to start hardware service: {e}")
|
||||
raise
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the hardware service."""
|
||||
self.is_running = False
|
||||
|
||||
# Cancel background tasks
|
||||
if self.collection_task:
|
||||
self.collection_task.cancel()
|
||||
try:
|
||||
await self.collection_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self.monitoring_task:
|
||||
self.monitoring_task.cancel()
|
||||
try:
|
||||
await self.monitoring_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Disconnect from routers
|
||||
await self._disconnect_routers()
|
||||
|
||||
self.logger.info("Hardware service stopped")
|
||||
|
||||
async def _initialize_routers(self):
|
||||
"""Initialize router interfaces."""
|
||||
try:
|
||||
# Get router configurations from domain config
|
||||
routers = self.domain_config.get_all_routers()
|
||||
|
||||
for router_config in routers:
|
||||
if not router_config.enabled:
|
||||
continue
|
||||
|
||||
router_id = router_config.router_id
|
||||
|
||||
# Create router interface
|
||||
router_interface = RouterInterface(
|
||||
router_id=router_id,
|
||||
host=router_config.ip_address,
|
||||
port=22, # Default SSH port
|
||||
username="admin", # Default username
|
||||
password="admin", # Default password
|
||||
interface=router_config.interface,
|
||||
mock_mode=self.settings.mock_hardware
|
||||
)
|
||||
|
||||
# Connect to router (always connect, even in mock mode)
|
||||
await router_interface.connect()
|
||||
|
||||
self.router_interfaces[router_id] = router_interface
|
||||
self.logger.info(f"Router interface initialized: {router_id}")
|
||||
|
||||
self.stats["connected_routers"] = len(self.router_interfaces)
|
||||
|
||||
if not self.router_interfaces:
|
||||
self.logger.warning("No router interfaces configured")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize routers: {e}")
|
||||
raise
|
||||
|
||||
async def _disconnect_routers(self):
|
||||
"""Disconnect from all routers."""
|
||||
for router_id, interface in self.router_interfaces.items():
|
||||
try:
|
||||
await interface.disconnect()
|
||||
self.logger.info(f"Disconnected from router: {router_id}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error disconnecting from router {router_id}: {e}")
|
||||
|
||||
self.router_interfaces.clear()
|
||||
self.stats["connected_routers"] = 0
|
||||
|
||||
async def _data_collection_loop(self):
|
||||
"""Background loop for data collection."""
|
||||
try:
|
||||
while self.is_running:
|
||||
start_time = time.time()
|
||||
|
||||
# Collect data from all routers
|
||||
await self._collect_data_from_routers()
|
||||
|
||||
# Calculate sleep time to maintain polling interval
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, self.settings.hardware_polling_interval - elapsed)
|
||||
|
||||
if sleep_time > 0:
|
||||
await asyncio.sleep(sleep_time)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self.logger.info("Data collection loop cancelled")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in data collection loop: {e}")
|
||||
self.last_error = str(e)
|
||||
|
||||
async def _monitoring_loop(self):
|
||||
"""Background loop for hardware monitoring."""
|
||||
try:
|
||||
while self.is_running:
|
||||
# Monitor router connections
|
||||
await self._monitor_router_health()
|
||||
|
||||
# Update statistics
|
||||
self._update_sample_rate_stats()
|
||||
|
||||
# Wait before next check
|
||||
await asyncio.sleep(30) # Check every 30 seconds
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self.logger.info("Monitoring loop cancelled")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in monitoring loop: {e}")
|
||||
|
||||
async def _collect_data_from_routers(self):
|
||||
"""Collect CSI data from all connected routers."""
|
||||
for router_id, interface in self.router_interfaces.items():
|
||||
try:
|
||||
# Get CSI data from router
|
||||
csi_data = await interface.get_csi_data()
|
||||
|
||||
if csi_data is not None:
|
||||
# Process the collected data
|
||||
await self._process_collected_data(router_id, csi_data)
|
||||
|
||||
self.stats["successful_samples"] += 1
|
||||
self.stats["last_sample_time"] = datetime.now().isoformat()
|
||||
else:
|
||||
self.stats["failed_samples"] += 1
|
||||
|
||||
self.stats["total_samples"] += 1
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error collecting data from router {router_id}: {e}")
|
||||
self.stats["failed_samples"] += 1
|
||||
self.stats["total_samples"] += 1
|
||||
|
||||
async def _process_collected_data(self, router_id: str, csi_data: np.ndarray):
|
||||
"""Process collected CSI data."""
|
||||
try:
|
||||
# Create sample metadata
|
||||
metadata = {
|
||||
"router_id": router_id,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"sample_rate": self.stats["average_sample_rate"],
|
||||
"data_shape": csi_data.shape if hasattr(csi_data, 'shape') else None
|
||||
}
|
||||
|
||||
# Add to recent samples buffer
|
||||
sample = {
|
||||
"router_id": router_id,
|
||||
"timestamp": metadata["timestamp"],
|
||||
"data": csi_data,
|
||||
"metadata": metadata
|
||||
}
|
||||
|
||||
self.recent_samples.append(sample)
|
||||
|
||||
# Maintain buffer size
|
||||
if len(self.recent_samples) > self.max_recent_samples:
|
||||
self.recent_samples.pop(0)
|
||||
|
||||
# Notify other services (this would typically be done through an event system)
|
||||
# For now, we'll just log the data collection
|
||||
self.logger.debug(f"Collected CSI data from {router_id}: shape {csi_data.shape if hasattr(csi_data, 'shape') else 'unknown'}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error processing collected data: {e}")
|
||||
|
||||
async def _monitor_router_health(self):
|
||||
"""Monitor health of router connections."""
|
||||
healthy_routers = 0
|
||||
|
||||
for router_id, interface in self.router_interfaces.items():
|
||||
try:
|
||||
is_healthy = await interface.check_health()
|
||||
|
||||
if is_healthy:
|
||||
healthy_routers += 1
|
||||
else:
|
||||
self.logger.warning(f"Router {router_id} is unhealthy")
|
||||
|
||||
# Try to reconnect if not in mock mode
|
||||
if not self.settings.mock_hardware:
|
||||
try:
|
||||
await interface.reconnect()
|
||||
self.logger.info(f"Reconnected to router {router_id}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to reconnect to router {router_id}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error checking health of router {router_id}: {e}")
|
||||
|
||||
self.stats["connected_routers"] = healthy_routers
|
||||
|
||||
def _update_sample_rate_stats(self):
|
||||
"""Update sample rate statistics."""
|
||||
if len(self.recent_samples) < 2:
|
||||
return
|
||||
|
||||
# Calculate sample rate from recent samples
|
||||
recent_count = min(100, len(self.recent_samples))
|
||||
recent_samples = self.recent_samples[-recent_count:]
|
||||
|
||||
if len(recent_samples) >= 2:
|
||||
# Calculate time differences
|
||||
time_diffs = []
|
||||
for i in range(1, len(recent_samples)):
|
||||
try:
|
||||
t1 = datetime.fromisoformat(recent_samples[i-1]["timestamp"])
|
||||
t2 = datetime.fromisoformat(recent_samples[i]["timestamp"])
|
||||
diff = (t2 - t1).total_seconds()
|
||||
if diff > 0:
|
||||
time_diffs.append(diff)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if time_diffs:
|
||||
avg_interval = sum(time_diffs) / len(time_diffs)
|
||||
self.stats["average_sample_rate"] = 1.0 / avg_interval if avg_interval > 0 else 0.0
|
||||
|
||||
async def get_router_status(self, router_id: str) -> Dict[str, Any]:
|
||||
"""Get status of a specific router."""
|
||||
if router_id not in self.router_interfaces:
|
||||
raise ValueError(f"Router {router_id} not found")
|
||||
|
||||
interface = self.router_interfaces[router_id]
|
||||
|
||||
try:
|
||||
is_healthy = await interface.check_health()
|
||||
status = await interface.get_status()
|
||||
|
||||
return {
|
||||
"router_id": router_id,
|
||||
"healthy": is_healthy,
|
||||
"connected": status.get("connected", False),
|
||||
"last_data_time": status.get("last_data_time"),
|
||||
"error_count": status.get("error_count", 0),
|
||||
"configuration": status.get("configuration", {})
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"router_id": router_id,
|
||||
"healthy": False,
|
||||
"connected": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def get_all_router_status(self) -> List[Dict[str, Any]]:
|
||||
"""Get status of all routers."""
|
||||
statuses = []
|
||||
|
||||
for router_id in self.router_interfaces:
|
||||
try:
|
||||
status = await self.get_router_status(router_id)
|
||||
statuses.append(status)
|
||||
except Exception as e:
|
||||
statuses.append({
|
||||
"router_id": router_id,
|
||||
"healthy": False,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
return statuses
|
||||
|
||||
async def get_recent_data(self, router_id: Optional[str] = None, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Get recent CSI data samples."""
|
||||
samples = self.recent_samples[-limit:] if limit else self.recent_samples
|
||||
|
||||
if router_id:
|
||||
samples = [s for s in samples if s["router_id"] == router_id]
|
||||
|
||||
# Convert numpy arrays to lists for JSON serialization
|
||||
result = []
|
||||
for sample in samples:
|
||||
sample_copy = sample.copy()
|
||||
if isinstance(sample_copy["data"], np.ndarray):
|
||||
sample_copy["data"] = sample_copy["data"].tolist()
|
||||
result.append(sample_copy)
|
||||
|
||||
return result
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get service status."""
|
||||
return {
|
||||
"status": "healthy" if self.is_running and not self.last_error else "unhealthy",
|
||||
"running": self.is_running,
|
||||
"last_error": self.last_error,
|
||||
"statistics": self.stats.copy(),
|
||||
"configuration": {
|
||||
"mock_hardware": self.settings.mock_hardware,
|
||||
"wifi_interface": self.settings.wifi_interface,
|
||||
"polling_interval": self.settings.hardware_polling_interval,
|
||||
"buffer_size": self.settings.csi_buffer_size
|
||||
},
|
||||
"routers": await self.get_all_router_status()
|
||||
}
|
||||
|
||||
async def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get service metrics."""
|
||||
total_samples = self.stats["total_samples"]
|
||||
success_rate = self.stats["successful_samples"] / max(1, total_samples)
|
||||
|
||||
return {
|
||||
"hardware_service": {
|
||||
"total_samples": total_samples,
|
||||
"successful_samples": self.stats["successful_samples"],
|
||||
"failed_samples": self.stats["failed_samples"],
|
||||
"success_rate": success_rate,
|
||||
"average_sample_rate": self.stats["average_sample_rate"],
|
||||
"connected_routers": self.stats["connected_routers"],
|
||||
"last_sample_time": self.stats["last_sample_time"]
|
||||
}
|
||||
}
|
||||
|
||||
async def reset(self):
|
||||
"""Reset service state."""
|
||||
self.stats = {
|
||||
"total_samples": 0,
|
||||
"successful_samples": 0,
|
||||
"failed_samples": 0,
|
||||
"average_sample_rate": 0.0,
|
||||
"last_sample_time": None,
|
||||
"connected_routers": len(self.router_interfaces)
|
||||
}
|
||||
|
||||
self.recent_samples.clear()
|
||||
self.last_error = None
|
||||
|
||||
self.logger.info("Hardware service reset")
|
||||
|
||||
async def trigger_manual_collection(self, router_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Manually trigger data collection."""
|
||||
if not self.is_running:
|
||||
raise RuntimeError("Hardware service is not running")
|
||||
|
||||
results = {}
|
||||
|
||||
if router_id:
|
||||
# Collect from specific router
|
||||
if router_id not in self.router_interfaces:
|
||||
raise ValueError(f"Router {router_id} not found")
|
||||
|
||||
interface = self.router_interfaces[router_id]
|
||||
try:
|
||||
csi_data = await interface.get_csi_data()
|
||||
if csi_data is not None:
|
||||
await self._process_collected_data(router_id, csi_data)
|
||||
results[router_id] = {"success": True, "data_shape": csi_data.shape if hasattr(csi_data, 'shape') else None}
|
||||
else:
|
||||
results[router_id] = {"success": False, "error": "No data received"}
|
||||
except Exception as e:
|
||||
results[router_id] = {"success": False, "error": str(e)}
|
||||
else:
|
||||
# Collect from all routers
|
||||
await self._collect_data_from_routers()
|
||||
results = {"message": "Manual collection triggered for all routers"}
|
||||
|
||||
return results
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Perform health check."""
|
||||
try:
|
||||
status = "healthy" if self.is_running and not self.last_error else "unhealthy"
|
||||
|
||||
# Check router health
|
||||
healthy_routers = 0
|
||||
total_routers = len(self.router_interfaces)
|
||||
|
||||
for router_id, interface in self.router_interfaces.items():
|
||||
try:
|
||||
if await interface.check_health():
|
||||
healthy_routers += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"status": status,
|
||||
"message": self.last_error if self.last_error else "Hardware service is running normally",
|
||||
"connected_routers": f"{healthy_routers}/{total_routers}",
|
||||
"metrics": {
|
||||
"total_samples": self.stats["total_samples"],
|
||||
"success_rate": (
|
||||
self.stats["successful_samples"] / max(1, self.stats["total_samples"])
|
||||
),
|
||||
"average_sample_rate": self.stats["average_sample_rate"]
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"message": f"Health check failed: {str(e)}"
|
||||
}
|
||||
|
||||
async def is_ready(self) -> bool:
|
||||
"""Check if service is ready."""
|
||||
return self.is_running and len(self.router_interfaces) > 0
|
||||
465
v1/src/services/health_check.py
Normal file
465
v1/src/services/health_check.py
Normal file
@@ -0,0 +1,465 @@
|
||||
"""
|
||||
Health check service for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
from src.config.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HealthStatus(Enum):
|
||||
"""Health status enumeration."""
|
||||
HEALTHY = "healthy"
|
||||
DEGRADED = "degraded"
|
||||
UNHEALTHY = "unhealthy"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HealthCheck:
|
||||
"""Health check result."""
|
||||
name: str
|
||||
status: HealthStatus
|
||||
message: str
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
duration_ms: float = 0.0
|
||||
details: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceHealth:
|
||||
"""Service health information."""
|
||||
name: str
|
||||
status: HealthStatus
|
||||
last_check: Optional[datetime] = None
|
||||
checks: List[HealthCheck] = field(default_factory=list)
|
||||
uptime: float = 0.0
|
||||
error_count: int = 0
|
||||
last_error: Optional[str] = None
|
||||
|
||||
|
||||
class HealthCheckService:
|
||||
"""Service for monitoring application health."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self._services: Dict[str, ServiceHealth] = {}
|
||||
self._start_time = time.time()
|
||||
self._initialized = False
|
||||
self._running = False
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize health check service."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
logger.info("Initializing health check service")
|
||||
|
||||
# Initialize service health tracking
|
||||
self._services = {
|
||||
"api": ServiceHealth("api", HealthStatus.UNKNOWN),
|
||||
"database": ServiceHealth("database", HealthStatus.UNKNOWN),
|
||||
"redis": ServiceHealth("redis", HealthStatus.UNKNOWN),
|
||||
"hardware": ServiceHealth("hardware", HealthStatus.UNKNOWN),
|
||||
"pose": ServiceHealth("pose", HealthStatus.UNKNOWN),
|
||||
"stream": ServiceHealth("stream", HealthStatus.UNKNOWN),
|
||||
}
|
||||
|
||||
self._initialized = True
|
||||
logger.info("Health check service initialized")
|
||||
|
||||
async def start(self):
|
||||
"""Start health check service."""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
self._running = True
|
||||
logger.info("Health check service started")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown health check service."""
|
||||
self._running = False
|
||||
logger.info("Health check service shut down")
|
||||
|
||||
async def perform_health_checks(self) -> Dict[str, HealthCheck]:
|
||||
"""Perform all health checks."""
|
||||
if not self._running:
|
||||
return {}
|
||||
|
||||
logger.debug("Performing health checks")
|
||||
results = {}
|
||||
|
||||
# Perform individual health checks
|
||||
checks = [
|
||||
self._check_api_health(),
|
||||
self._check_database_health(),
|
||||
self._check_redis_health(),
|
||||
self._check_hardware_health(),
|
||||
self._check_pose_health(),
|
||||
self._check_stream_health(),
|
||||
]
|
||||
|
||||
# Run checks concurrently
|
||||
check_results = await asyncio.gather(*checks, return_exceptions=True)
|
||||
|
||||
# Process results
|
||||
for i, result in enumerate(check_results):
|
||||
check_name = ["api", "database", "redis", "hardware", "pose", "stream"][i]
|
||||
|
||||
if isinstance(result, Exception):
|
||||
health_check = HealthCheck(
|
||||
name=check_name,
|
||||
status=HealthStatus.UNHEALTHY,
|
||||
message=f"Health check failed: {result}"
|
||||
)
|
||||
else:
|
||||
health_check = result
|
||||
|
||||
results[check_name] = health_check
|
||||
self._update_service_health(check_name, health_check)
|
||||
|
||||
logger.debug(f"Completed {len(results)} health checks")
|
||||
return results
|
||||
|
||||
async def _check_api_health(self) -> HealthCheck:
|
||||
"""Check API health."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Basic API health check
|
||||
uptime = time.time() - self._start_time
|
||||
|
||||
status = HealthStatus.HEALTHY
|
||||
message = "API is running normally"
|
||||
details = {
|
||||
"uptime_seconds": uptime,
|
||||
"uptime_formatted": str(timedelta(seconds=int(uptime)))
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
status = HealthStatus.UNHEALTHY
|
||||
message = f"API health check failed: {e}"
|
||||
details = {"error": str(e)}
|
||||
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
||||
return HealthCheck(
|
||||
name="api",
|
||||
status=status,
|
||||
message=message,
|
||||
duration_ms=duration_ms,
|
||||
details=details
|
||||
)
|
||||
|
||||
async def _check_database_health(self) -> HealthCheck:
|
||||
"""Check database health."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from src.database.connection import get_database_manager
|
||||
|
||||
db_manager = get_database_manager()
|
||||
|
||||
if not db_manager.is_connected():
|
||||
status = HealthStatus.UNHEALTHY
|
||||
message = "Database is not connected"
|
||||
details = {"connected": False}
|
||||
else:
|
||||
# Test database connection
|
||||
await db_manager.test_connection()
|
||||
|
||||
status = HealthStatus.HEALTHY
|
||||
message = "Database is connected and responsive"
|
||||
details = {
|
||||
"connected": True,
|
||||
"pool_size": db_manager.get_pool_size(),
|
||||
"active_connections": db_manager.get_active_connections()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
status = HealthStatus.UNHEALTHY
|
||||
message = f"Database health check failed: {e}"
|
||||
details = {"error": str(e)}
|
||||
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
||||
return HealthCheck(
|
||||
name="database",
|
||||
status=status,
|
||||
message=message,
|
||||
duration_ms=duration_ms,
|
||||
details=details
|
||||
)
|
||||
|
||||
async def _check_redis_health(self) -> HealthCheck:
|
||||
"""Check Redis health."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
redis_config = self.settings.get_redis_url()
|
||||
|
||||
if not redis_config:
|
||||
status = HealthStatus.UNKNOWN
|
||||
message = "Redis is not configured"
|
||||
details = {"configured": False}
|
||||
else:
|
||||
# Test Redis connection
|
||||
import redis.asyncio as redis
|
||||
|
||||
redis_client = redis.from_url(redis_config)
|
||||
await redis_client.ping()
|
||||
await redis_client.close()
|
||||
|
||||
status = HealthStatus.HEALTHY
|
||||
message = "Redis is connected and responsive"
|
||||
details = {"connected": True}
|
||||
|
||||
except Exception as e:
|
||||
status = HealthStatus.UNHEALTHY
|
||||
message = f"Redis health check failed: {e}"
|
||||
details = {"error": str(e)}
|
||||
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
||||
return HealthCheck(
|
||||
name="redis",
|
||||
status=status,
|
||||
message=message,
|
||||
duration_ms=duration_ms,
|
||||
details=details
|
||||
)
|
||||
|
||||
async def _check_hardware_health(self) -> HealthCheck:
|
||||
"""Check hardware service health."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from src.api.dependencies import get_hardware_service
|
||||
|
||||
hardware_service = get_hardware_service()
|
||||
|
||||
if hasattr(hardware_service, 'get_status'):
|
||||
status_info = await hardware_service.get_status()
|
||||
|
||||
if status_info.get("status") == "healthy":
|
||||
status = HealthStatus.HEALTHY
|
||||
message = "Hardware service is operational"
|
||||
else:
|
||||
status = HealthStatus.DEGRADED
|
||||
message = f"Hardware service status: {status_info.get('status', 'unknown')}"
|
||||
|
||||
details = status_info
|
||||
else:
|
||||
status = HealthStatus.UNKNOWN
|
||||
message = "Hardware service status unavailable"
|
||||
details = {}
|
||||
|
||||
except Exception as e:
|
||||
status = HealthStatus.UNHEALTHY
|
||||
message = f"Hardware health check failed: {e}"
|
||||
details = {"error": str(e)}
|
||||
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
||||
return HealthCheck(
|
||||
name="hardware",
|
||||
status=status,
|
||||
message=message,
|
||||
duration_ms=duration_ms,
|
||||
details=details
|
||||
)
|
||||
|
||||
async def _check_pose_health(self) -> HealthCheck:
|
||||
"""Check pose service health."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from src.api.dependencies import get_pose_service
|
||||
|
||||
pose_service = get_pose_service()
|
||||
|
||||
if hasattr(pose_service, 'get_status'):
|
||||
status_info = await pose_service.get_status()
|
||||
|
||||
if status_info.get("status") == "healthy":
|
||||
status = HealthStatus.HEALTHY
|
||||
message = "Pose service is operational"
|
||||
else:
|
||||
status = HealthStatus.DEGRADED
|
||||
message = f"Pose service status: {status_info.get('status', 'unknown')}"
|
||||
|
||||
details = status_info
|
||||
else:
|
||||
status = HealthStatus.UNKNOWN
|
||||
message = "Pose service status unavailable"
|
||||
details = {}
|
||||
|
||||
except Exception as e:
|
||||
status = HealthStatus.UNHEALTHY
|
||||
message = f"Pose health check failed: {e}"
|
||||
details = {"error": str(e)}
|
||||
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
||||
return HealthCheck(
|
||||
name="pose",
|
||||
status=status,
|
||||
message=message,
|
||||
duration_ms=duration_ms,
|
||||
details=details
|
||||
)
|
||||
|
||||
async def _check_stream_health(self) -> HealthCheck:
|
||||
"""Check stream service health."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from src.api.dependencies import get_stream_service
|
||||
|
||||
stream_service = get_stream_service()
|
||||
|
||||
if hasattr(stream_service, 'get_status'):
|
||||
status_info = await stream_service.get_status()
|
||||
|
||||
if status_info.get("status") == "healthy":
|
||||
status = HealthStatus.HEALTHY
|
||||
message = "Stream service is operational"
|
||||
else:
|
||||
status = HealthStatus.DEGRADED
|
||||
message = f"Stream service status: {status_info.get('status', 'unknown')}"
|
||||
|
||||
details = status_info
|
||||
else:
|
||||
status = HealthStatus.UNKNOWN
|
||||
message = "Stream service status unavailable"
|
||||
details = {}
|
||||
|
||||
except Exception as e:
|
||||
status = HealthStatus.UNHEALTHY
|
||||
message = f"Stream health check failed: {e}"
|
||||
details = {"error": str(e)}
|
||||
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
||||
return HealthCheck(
|
||||
name="stream",
|
||||
status=status,
|
||||
message=message,
|
||||
duration_ms=duration_ms,
|
||||
details=details
|
||||
)
|
||||
|
||||
def _update_service_health(self, service_name: str, health_check: HealthCheck):
|
||||
"""Update service health information."""
|
||||
if service_name not in self._services:
|
||||
self._services[service_name] = ServiceHealth(service_name, HealthStatus.UNKNOWN)
|
||||
|
||||
service_health = self._services[service_name]
|
||||
service_health.status = health_check.status
|
||||
service_health.last_check = health_check.timestamp
|
||||
service_health.uptime = time.time() - self._start_time
|
||||
|
||||
# Keep last 10 checks
|
||||
service_health.checks.append(health_check)
|
||||
if len(service_health.checks) > 10:
|
||||
service_health.checks.pop(0)
|
||||
|
||||
# Update error tracking
|
||||
if health_check.status == HealthStatus.UNHEALTHY:
|
||||
service_health.error_count += 1
|
||||
service_health.last_error = health_check.message
|
||||
|
||||
async def get_overall_health(self) -> Dict[str, Any]:
|
||||
"""Get overall system health."""
|
||||
if not self._services:
|
||||
return {
|
||||
"status": HealthStatus.UNKNOWN.value,
|
||||
"message": "Health checks not initialized"
|
||||
}
|
||||
|
||||
# Determine overall status
|
||||
statuses = [service.status for service in self._services.values()]
|
||||
|
||||
if all(status == HealthStatus.HEALTHY for status in statuses):
|
||||
overall_status = HealthStatus.HEALTHY
|
||||
message = "All services are healthy"
|
||||
elif any(status == HealthStatus.UNHEALTHY for status in statuses):
|
||||
overall_status = HealthStatus.UNHEALTHY
|
||||
unhealthy_services = [
|
||||
name for name, service in self._services.items()
|
||||
if service.status == HealthStatus.UNHEALTHY
|
||||
]
|
||||
message = f"Unhealthy services: {', '.join(unhealthy_services)}"
|
||||
elif any(status == HealthStatus.DEGRADED for status in statuses):
|
||||
overall_status = HealthStatus.DEGRADED
|
||||
degraded_services = [
|
||||
name for name, service in self._services.items()
|
||||
if service.status == HealthStatus.DEGRADED
|
||||
]
|
||||
message = f"Degraded services: {', '.join(degraded_services)}"
|
||||
else:
|
||||
overall_status = HealthStatus.UNKNOWN
|
||||
message = "System health status unknown"
|
||||
|
||||
return {
|
||||
"status": overall_status.value,
|
||||
"message": message,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"uptime": time.time() - self._start_time,
|
||||
"services": {
|
||||
name: {
|
||||
"status": service.status.value,
|
||||
"last_check": service.last_check.isoformat() if service.last_check else None,
|
||||
"error_count": service.error_count,
|
||||
"last_error": service.last_error
|
||||
}
|
||||
for name, service in self._services.items()
|
||||
}
|
||||
}
|
||||
|
||||
async def get_service_health(self, service_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get health information for a specific service."""
|
||||
service = self._services.get(service_name)
|
||||
if not service:
|
||||
return None
|
||||
|
||||
return {
|
||||
"name": service.name,
|
||||
"status": service.status.value,
|
||||
"last_check": service.last_check.isoformat() if service.last_check else None,
|
||||
"uptime": service.uptime,
|
||||
"error_count": service.error_count,
|
||||
"last_error": service.last_error,
|
||||
"recent_checks": [
|
||||
{
|
||||
"timestamp": check.timestamp.isoformat(),
|
||||
"status": check.status.value,
|
||||
"message": check.message,
|
||||
"duration_ms": check.duration_ms,
|
||||
"details": check.details
|
||||
}
|
||||
for check in service.checks[-5:] # Last 5 checks
|
||||
]
|
||||
}
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get health check service status."""
|
||||
return {
|
||||
"status": "healthy" if self._running else "stopped",
|
||||
"initialized": self._initialized,
|
||||
"running": self._running,
|
||||
"services_monitored": len(self._services),
|
||||
"uptime": time.time() - self._start_time
|
||||
}
|
||||
431
v1/src/services/metrics.py
Normal file
431
v1/src/services/metrics.py
Normal file
@@ -0,0 +1,431 @@
|
||||
"""
|
||||
Metrics collection service for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import psutil
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict, deque
|
||||
|
||||
from src.config.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetricPoint:
|
||||
"""Single metric data point."""
|
||||
timestamp: datetime
|
||||
value: float
|
||||
labels: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetricSeries:
|
||||
"""Time series of metric points."""
|
||||
name: str
|
||||
description: str
|
||||
unit: str
|
||||
points: deque = field(default_factory=lambda: deque(maxlen=1000))
|
||||
|
||||
def add_point(self, value: float, labels: Optional[Dict[str, str]] = None):
|
||||
"""Add a metric point."""
|
||||
point = MetricPoint(
|
||||
timestamp=datetime.utcnow(),
|
||||
value=value,
|
||||
labels=labels or {}
|
||||
)
|
||||
self.points.append(point)
|
||||
|
||||
def get_latest(self) -> Optional[MetricPoint]:
|
||||
"""Get the latest metric point."""
|
||||
return self.points[-1] if self.points else None
|
||||
|
||||
def get_average(self, duration: timedelta) -> Optional[float]:
|
||||
"""Get average value over a time duration."""
|
||||
cutoff = datetime.utcnow() - duration
|
||||
relevant_points = [
|
||||
point for point in self.points
|
||||
if point.timestamp >= cutoff
|
||||
]
|
||||
|
||||
if not relevant_points:
|
||||
return None
|
||||
|
||||
return sum(point.value for point in relevant_points) / len(relevant_points)
|
||||
|
||||
def get_max(self, duration: timedelta) -> Optional[float]:
|
||||
"""Get maximum value over a time duration."""
|
||||
cutoff = datetime.utcnow() - duration
|
||||
relevant_points = [
|
||||
point for point in self.points
|
||||
if point.timestamp >= cutoff
|
||||
]
|
||||
|
||||
if not relevant_points:
|
||||
return None
|
||||
|
||||
return max(point.value for point in relevant_points)
|
||||
|
||||
|
||||
class MetricsService:
|
||||
"""Service for collecting and managing application metrics."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self._metrics: Dict[str, MetricSeries] = {}
|
||||
self._counters: Dict[str, float] = defaultdict(float)
|
||||
self._gauges: Dict[str, float] = {}
|
||||
self._histograms: Dict[str, List[float]] = defaultdict(list)
|
||||
self._start_time = time.time()
|
||||
self._initialized = False
|
||||
self._running = False
|
||||
|
||||
# Initialize standard metrics
|
||||
self._initialize_standard_metrics()
|
||||
|
||||
def _initialize_standard_metrics(self):
|
||||
"""Initialize standard system and application metrics."""
|
||||
self._metrics.update({
|
||||
# System metrics
|
||||
"system_cpu_usage": MetricSeries(
|
||||
"system_cpu_usage", "System CPU usage percentage", "percent"
|
||||
),
|
||||
"system_memory_usage": MetricSeries(
|
||||
"system_memory_usage", "System memory usage percentage", "percent"
|
||||
),
|
||||
"system_disk_usage": MetricSeries(
|
||||
"system_disk_usage", "System disk usage percentage", "percent"
|
||||
),
|
||||
"system_network_bytes_sent": MetricSeries(
|
||||
"system_network_bytes_sent", "Network bytes sent", "bytes"
|
||||
),
|
||||
"system_network_bytes_recv": MetricSeries(
|
||||
"system_network_bytes_recv", "Network bytes received", "bytes"
|
||||
),
|
||||
|
||||
# Application metrics
|
||||
"app_requests_total": MetricSeries(
|
||||
"app_requests_total", "Total HTTP requests", "count"
|
||||
),
|
||||
"app_request_duration": MetricSeries(
|
||||
"app_request_duration", "HTTP request duration", "seconds"
|
||||
),
|
||||
"app_active_connections": MetricSeries(
|
||||
"app_active_connections", "Active WebSocket connections", "count"
|
||||
),
|
||||
"app_pose_detections": MetricSeries(
|
||||
"app_pose_detections", "Pose detections performed", "count"
|
||||
),
|
||||
"app_pose_processing_time": MetricSeries(
|
||||
"app_pose_processing_time", "Pose processing time", "seconds"
|
||||
),
|
||||
"app_csi_data_points": MetricSeries(
|
||||
"app_csi_data_points", "CSI data points processed", "count"
|
||||
),
|
||||
"app_stream_fps": MetricSeries(
|
||||
"app_stream_fps", "Streaming frames per second", "fps"
|
||||
),
|
||||
|
||||
# Error metrics
|
||||
"app_errors_total": MetricSeries(
|
||||
"app_errors_total", "Total application errors", "count"
|
||||
),
|
||||
"app_http_errors": MetricSeries(
|
||||
"app_http_errors", "HTTP errors", "count"
|
||||
),
|
||||
})
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize metrics service."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
logger.info("Initializing metrics service")
|
||||
self._initialized = True
|
||||
logger.info("Metrics service initialized")
|
||||
|
||||
async def start(self):
|
||||
"""Start metrics service."""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
self._running = True
|
||||
logger.info("Metrics service started")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown metrics service."""
|
||||
self._running = False
|
||||
logger.info("Metrics service shut down")
|
||||
|
||||
async def collect_metrics(self):
|
||||
"""Collect all metrics."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
logger.debug("Collecting metrics")
|
||||
|
||||
# Collect system metrics
|
||||
await self._collect_system_metrics()
|
||||
|
||||
# Collect application metrics
|
||||
await self._collect_application_metrics()
|
||||
|
||||
logger.debug("Metrics collection completed")
|
||||
|
||||
async def _collect_system_metrics(self):
|
||||
"""Collect system-level metrics."""
|
||||
try:
|
||||
# CPU usage
|
||||
cpu_percent = psutil.cpu_percent(interval=1)
|
||||
self._metrics["system_cpu_usage"].add_point(cpu_percent)
|
||||
|
||||
# Memory usage
|
||||
memory = psutil.virtual_memory()
|
||||
self._metrics["system_memory_usage"].add_point(memory.percent)
|
||||
|
||||
# Disk usage
|
||||
disk = psutil.disk_usage('/')
|
||||
disk_percent = (disk.used / disk.total) * 100
|
||||
self._metrics["system_disk_usage"].add_point(disk_percent)
|
||||
|
||||
# Network I/O
|
||||
network = psutil.net_io_counters()
|
||||
self._metrics["system_network_bytes_sent"].add_point(network.bytes_sent)
|
||||
self._metrics["system_network_bytes_recv"].add_point(network.bytes_recv)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting system metrics: {e}")
|
||||
|
||||
async def _collect_application_metrics(self):
|
||||
"""Collect application-specific metrics."""
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from src.api.websocket.connection_manager import connection_manager
|
||||
|
||||
# Active connections
|
||||
connection_stats = await connection_manager.get_connection_stats()
|
||||
active_connections = connection_stats.get("active_connections", 0)
|
||||
self._metrics["app_active_connections"].add_point(active_connections)
|
||||
|
||||
# Update counters as metrics
|
||||
for name, value in self._counters.items():
|
||||
if name in self._metrics:
|
||||
self._metrics[name].add_point(value)
|
||||
|
||||
# Update gauges as metrics
|
||||
for name, value in self._gauges.items():
|
||||
if name in self._metrics:
|
||||
self._metrics[name].add_point(value)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting application metrics: {e}")
|
||||
|
||||
def increment_counter(self, name: str, value: float = 1.0, labels: Optional[Dict[str, str]] = None):
|
||||
"""Increment a counter metric."""
|
||||
self._counters[name] += value
|
||||
|
||||
if name in self._metrics:
|
||||
self._metrics[name].add_point(self._counters[name], labels)
|
||||
|
||||
def set_gauge(self, name: str, value: float, labels: Optional[Dict[str, str]] = None):
|
||||
"""Set a gauge metric value."""
|
||||
self._gauges[name] = value
|
||||
|
||||
if name in self._metrics:
|
||||
self._metrics[name].add_point(value, labels)
|
||||
|
||||
def record_histogram(self, name: str, value: float, labels: Optional[Dict[str, str]] = None):
|
||||
"""Record a histogram value."""
|
||||
self._histograms[name].append(value)
|
||||
|
||||
# Keep only last 1000 values
|
||||
if len(self._histograms[name]) > 1000:
|
||||
self._histograms[name] = self._histograms[name][-1000:]
|
||||
|
||||
if name in self._metrics:
|
||||
self._metrics[name].add_point(value, labels)
|
||||
|
||||
def time_function(self, metric_name: str):
|
||||
"""Decorator to time function execution."""
|
||||
def decorator(func):
|
||||
import functools
|
||||
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
duration = time.time() - start_time
|
||||
self.record_histogram(metric_name, duration)
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
duration = time.time() - start_time
|
||||
self.record_histogram(metric_name, duration)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
def get_metric(self, name: str) -> Optional[MetricSeries]:
|
||||
"""Get a metric series by name."""
|
||||
return self._metrics.get(name)
|
||||
|
||||
def get_metric_value(self, name: str) -> Optional[float]:
|
||||
"""Get the latest value of a metric."""
|
||||
metric = self._metrics.get(name)
|
||||
if metric:
|
||||
latest = metric.get_latest()
|
||||
return latest.value if latest else None
|
||||
return None
|
||||
|
||||
def get_counter_value(self, name: str) -> float:
|
||||
"""Get current counter value."""
|
||||
return self._counters.get(name, 0.0)
|
||||
|
||||
def get_gauge_value(self, name: str) -> Optional[float]:
|
||||
"""Get current gauge value."""
|
||||
return self._gauges.get(name)
|
||||
|
||||
def get_histogram_stats(self, name: str) -> Dict[str, float]:
|
||||
"""Get histogram statistics."""
|
||||
values = self._histograms.get(name, [])
|
||||
if not values:
|
||||
return {}
|
||||
|
||||
sorted_values = sorted(values)
|
||||
count = len(sorted_values)
|
||||
|
||||
return {
|
||||
"count": count,
|
||||
"sum": sum(sorted_values),
|
||||
"min": sorted_values[0],
|
||||
"max": sorted_values[-1],
|
||||
"mean": sum(sorted_values) / count,
|
||||
"p50": sorted_values[int(count * 0.5)],
|
||||
"p90": sorted_values[int(count * 0.9)],
|
||||
"p95": sorted_values[int(count * 0.95)],
|
||||
"p99": sorted_values[int(count * 0.99)],
|
||||
}
|
||||
|
||||
async def get_all_metrics(self) -> Dict[str, Any]:
|
||||
"""Get all current metrics."""
|
||||
metrics = {}
|
||||
|
||||
# Current metric values
|
||||
for name, metric_series in self._metrics.items():
|
||||
latest = metric_series.get_latest()
|
||||
if latest:
|
||||
metrics[name] = {
|
||||
"value": latest.value,
|
||||
"timestamp": latest.timestamp.isoformat(),
|
||||
"description": metric_series.description,
|
||||
"unit": metric_series.unit,
|
||||
"labels": latest.labels
|
||||
}
|
||||
|
||||
# Counter values
|
||||
metrics.update({
|
||||
f"counter_{name}": value
|
||||
for name, value in self._counters.items()
|
||||
})
|
||||
|
||||
# Gauge values
|
||||
metrics.update({
|
||||
f"gauge_{name}": value
|
||||
for name, value in self._gauges.items()
|
||||
})
|
||||
|
||||
# Histogram statistics
|
||||
for name, values in self._histograms.items():
|
||||
if values:
|
||||
stats = self.get_histogram_stats(name)
|
||||
metrics[f"histogram_{name}"] = stats
|
||||
|
||||
return metrics
|
||||
|
||||
async def get_system_metrics(self) -> Dict[str, Any]:
|
||||
"""Get system metrics summary."""
|
||||
return {
|
||||
"cpu_usage": self.get_metric_value("system_cpu_usage"),
|
||||
"memory_usage": self.get_metric_value("system_memory_usage"),
|
||||
"disk_usage": self.get_metric_value("system_disk_usage"),
|
||||
"network_bytes_sent": self.get_metric_value("system_network_bytes_sent"),
|
||||
"network_bytes_recv": self.get_metric_value("system_network_bytes_recv"),
|
||||
}
|
||||
|
||||
async def get_application_metrics(self) -> Dict[str, Any]:
|
||||
"""Get application metrics summary."""
|
||||
return {
|
||||
"requests_total": self.get_counter_value("app_requests_total"),
|
||||
"active_connections": self.get_metric_value("app_active_connections"),
|
||||
"pose_detections": self.get_counter_value("app_pose_detections"),
|
||||
"csi_data_points": self.get_counter_value("app_csi_data_points"),
|
||||
"errors_total": self.get_counter_value("app_errors_total"),
|
||||
"uptime_seconds": time.time() - self._start_time,
|
||||
"request_duration_stats": self.get_histogram_stats("app_request_duration"),
|
||||
"pose_processing_time_stats": self.get_histogram_stats("app_pose_processing_time"),
|
||||
}
|
||||
|
||||
async def get_performance_summary(self) -> Dict[str, Any]:
|
||||
"""Get performance metrics summary."""
|
||||
one_hour = timedelta(hours=1)
|
||||
|
||||
return {
|
||||
"system": {
|
||||
"cpu_avg_1h": self._metrics["system_cpu_usage"].get_average(one_hour),
|
||||
"memory_avg_1h": self._metrics["system_memory_usage"].get_average(one_hour),
|
||||
"cpu_max_1h": self._metrics["system_cpu_usage"].get_max(one_hour),
|
||||
"memory_max_1h": self._metrics["system_memory_usage"].get_max(one_hour),
|
||||
},
|
||||
"application": {
|
||||
"avg_request_duration": self.get_histogram_stats("app_request_duration").get("mean"),
|
||||
"avg_pose_processing_time": self.get_histogram_stats("app_pose_processing_time").get("mean"),
|
||||
"total_requests": self.get_counter_value("app_requests_total"),
|
||||
"total_errors": self.get_counter_value("app_errors_total"),
|
||||
"error_rate": (
|
||||
self.get_counter_value("app_errors_total") /
|
||||
max(self.get_counter_value("app_requests_total"), 1)
|
||||
) * 100,
|
||||
}
|
||||
}
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get metrics service status."""
|
||||
return {
|
||||
"status": "healthy" if self._running else "stopped",
|
||||
"initialized": self._initialized,
|
||||
"running": self._running,
|
||||
"metrics_count": len(self._metrics),
|
||||
"counters_count": len(self._counters),
|
||||
"gauges_count": len(self._gauges),
|
||||
"histograms_count": len(self._histograms),
|
||||
"uptime": time.time() - self._start_time
|
||||
}
|
||||
|
||||
def reset_metrics(self):
|
||||
"""Reset all metrics."""
|
||||
logger.info("Resetting all metrics")
|
||||
|
||||
# Clear metric points but keep series definitions
|
||||
for metric_series in self._metrics.values():
|
||||
metric_series.points.clear()
|
||||
|
||||
# Reset counters, gauges, and histograms
|
||||
self._counters.clear()
|
||||
self._gauges.clear()
|
||||
self._histograms.clear()
|
||||
|
||||
logger.info("All metrics reset")
|
||||
395
v1/src/services/orchestrator.py
Normal file
395
v1/src/services/orchestrator.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
Main service orchestrator for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.services.health_check import HealthCheckService
|
||||
from src.services.metrics import MetricsService
|
||||
from src.api.dependencies import (
|
||||
get_hardware_service,
|
||||
get_pose_service,
|
||||
get_stream_service
|
||||
)
|
||||
from src.api.websocket.connection_manager import connection_manager
|
||||
from src.api.websocket.pose_stream import PoseStreamHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServiceOrchestrator:
|
||||
"""Main service orchestrator that manages all application services."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self._services: Dict[str, Any] = {}
|
||||
self._background_tasks: List[asyncio.Task] = []
|
||||
self._initialized = False
|
||||
self._started = False
|
||||
|
||||
# Core services
|
||||
self.health_service = HealthCheckService(settings)
|
||||
self.metrics_service = MetricsService(settings)
|
||||
|
||||
# Application services (will be initialized later)
|
||||
self.hardware_service = None
|
||||
self.pose_service = None
|
||||
self.stream_service = None
|
||||
self.pose_stream_handler = None
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize all services."""
|
||||
if self._initialized:
|
||||
logger.warning("Services already initialized")
|
||||
return
|
||||
|
||||
logger.info("Initializing services...")
|
||||
|
||||
try:
|
||||
# Initialize core services
|
||||
await self.health_service.initialize()
|
||||
await self.metrics_service.initialize()
|
||||
|
||||
# Initialize application services
|
||||
await self._initialize_application_services()
|
||||
|
||||
# Store services in registry
|
||||
self._services = {
|
||||
'health': self.health_service,
|
||||
'metrics': self.metrics_service,
|
||||
'hardware': self.hardware_service,
|
||||
'pose': self.pose_service,
|
||||
'stream': self.stream_service,
|
||||
'pose_stream_handler': self.pose_stream_handler,
|
||||
'connection_manager': connection_manager
|
||||
}
|
||||
|
||||
self._initialized = True
|
||||
logger.info("All services initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize services: {e}")
|
||||
await self.shutdown()
|
||||
raise
|
||||
|
||||
async def _initialize_application_services(self):
|
||||
"""Initialize application-specific services."""
|
||||
try:
|
||||
# Initialize hardware service
|
||||
self.hardware_service = get_hardware_service()
|
||||
await self.hardware_service.initialize()
|
||||
logger.info("Hardware service initialized")
|
||||
|
||||
# Initialize pose service
|
||||
self.pose_service = get_pose_service()
|
||||
await self.pose_service.initialize()
|
||||
logger.info("Pose service initialized")
|
||||
|
||||
# Initialize stream service
|
||||
self.stream_service = get_stream_service()
|
||||
await self.stream_service.initialize()
|
||||
logger.info("Stream service initialized")
|
||||
|
||||
# Initialize pose stream handler
|
||||
self.pose_stream_handler = PoseStreamHandler(
|
||||
connection_manager=connection_manager,
|
||||
pose_service=self.pose_service,
|
||||
stream_service=self.stream_service
|
||||
)
|
||||
logger.info("Pose stream handler initialized")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize application services: {e}")
|
||||
raise
|
||||
|
||||
async def start(self):
|
||||
"""Start all services and background tasks."""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
if self._started:
|
||||
logger.warning("Services already started")
|
||||
return
|
||||
|
||||
logger.info("Starting services...")
|
||||
|
||||
try:
|
||||
# Start core services
|
||||
await self.health_service.start()
|
||||
await self.metrics_service.start()
|
||||
|
||||
# Start application services
|
||||
await self._start_application_services()
|
||||
|
||||
# Start background tasks
|
||||
await self._start_background_tasks()
|
||||
|
||||
self._started = True
|
||||
logger.info("All services started successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start services: {e}")
|
||||
await self.shutdown()
|
||||
raise
|
||||
|
||||
async def _start_application_services(self):
|
||||
"""Start application-specific services."""
|
||||
try:
|
||||
# Start hardware service
|
||||
if hasattr(self.hardware_service, 'start'):
|
||||
await self.hardware_service.start()
|
||||
|
||||
# Start pose service
|
||||
if hasattr(self.pose_service, 'start'):
|
||||
await self.pose_service.start()
|
||||
|
||||
# Start stream service
|
||||
if hasattr(self.stream_service, 'start'):
|
||||
await self.stream_service.start()
|
||||
|
||||
logger.info("Application services started")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start application services: {e}")
|
||||
raise
|
||||
|
||||
async def _start_background_tasks(self):
|
||||
"""Start background tasks."""
|
||||
try:
|
||||
# Start health check monitoring
|
||||
if self.settings.health_check_interval > 0:
|
||||
task = asyncio.create_task(self._health_check_loop())
|
||||
self._background_tasks.append(task)
|
||||
|
||||
# Start metrics collection
|
||||
if self.settings.metrics_enabled:
|
||||
task = asyncio.create_task(self._metrics_collection_loop())
|
||||
self._background_tasks.append(task)
|
||||
|
||||
# Start pose streaming if enabled
|
||||
if self.settings.enable_real_time_processing:
|
||||
await self.pose_stream_handler.start_streaming()
|
||||
|
||||
logger.info(f"Started {len(self._background_tasks)} background tasks")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start background tasks: {e}")
|
||||
raise
|
||||
|
||||
async def _health_check_loop(self):
|
||||
"""Background health check loop."""
|
||||
logger.info("Starting health check loop")
|
||||
|
||||
while True:
|
||||
try:
|
||||
await self.health_service.perform_health_checks()
|
||||
await asyncio.sleep(self.settings.health_check_interval)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Health check loop cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in health check loop: {e}")
|
||||
await asyncio.sleep(self.settings.health_check_interval)
|
||||
|
||||
async def _metrics_collection_loop(self):
|
||||
"""Background metrics collection loop."""
|
||||
logger.info("Starting metrics collection loop")
|
||||
|
||||
while True:
|
||||
try:
|
||||
await self.metrics_service.collect_metrics()
|
||||
await asyncio.sleep(60) # Collect metrics every minute
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Metrics collection loop cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in metrics collection loop: {e}")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown all services and cleanup resources."""
|
||||
logger.info("Shutting down services...")
|
||||
|
||||
try:
|
||||
# Cancel background tasks
|
||||
for task in self._background_tasks:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
|
||||
if self._background_tasks:
|
||||
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||
self._background_tasks.clear()
|
||||
|
||||
# Stop pose streaming
|
||||
if self.pose_stream_handler:
|
||||
await self.pose_stream_handler.shutdown()
|
||||
|
||||
# Shutdown connection manager
|
||||
await connection_manager.shutdown()
|
||||
|
||||
# Shutdown application services
|
||||
await self._shutdown_application_services()
|
||||
|
||||
# Shutdown core services
|
||||
await self.health_service.shutdown()
|
||||
await self.metrics_service.shutdown()
|
||||
|
||||
self._started = False
|
||||
self._initialized = False
|
||||
|
||||
logger.info("All services shut down successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during shutdown: {e}")
|
||||
|
||||
async def _shutdown_application_services(self):
|
||||
"""Shutdown application-specific services."""
|
||||
try:
|
||||
# Shutdown services in reverse order
|
||||
if self.stream_service and hasattr(self.stream_service, 'shutdown'):
|
||||
await self.stream_service.shutdown()
|
||||
|
||||
if self.pose_service and hasattr(self.pose_service, 'shutdown'):
|
||||
await self.pose_service.shutdown()
|
||||
|
||||
if self.hardware_service and hasattr(self.hardware_service, 'shutdown'):
|
||||
await self.hardware_service.shutdown()
|
||||
|
||||
logger.info("Application services shut down")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error shutting down application services: {e}")
|
||||
|
||||
async def restart_service(self, service_name: str):
|
||||
"""Restart a specific service."""
|
||||
logger.info(f"Restarting service: {service_name}")
|
||||
|
||||
service = self._services.get(service_name)
|
||||
if not service:
|
||||
raise ValueError(f"Service not found: {service_name}")
|
||||
|
||||
try:
|
||||
# Stop service
|
||||
if hasattr(service, 'stop'):
|
||||
await service.stop()
|
||||
elif hasattr(service, 'shutdown'):
|
||||
await service.shutdown()
|
||||
|
||||
# Reinitialize service
|
||||
if hasattr(service, 'initialize'):
|
||||
await service.initialize()
|
||||
|
||||
# Start service
|
||||
if hasattr(service, 'start'):
|
||||
await service.start()
|
||||
|
||||
logger.info(f"Service restarted successfully: {service_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to restart service {service_name}: {e}")
|
||||
raise
|
||||
|
||||
async def reset_services(self):
|
||||
"""Reset all services to initial state."""
|
||||
logger.info("Resetting all services")
|
||||
|
||||
try:
|
||||
# Reset application services
|
||||
if self.hardware_service and hasattr(self.hardware_service, 'reset'):
|
||||
await self.hardware_service.reset()
|
||||
|
||||
if self.pose_service and hasattr(self.pose_service, 'reset'):
|
||||
await self.pose_service.reset()
|
||||
|
||||
if self.stream_service and hasattr(self.stream_service, 'reset'):
|
||||
await self.stream_service.reset()
|
||||
|
||||
# Reset connection manager
|
||||
await connection_manager.reset()
|
||||
|
||||
logger.info("All services reset successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reset services: {e}")
|
||||
raise
|
||||
|
||||
async def get_service_status(self) -> Dict[str, Any]:
|
||||
"""Get status of all services."""
|
||||
status = {}
|
||||
|
||||
for name, service in self._services.items():
|
||||
try:
|
||||
if hasattr(service, 'get_status'):
|
||||
status[name] = await service.get_status()
|
||||
else:
|
||||
status[name] = {"status": "unknown"}
|
||||
except Exception as e:
|
||||
status[name] = {"status": "error", "error": str(e)}
|
||||
|
||||
return status
|
||||
|
||||
async def get_service_metrics(self) -> Dict[str, Any]:
|
||||
"""Get metrics from all services."""
|
||||
metrics = {}
|
||||
|
||||
for name, service in self._services.items():
|
||||
try:
|
||||
if hasattr(service, 'get_metrics'):
|
||||
metrics[name] = await service.get_metrics()
|
||||
elif hasattr(service, 'get_performance_metrics'):
|
||||
metrics[name] = await service.get_performance_metrics()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get metrics from {name}: {e}")
|
||||
metrics[name] = {"error": str(e)}
|
||||
|
||||
return metrics
|
||||
|
||||
async def get_service_info(self) -> Dict[str, Any]:
|
||||
"""Get information about all services."""
|
||||
info = {
|
||||
"total_services": len(self._services),
|
||||
"initialized": self._initialized,
|
||||
"started": self._started,
|
||||
"background_tasks": len(self._background_tasks),
|
||||
"services": {}
|
||||
}
|
||||
|
||||
for name, service in self._services.items():
|
||||
service_info = {
|
||||
"type": type(service).__name__,
|
||||
"module": type(service).__module__
|
||||
}
|
||||
|
||||
# Add service-specific info if available
|
||||
if hasattr(service, 'get_info'):
|
||||
try:
|
||||
service_info.update(await service.get_info())
|
||||
except Exception as e:
|
||||
service_info["error"] = str(e)
|
||||
|
||||
info["services"][name] = service_info
|
||||
|
||||
return info
|
||||
|
||||
def get_service(self, name: str) -> Optional[Any]:
|
||||
"""Get a specific service by name."""
|
||||
return self._services.get(name)
|
||||
|
||||
@property
|
||||
def is_healthy(self) -> bool:
|
||||
"""Check if all services are healthy."""
|
||||
return self._initialized and self._started
|
||||
|
||||
@asynccontextmanager
|
||||
async def service_context(self):
|
||||
"""Context manager for service lifecycle."""
|
||||
try:
|
||||
await self.initialize()
|
||||
await self.start()
|
||||
yield self
|
||||
finally:
|
||||
await self.shutdown()
|
||||
757
v1/src/services/pose_service.py
Normal file
757
v1/src/services/pose_service.py
Normal file
@@ -0,0 +1,757 @@
|
||||
"""
|
||||
Pose estimation service for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.config.domains import DomainConfig
|
||||
from src.core.csi_processor import CSIProcessor
|
||||
from src.core.phase_sanitizer import PhaseSanitizer
|
||||
from src.models.densepose_head import DensePoseHead
|
||||
from src.models.modality_translation import ModalityTranslationNetwork
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PoseService:
|
||||
"""Service for pose estimation operations."""
|
||||
|
||||
def __init__(self, settings: Settings, domain_config: DomainConfig):
|
||||
"""Initialize pose service."""
|
||||
self.settings = settings
|
||||
self.domain_config = domain_config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize components
|
||||
self.csi_processor = None
|
||||
self.phase_sanitizer = None
|
||||
self.densepose_model = None
|
||||
self.modality_translator = None
|
||||
|
||||
# Service state
|
||||
self.is_initialized = False
|
||||
self.is_running = False
|
||||
self.last_error = None
|
||||
|
||||
# Processing statistics
|
||||
self.stats = {
|
||||
"total_processed": 0,
|
||||
"successful_detections": 0,
|
||||
"failed_detections": 0,
|
||||
"average_confidence": 0.0,
|
||||
"processing_time_ms": 0.0
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the pose service."""
|
||||
try:
|
||||
self.logger.info("Initializing pose service...")
|
||||
|
||||
# Initialize CSI processor
|
||||
csi_config = {
|
||||
'buffer_size': self.settings.csi_buffer_size,
|
||||
'sampling_rate': getattr(self.settings, 'csi_sampling_rate', 1000),
|
||||
'window_size': getattr(self.settings, 'csi_window_size', 512),
|
||||
'overlap': getattr(self.settings, 'csi_overlap', 0.5),
|
||||
'noise_threshold': getattr(self.settings, 'csi_noise_threshold', 0.1),
|
||||
'human_detection_threshold': getattr(self.settings, 'csi_human_detection_threshold', 0.8),
|
||||
'smoothing_factor': getattr(self.settings, 'csi_smoothing_factor', 0.9),
|
||||
'max_history_size': getattr(self.settings, 'csi_max_history_size', 500),
|
||||
'num_subcarriers': 56,
|
||||
'num_antennas': 3
|
||||
}
|
||||
self.csi_processor = CSIProcessor(config=csi_config)
|
||||
|
||||
# Initialize phase sanitizer
|
||||
phase_config = {
|
||||
'unwrapping_method': 'numpy',
|
||||
'outlier_threshold': 3.0,
|
||||
'smoothing_window': 5,
|
||||
'enable_outlier_removal': True,
|
||||
'enable_smoothing': True,
|
||||
'enable_noise_filtering': True,
|
||||
'noise_threshold': getattr(self.settings, 'csi_noise_threshold', 0.1)
|
||||
}
|
||||
self.phase_sanitizer = PhaseSanitizer(config=phase_config)
|
||||
|
||||
# Initialize models if not mocking
|
||||
if not self.settings.mock_pose_data:
|
||||
await self._initialize_models()
|
||||
else:
|
||||
self.logger.info("Using mock pose data for development")
|
||||
|
||||
self.is_initialized = True
|
||||
self.logger.info("Pose service initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.last_error = str(e)
|
||||
self.logger.error(f"Failed to initialize pose service: {e}")
|
||||
raise
|
||||
|
||||
async def _initialize_models(self):
|
||||
"""Initialize neural network models."""
|
||||
try:
|
||||
# Initialize DensePose model
|
||||
if self.settings.pose_model_path:
|
||||
self.densepose_model = DensePoseHead()
|
||||
# Load model weights if path is provided
|
||||
# model_state = torch.load(self.settings.pose_model_path)
|
||||
# self.densepose_model.load_state_dict(model_state)
|
||||
self.logger.info("DensePose model loaded")
|
||||
else:
|
||||
self.logger.warning("No pose model path provided, using default model")
|
||||
self.densepose_model = DensePoseHead()
|
||||
|
||||
# Initialize modality translation
|
||||
config = {
|
||||
'input_channels': 64, # CSI data channels
|
||||
'hidden_channels': [128, 256, 512],
|
||||
'output_channels': 256, # Visual feature channels
|
||||
'use_attention': True
|
||||
}
|
||||
self.modality_translator = ModalityTranslationNetwork(config)
|
||||
|
||||
# Set models to evaluation mode
|
||||
self.densepose_model.eval()
|
||||
self.modality_translator.eval()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize models: {e}")
|
||||
raise
|
||||
|
||||
async def start(self):
|
||||
"""Start the pose service."""
|
||||
if not self.is_initialized:
|
||||
await self.initialize()
|
||||
|
||||
self.is_running = True
|
||||
self.logger.info("Pose service started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the pose service."""
|
||||
self.is_running = False
|
||||
self.logger.info("Pose service stopped")
|
||||
|
||||
async def process_csi_data(self, csi_data: np.ndarray, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Process CSI data and estimate poses."""
|
||||
if not self.is_running:
|
||||
raise RuntimeError("Pose service is not running")
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
# Process CSI data
|
||||
processed_csi = await self._process_csi(csi_data, metadata)
|
||||
|
||||
# Estimate poses
|
||||
poses = await self._estimate_poses(processed_csi, metadata)
|
||||
|
||||
# Update statistics
|
||||
processing_time = (datetime.now() - start_time).total_seconds() * 1000
|
||||
self._update_stats(poses, processing_time)
|
||||
|
||||
return {
|
||||
"timestamp": start_time.isoformat(),
|
||||
"poses": poses,
|
||||
"metadata": metadata,
|
||||
"processing_time_ms": processing_time,
|
||||
"confidence_scores": [pose.get("confidence", 0.0) for pose in poses]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.last_error = str(e)
|
||||
self.stats["failed_detections"] += 1
|
||||
self.logger.error(f"Error processing CSI data: {e}")
|
||||
raise
|
||||
|
||||
async def _process_csi(self, csi_data: np.ndarray, metadata: Dict[str, Any]) -> np.ndarray:
|
||||
"""Process raw CSI data."""
|
||||
# Convert raw data to CSIData format
|
||||
from src.hardware.csi_extractor import CSIData
|
||||
|
||||
# Create CSIData object with proper fields
|
||||
# For mock data, create amplitude and phase from input
|
||||
if csi_data.ndim == 1:
|
||||
amplitude = np.abs(csi_data)
|
||||
phase = np.angle(csi_data) if np.iscomplexobj(csi_data) else np.zeros_like(csi_data)
|
||||
else:
|
||||
amplitude = csi_data
|
||||
phase = np.zeros_like(csi_data)
|
||||
|
||||
csi_data_obj = CSIData(
|
||||
timestamp=metadata.get("timestamp", datetime.now()),
|
||||
amplitude=amplitude,
|
||||
phase=phase,
|
||||
frequency=metadata.get("frequency", 5.0), # 5 GHz default
|
||||
bandwidth=metadata.get("bandwidth", 20.0), # 20 MHz default
|
||||
num_subcarriers=metadata.get("num_subcarriers", 56),
|
||||
num_antennas=metadata.get("num_antennas", 3),
|
||||
snr=metadata.get("snr", 20.0), # 20 dB default
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
# Process CSI data
|
||||
try:
|
||||
detection_result = await self.csi_processor.process_csi_data(csi_data_obj)
|
||||
|
||||
# Add to history for temporal analysis
|
||||
self.csi_processor.add_to_history(csi_data_obj)
|
||||
|
||||
# Extract amplitude data for pose estimation
|
||||
if detection_result and detection_result.features:
|
||||
amplitude_data = detection_result.features.amplitude_mean
|
||||
|
||||
# Apply phase sanitization if we have phase data
|
||||
if hasattr(detection_result.features, 'phase_difference'):
|
||||
phase_data = detection_result.features.phase_difference
|
||||
sanitized_phase = self.phase_sanitizer.sanitize(phase_data)
|
||||
# Combine amplitude and phase data
|
||||
return np.concatenate([amplitude_data, sanitized_phase])
|
||||
|
||||
return amplitude_data
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"CSI processing failed, using raw data: {e}")
|
||||
|
||||
return csi_data
|
||||
|
||||
async def _estimate_poses(self, csi_data: np.ndarray, metadata: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Estimate poses from processed CSI data."""
|
||||
if self.settings.mock_pose_data:
|
||||
return self._generate_mock_poses()
|
||||
|
||||
try:
|
||||
# Convert CSI data to tensor
|
||||
csi_tensor = torch.from_numpy(csi_data).float()
|
||||
|
||||
# Add batch dimension if needed
|
||||
if len(csi_tensor.shape) == 2:
|
||||
csi_tensor = csi_tensor.unsqueeze(0)
|
||||
|
||||
# Translate modality (CSI to visual-like features)
|
||||
with torch.no_grad():
|
||||
visual_features = self.modality_translator(csi_tensor)
|
||||
|
||||
# Estimate poses using DensePose
|
||||
pose_outputs = self.densepose_model(visual_features)
|
||||
|
||||
# Convert outputs to pose detections
|
||||
poses = self._parse_pose_outputs(pose_outputs)
|
||||
|
||||
# Filter by confidence threshold
|
||||
filtered_poses = [
|
||||
pose for pose in poses
|
||||
if pose.get("confidence", 0.0) >= self.settings.pose_confidence_threshold
|
||||
]
|
||||
|
||||
# Limit number of persons
|
||||
if len(filtered_poses) > self.settings.pose_max_persons:
|
||||
filtered_poses = sorted(
|
||||
filtered_poses,
|
||||
key=lambda x: x.get("confidence", 0.0),
|
||||
reverse=True
|
||||
)[:self.settings.pose_max_persons]
|
||||
|
||||
return filtered_poses
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in pose estimation: {e}")
|
||||
return []
|
||||
|
||||
def _parse_pose_outputs(self, outputs: torch.Tensor) -> List[Dict[str, Any]]:
|
||||
"""Parse neural network outputs into pose detections."""
|
||||
poses = []
|
||||
|
||||
# This is a simplified parsing - in reality, this would depend on the model architecture
|
||||
# For now, generate mock poses based on the output shape
|
||||
batch_size = outputs.shape[0]
|
||||
|
||||
for i in range(batch_size):
|
||||
# Extract pose information (mock implementation)
|
||||
confidence = float(torch.sigmoid(outputs[i, 0]).item()) if outputs.shape[1] > 0 else 0.5
|
||||
|
||||
pose = {
|
||||
"person_id": i,
|
||||
"confidence": confidence,
|
||||
"keypoints": self._generate_keypoints(),
|
||||
"bounding_box": self._generate_bounding_box(),
|
||||
"activity": self._classify_activity(outputs[i] if len(outputs.shape) > 1 else outputs),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
poses.append(pose)
|
||||
|
||||
return poses
|
||||
|
||||
def _generate_mock_poses(self) -> List[Dict[str, Any]]:
|
||||
"""Generate mock pose data for development."""
|
||||
import random
|
||||
|
||||
num_persons = random.randint(1, min(3, self.settings.pose_max_persons))
|
||||
poses = []
|
||||
|
||||
for i in range(num_persons):
|
||||
confidence = random.uniform(0.3, 0.95)
|
||||
|
||||
pose = {
|
||||
"person_id": i,
|
||||
"confidence": confidence,
|
||||
"keypoints": self._generate_keypoints(),
|
||||
"bounding_box": self._generate_bounding_box(),
|
||||
"activity": random.choice(["standing", "sitting", "walking", "lying"]),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
poses.append(pose)
|
||||
|
||||
return poses
|
||||
|
||||
def _generate_keypoints(self) -> List[Dict[str, Any]]:
|
||||
"""Generate keypoints for a person."""
|
||||
import random
|
||||
|
||||
keypoint_names = [
|
||||
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
|
||||
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
|
||||
"left_wrist", "right_wrist", "left_hip", "right_hip",
|
||||
"left_knee", "right_knee", "left_ankle", "right_ankle"
|
||||
]
|
||||
|
||||
keypoints = []
|
||||
for name in keypoint_names:
|
||||
keypoints.append({
|
||||
"name": name,
|
||||
"x": random.uniform(0.1, 0.9),
|
||||
"y": random.uniform(0.1, 0.9),
|
||||
"confidence": random.uniform(0.5, 0.95)
|
||||
})
|
||||
|
||||
return keypoints
|
||||
|
||||
def _generate_bounding_box(self) -> Dict[str, float]:
|
||||
"""Generate bounding box for a person."""
|
||||
import random
|
||||
|
||||
x = random.uniform(0.1, 0.6)
|
||||
y = random.uniform(0.1, 0.6)
|
||||
width = random.uniform(0.2, 0.4)
|
||||
height = random.uniform(0.3, 0.5)
|
||||
|
||||
return {
|
||||
"x": x,
|
||||
"y": y,
|
||||
"width": width,
|
||||
"height": height
|
||||
}
|
||||
|
||||
def _classify_activity(self, features: torch.Tensor) -> str:
|
||||
"""Classify activity from features."""
|
||||
# Simple mock classification
|
||||
import random
|
||||
activities = ["standing", "sitting", "walking", "lying", "unknown"]
|
||||
return random.choice(activities)
|
||||
|
||||
def _update_stats(self, poses: List[Dict[str, Any]], processing_time: float):
|
||||
"""Update processing statistics."""
|
||||
self.stats["total_processed"] += 1
|
||||
|
||||
if poses:
|
||||
self.stats["successful_detections"] += 1
|
||||
confidences = [pose.get("confidence", 0.0) for pose in poses]
|
||||
avg_confidence = sum(confidences) / len(confidences)
|
||||
|
||||
# Update running average
|
||||
total = self.stats["successful_detections"]
|
||||
current_avg = self.stats["average_confidence"]
|
||||
self.stats["average_confidence"] = (current_avg * (total - 1) + avg_confidence) / total
|
||||
else:
|
||||
self.stats["failed_detections"] += 1
|
||||
|
||||
# Update processing time (running average)
|
||||
total = self.stats["total_processed"]
|
||||
current_avg = self.stats["processing_time_ms"]
|
||||
self.stats["processing_time_ms"] = (current_avg * (total - 1) + processing_time) / total
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get service status."""
|
||||
return {
|
||||
"status": "healthy" if self.is_running and not self.last_error else "unhealthy",
|
||||
"initialized": self.is_initialized,
|
||||
"running": self.is_running,
|
||||
"last_error": self.last_error,
|
||||
"statistics": self.stats.copy(),
|
||||
"configuration": {
|
||||
"mock_data": self.settings.mock_pose_data,
|
||||
"confidence_threshold": self.settings.pose_confidence_threshold,
|
||||
"max_persons": self.settings.pose_max_persons,
|
||||
"batch_size": self.settings.pose_processing_batch_size
|
||||
}
|
||||
}
|
||||
|
||||
async def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get service metrics."""
|
||||
return {
|
||||
"pose_service": {
|
||||
"total_processed": self.stats["total_processed"],
|
||||
"successful_detections": self.stats["successful_detections"],
|
||||
"failed_detections": self.stats["failed_detections"],
|
||||
"success_rate": (
|
||||
self.stats["successful_detections"] / max(1, self.stats["total_processed"])
|
||||
),
|
||||
"average_confidence": self.stats["average_confidence"],
|
||||
"average_processing_time_ms": self.stats["processing_time_ms"]
|
||||
}
|
||||
}
|
||||
|
||||
async def reset(self):
|
||||
"""Reset service state."""
|
||||
self.stats = {
|
||||
"total_processed": 0,
|
||||
"successful_detections": 0,
|
||||
"failed_detections": 0,
|
||||
"average_confidence": 0.0,
|
||||
"processing_time_ms": 0.0
|
||||
}
|
||||
self.last_error = None
|
||||
self.logger.info("Pose service reset")
|
||||
|
||||
# API endpoint methods
|
||||
async def estimate_poses(self, zone_ids=None, confidence_threshold=None, max_persons=None,
|
||||
include_keypoints=True, include_segmentation=False):
|
||||
"""Estimate poses with API parameters."""
|
||||
try:
|
||||
# Generate mock CSI data for estimation
|
||||
mock_csi = np.random.randn(64, 56, 3) # Mock CSI data
|
||||
metadata = {
|
||||
"timestamp": datetime.now(),
|
||||
"zone_ids": zone_ids or ["zone_1"],
|
||||
"confidence_threshold": confidence_threshold or self.settings.pose_confidence_threshold,
|
||||
"max_persons": max_persons or self.settings.pose_max_persons
|
||||
}
|
||||
|
||||
# Process the data
|
||||
result = await self.process_csi_data(mock_csi, metadata)
|
||||
|
||||
# Format for API response
|
||||
persons = []
|
||||
for i, pose in enumerate(result["poses"]):
|
||||
person = {
|
||||
"person_id": str(pose["person_id"]),
|
||||
"confidence": pose["confidence"],
|
||||
"bounding_box": pose["bounding_box"],
|
||||
"zone_id": zone_ids[0] if zone_ids else "zone_1",
|
||||
"activity": pose["activity"],
|
||||
"timestamp": datetime.fromisoformat(pose["timestamp"])
|
||||
}
|
||||
|
||||
if include_keypoints:
|
||||
person["keypoints"] = pose["keypoints"]
|
||||
|
||||
if include_segmentation:
|
||||
person["segmentation"] = {"mask": "mock_segmentation_data"}
|
||||
|
||||
persons.append(person)
|
||||
|
||||
# Zone summary
|
||||
zone_summary = {}
|
||||
for zone_id in (zone_ids or ["zone_1"]):
|
||||
zone_summary[zone_id] = len([p for p in persons if p.get("zone_id") == zone_id])
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(),
|
||||
"frame_id": f"frame_{int(datetime.now().timestamp())}",
|
||||
"persons": persons,
|
||||
"zone_summary": zone_summary,
|
||||
"processing_time_ms": result["processing_time_ms"],
|
||||
"metadata": {"mock_data": self.settings.mock_pose_data}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in estimate_poses: {e}")
|
||||
raise
|
||||
|
||||
async def analyze_with_params(self, zone_ids=None, confidence_threshold=None, max_persons=None,
|
||||
include_keypoints=True, include_segmentation=False):
|
||||
"""Analyze pose data with custom parameters."""
|
||||
return await self.estimate_poses(zone_ids, confidence_threshold, max_persons,
|
||||
include_keypoints, include_segmentation)
|
||||
|
||||
async def get_zone_occupancy(self, zone_id: str):
|
||||
"""Get current occupancy for a specific zone."""
|
||||
try:
|
||||
# Mock occupancy data
|
||||
import random
|
||||
count = random.randint(0, 5)
|
||||
persons = []
|
||||
|
||||
for i in range(count):
|
||||
persons.append({
|
||||
"person_id": f"person_{i}",
|
||||
"confidence": random.uniform(0.7, 0.95),
|
||||
"activity": random.choice(["standing", "sitting", "walking"])
|
||||
})
|
||||
|
||||
return {
|
||||
"count": count,
|
||||
"max_occupancy": 10,
|
||||
"persons": persons,
|
||||
"timestamp": datetime.now()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting zone occupancy: {e}")
|
||||
return None
|
||||
|
||||
async def get_zones_summary(self):
|
||||
"""Get occupancy summary for all zones."""
|
||||
try:
|
||||
import random
|
||||
zones = ["zone_1", "zone_2", "zone_3", "zone_4"]
|
||||
zone_data = {}
|
||||
total_persons = 0
|
||||
active_zones = 0
|
||||
|
||||
for zone_id in zones:
|
||||
count = random.randint(0, 3)
|
||||
zone_data[zone_id] = {
|
||||
"occupancy": count,
|
||||
"max_occupancy": 10,
|
||||
"status": "active" if count > 0 else "inactive"
|
||||
}
|
||||
total_persons += count
|
||||
if count > 0:
|
||||
active_zones += 1
|
||||
|
||||
return {
|
||||
"total_persons": total_persons,
|
||||
"zones": zone_data,
|
||||
"active_zones": active_zones
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting zones summary: {e}")
|
||||
raise
|
||||
|
||||
async def get_historical_data(self, start_time, end_time, zone_ids=None,
|
||||
aggregation_interval=300, include_raw_data=False):
|
||||
"""Get historical pose estimation data."""
|
||||
try:
|
||||
# Mock historical data
|
||||
import random
|
||||
from datetime import timedelta
|
||||
|
||||
current_time = start_time
|
||||
aggregated_data = []
|
||||
raw_data = [] if include_raw_data else None
|
||||
|
||||
while current_time < end_time:
|
||||
# Generate aggregated data point
|
||||
data_point = {
|
||||
"timestamp": current_time,
|
||||
"total_persons": random.randint(0, 8),
|
||||
"zones": {}
|
||||
}
|
||||
|
||||
for zone_id in (zone_ids or ["zone_1", "zone_2", "zone_3"]):
|
||||
data_point["zones"][zone_id] = {
|
||||
"occupancy": random.randint(0, 3),
|
||||
"avg_confidence": random.uniform(0.7, 0.95)
|
||||
}
|
||||
|
||||
aggregated_data.append(data_point)
|
||||
|
||||
# Generate raw data if requested
|
||||
if include_raw_data:
|
||||
for _ in range(random.randint(0, 5)):
|
||||
raw_data.append({
|
||||
"timestamp": current_time + timedelta(seconds=random.randint(0, aggregation_interval)),
|
||||
"person_id": f"person_{random.randint(1, 10)}",
|
||||
"zone_id": random.choice(zone_ids or ["zone_1", "zone_2", "zone_3"]),
|
||||
"confidence": random.uniform(0.5, 0.95),
|
||||
"activity": random.choice(["standing", "sitting", "walking"])
|
||||
})
|
||||
|
||||
current_time += timedelta(seconds=aggregation_interval)
|
||||
|
||||
return {
|
||||
"aggregated_data": aggregated_data,
|
||||
"raw_data": raw_data,
|
||||
"total_records": len(aggregated_data)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting historical data: {e}")
|
||||
raise
|
||||
|
||||
async def get_recent_activities(self, zone_id=None, limit=10):
|
||||
"""Get recently detected activities."""
|
||||
try:
|
||||
import random
|
||||
activities = []
|
||||
|
||||
for i in range(limit):
|
||||
activity = {
|
||||
"activity_id": f"activity_{i}",
|
||||
"person_id": f"person_{random.randint(1, 5)}",
|
||||
"zone_id": zone_id or random.choice(["zone_1", "zone_2", "zone_3"]),
|
||||
"activity": random.choice(["standing", "sitting", "walking", "lying"]),
|
||||
"confidence": random.uniform(0.6, 0.95),
|
||||
"timestamp": datetime.now() - timedelta(minutes=random.randint(0, 60)),
|
||||
"duration_seconds": random.randint(10, 300)
|
||||
}
|
||||
activities.append(activity)
|
||||
|
||||
return activities
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting recent activities: {e}")
|
||||
raise
|
||||
|
||||
async def is_calibrating(self):
|
||||
"""Check if calibration is in progress."""
|
||||
return False # Mock implementation
|
||||
|
||||
async def start_calibration(self):
|
||||
"""Start calibration process."""
|
||||
import uuid
|
||||
calibration_id = str(uuid.uuid4())
|
||||
self.logger.info(f"Started calibration: {calibration_id}")
|
||||
return calibration_id
|
||||
|
||||
async def run_calibration(self, calibration_id):
|
||||
"""Run calibration process."""
|
||||
self.logger.info(f"Running calibration: {calibration_id}")
|
||||
# Mock calibration process
|
||||
await asyncio.sleep(5)
|
||||
self.logger.info(f"Calibration completed: {calibration_id}")
|
||||
|
||||
async def get_calibration_status(self):
|
||||
"""Get current calibration status."""
|
||||
return {
|
||||
"is_calibrating": False,
|
||||
"calibration_id": None,
|
||||
"progress_percent": 100,
|
||||
"current_step": "completed",
|
||||
"estimated_remaining_minutes": 0,
|
||||
"last_calibration": datetime.now() - timedelta(hours=1)
|
||||
}
|
||||
|
||||
async def get_statistics(self, start_time, end_time):
|
||||
"""Get pose estimation statistics."""
|
||||
try:
|
||||
import random
|
||||
|
||||
# Mock statistics
|
||||
total_detections = random.randint(100, 1000)
|
||||
successful_detections = int(total_detections * random.uniform(0.8, 0.95))
|
||||
|
||||
return {
|
||||
"total_detections": total_detections,
|
||||
"successful_detections": successful_detections,
|
||||
"failed_detections": total_detections - successful_detections,
|
||||
"success_rate": successful_detections / total_detections,
|
||||
"average_confidence": random.uniform(0.75, 0.90),
|
||||
"average_processing_time_ms": random.uniform(50, 200),
|
||||
"unique_persons": random.randint(5, 20),
|
||||
"most_active_zone": random.choice(["zone_1", "zone_2", "zone_3"]),
|
||||
"activity_distribution": {
|
||||
"standing": random.uniform(0.3, 0.5),
|
||||
"sitting": random.uniform(0.2, 0.4),
|
||||
"walking": random.uniform(0.1, 0.3),
|
||||
"lying": random.uniform(0.0, 0.1)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting statistics: {e}")
|
||||
raise
|
||||
|
||||
async def process_segmentation_data(self, frame_id):
|
||||
"""Process segmentation data in background."""
|
||||
self.logger.info(f"Processing segmentation data for frame: {frame_id}")
|
||||
# Mock background processing
|
||||
await asyncio.sleep(2)
|
||||
self.logger.info(f"Segmentation processing completed for frame: {frame_id}")
|
||||
|
||||
# WebSocket streaming methods
|
||||
async def get_current_pose_data(self):
|
||||
"""Get current pose data for streaming."""
|
||||
try:
|
||||
# Generate current pose data
|
||||
result = await self.estimate_poses()
|
||||
|
||||
# Format data by zones for WebSocket streaming
|
||||
zone_data = {}
|
||||
|
||||
# Group persons by zone
|
||||
for person in result["persons"]:
|
||||
zone_id = person.get("zone_id", "zone_1")
|
||||
|
||||
if zone_id not in zone_data:
|
||||
zone_data[zone_id] = {
|
||||
"pose": {
|
||||
"persons": [],
|
||||
"count": 0
|
||||
},
|
||||
"confidence": 0.0,
|
||||
"activity": None,
|
||||
"metadata": {
|
||||
"frame_id": result["frame_id"],
|
||||
"processing_time_ms": result["processing_time_ms"]
|
||||
}
|
||||
}
|
||||
|
||||
zone_data[zone_id]["pose"]["persons"].append(person)
|
||||
zone_data[zone_id]["pose"]["count"] += 1
|
||||
|
||||
# Update zone confidence (average)
|
||||
current_confidence = zone_data[zone_id]["confidence"]
|
||||
person_confidence = person.get("confidence", 0.0)
|
||||
zone_data[zone_id]["confidence"] = (current_confidence + person_confidence) / 2
|
||||
|
||||
# Set activity if not already set
|
||||
if not zone_data[zone_id]["activity"] and person.get("activity"):
|
||||
zone_data[zone_id]["activity"] = person["activity"]
|
||||
|
||||
return zone_data
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting current pose data: {e}")
|
||||
# Return empty zone data on error
|
||||
return {}
|
||||
|
||||
# Health check methods
|
||||
async def health_check(self):
|
||||
"""Perform health check."""
|
||||
try:
|
||||
status = "healthy" if self.is_running and not self.last_error else "unhealthy"
|
||||
|
||||
return {
|
||||
"status": status,
|
||||
"message": self.last_error if self.last_error else "Service is running normally",
|
||||
"uptime_seconds": 0.0, # TODO: Implement actual uptime tracking
|
||||
"metrics": {
|
||||
"total_processed": self.stats["total_processed"],
|
||||
"success_rate": (
|
||||
self.stats["successful_detections"] / max(1, self.stats["total_processed"])
|
||||
),
|
||||
"average_processing_time_ms": self.stats["processing_time_ms"]
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"message": f"Health check failed: {str(e)}"
|
||||
}
|
||||
|
||||
async def is_ready(self):
|
||||
"""Check if service is ready."""
|
||||
return self.is_initialized and self.is_running
|
||||
397
v1/src/services/stream_service.py
Normal file
397
v1/src/services/stream_service.py
Normal file
@@ -0,0 +1,397 @@
|
||||
"""
|
||||
Real-time streaming service for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, List, Optional, Any, Set
|
||||
from datetime import datetime
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
from fastapi import WebSocket
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.config.domains import DomainConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamService:
|
||||
"""Service for real-time data streaming."""
|
||||
|
||||
def __init__(self, settings: Settings, domain_config: DomainConfig):
|
||||
"""Initialize stream service."""
|
||||
self.settings = settings
|
||||
self.domain_config = domain_config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# WebSocket connections
|
||||
self.connections: Set[WebSocket] = set()
|
||||
self.connection_metadata: Dict[WebSocket, Dict[str, Any]] = {}
|
||||
|
||||
# Stream buffers
|
||||
self.pose_buffer = deque(maxlen=self.settings.stream_buffer_size)
|
||||
self.csi_buffer = deque(maxlen=self.settings.stream_buffer_size)
|
||||
|
||||
# Service state
|
||||
self.is_running = False
|
||||
self.last_error = None
|
||||
|
||||
# Streaming statistics
|
||||
self.stats = {
|
||||
"active_connections": 0,
|
||||
"total_connections": 0,
|
||||
"messages_sent": 0,
|
||||
"messages_failed": 0,
|
||||
"data_points_streamed": 0,
|
||||
"average_latency_ms": 0.0
|
||||
}
|
||||
|
||||
# Background tasks
|
||||
self.streaming_task = None
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the stream service."""
|
||||
self.logger.info("Stream service initialized")
|
||||
|
||||
async def start(self):
|
||||
"""Start the stream service."""
|
||||
if self.is_running:
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
self.logger.info("Stream service started")
|
||||
|
||||
# Start background streaming task
|
||||
if self.settings.enable_real_time_processing:
|
||||
self.streaming_task = asyncio.create_task(self._streaming_loop())
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the stream service."""
|
||||
self.is_running = False
|
||||
|
||||
# Cancel background task
|
||||
if self.streaming_task:
|
||||
self.streaming_task.cancel()
|
||||
try:
|
||||
await self.streaming_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Close all connections
|
||||
await self._close_all_connections()
|
||||
|
||||
self.logger.info("Stream service stopped")
|
||||
|
||||
async def add_connection(self, websocket: WebSocket, metadata: Dict[str, Any] = None):
|
||||
"""Add a new WebSocket connection."""
|
||||
try:
|
||||
await websocket.accept()
|
||||
self.connections.add(websocket)
|
||||
self.connection_metadata[websocket] = metadata or {}
|
||||
|
||||
self.stats["active_connections"] = len(self.connections)
|
||||
self.stats["total_connections"] += 1
|
||||
|
||||
self.logger.info(f"New WebSocket connection added. Total: {len(self.connections)}")
|
||||
|
||||
# Send initial data if available
|
||||
await self._send_initial_data(websocket)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error adding WebSocket connection: {e}")
|
||||
raise
|
||||
|
||||
async def remove_connection(self, websocket: WebSocket):
|
||||
"""Remove a WebSocket connection."""
|
||||
try:
|
||||
if websocket in self.connections:
|
||||
self.connections.remove(websocket)
|
||||
self.connection_metadata.pop(websocket, None)
|
||||
|
||||
self.stats["active_connections"] = len(self.connections)
|
||||
|
||||
self.logger.info(f"WebSocket connection removed. Total: {len(self.connections)}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error removing WebSocket connection: {e}")
|
||||
|
||||
async def broadcast_pose_data(self, pose_data: Dict[str, Any]):
|
||||
"""Broadcast pose data to all connected clients."""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
# Add to buffer
|
||||
self.pose_buffer.append({
|
||||
"type": "pose_data",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": pose_data
|
||||
})
|
||||
|
||||
# Broadcast to all connections
|
||||
await self._broadcast_message({
|
||||
"type": "pose_update",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": pose_data
|
||||
})
|
||||
|
||||
async def broadcast_csi_data(self, csi_data: np.ndarray, metadata: Dict[str, Any]):
|
||||
"""Broadcast CSI data to all connected clients."""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
# Convert numpy array to list for JSON serialization
|
||||
csi_list = csi_data.tolist() if isinstance(csi_data, np.ndarray) else csi_data
|
||||
|
||||
# Add to buffer
|
||||
self.csi_buffer.append({
|
||||
"type": "csi_data",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": csi_list,
|
||||
"metadata": metadata
|
||||
})
|
||||
|
||||
# Broadcast to all connections
|
||||
await self._broadcast_message({
|
||||
"type": "csi_update",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": csi_list,
|
||||
"metadata": metadata
|
||||
})
|
||||
|
||||
async def broadcast_system_status(self, status_data: Dict[str, Any]):
|
||||
"""Broadcast system status to all connected clients."""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
await self._broadcast_message({
|
||||
"type": "system_status",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": status_data
|
||||
})
|
||||
|
||||
async def send_to_connection(self, websocket: WebSocket, message: Dict[str, Any]):
|
||||
"""Send message to a specific connection."""
|
||||
try:
|
||||
if websocket in self.connections:
|
||||
await websocket.send_text(json.dumps(message))
|
||||
self.stats["messages_sent"] += 1
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error sending message to connection: {e}")
|
||||
self.stats["messages_failed"] += 1
|
||||
await self.remove_connection(websocket)
|
||||
|
||||
async def _broadcast_message(self, message: Dict[str, Any]):
|
||||
"""Broadcast message to all connected clients."""
|
||||
if not self.connections:
|
||||
return
|
||||
|
||||
disconnected = set()
|
||||
|
||||
for websocket in self.connections.copy():
|
||||
try:
|
||||
await websocket.send_text(json.dumps(message))
|
||||
self.stats["messages_sent"] += 1
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to send message to connection: {e}")
|
||||
self.stats["messages_failed"] += 1
|
||||
disconnected.add(websocket)
|
||||
|
||||
# Remove disconnected clients
|
||||
for websocket in disconnected:
|
||||
await self.remove_connection(websocket)
|
||||
|
||||
if message.get("type") in ["pose_update", "csi_update"]:
|
||||
self.stats["data_points_streamed"] += 1
|
||||
|
||||
async def _send_initial_data(self, websocket: WebSocket):
|
||||
"""Send initial data to a new connection."""
|
||||
try:
|
||||
# Send recent pose data
|
||||
if self.pose_buffer:
|
||||
recent_poses = list(self.pose_buffer)[-10:] # Last 10 poses
|
||||
await self.send_to_connection(websocket, {
|
||||
"type": "initial_poses",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": recent_poses
|
||||
})
|
||||
|
||||
# Send recent CSI data
|
||||
if self.csi_buffer:
|
||||
recent_csi = list(self.csi_buffer)[-5:] # Last 5 CSI readings
|
||||
await self.send_to_connection(websocket, {
|
||||
"type": "initial_csi",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": recent_csi
|
||||
})
|
||||
|
||||
# Send service status
|
||||
status = await self.get_status()
|
||||
await self.send_to_connection(websocket, {
|
||||
"type": "service_status",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": status
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error sending initial data: {e}")
|
||||
|
||||
async def _streaming_loop(self):
|
||||
"""Background streaming loop for periodic updates."""
|
||||
try:
|
||||
while self.is_running:
|
||||
# Send periodic heartbeat
|
||||
if self.connections:
|
||||
await self._broadcast_message({
|
||||
"type": "heartbeat",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"active_connections": len(self.connections)
|
||||
})
|
||||
|
||||
# Wait for next iteration
|
||||
await asyncio.sleep(self.settings.websocket_ping_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self.logger.info("Streaming loop cancelled")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in streaming loop: {e}")
|
||||
self.last_error = str(e)
|
||||
|
||||
async def _close_all_connections(self):
|
||||
"""Close all WebSocket connections."""
|
||||
disconnected = []
|
||||
|
||||
for websocket in self.connections.copy():
|
||||
try:
|
||||
await websocket.close()
|
||||
disconnected.append(websocket)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error closing connection: {e}")
|
||||
disconnected.append(websocket)
|
||||
|
||||
# Clear all connections
|
||||
for websocket in disconnected:
|
||||
await self.remove_connection(websocket)
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get service status."""
|
||||
return {
|
||||
"status": "healthy" if self.is_running and not self.last_error else "unhealthy",
|
||||
"running": self.is_running,
|
||||
"last_error": self.last_error,
|
||||
"connections": {
|
||||
"active": len(self.connections),
|
||||
"total": self.stats["total_connections"]
|
||||
},
|
||||
"buffers": {
|
||||
"pose_buffer_size": len(self.pose_buffer),
|
||||
"csi_buffer_size": len(self.csi_buffer),
|
||||
"max_buffer_size": self.settings.stream_buffer_size
|
||||
},
|
||||
"statistics": self.stats.copy(),
|
||||
"configuration": {
|
||||
"stream_fps": self.settings.stream_fps,
|
||||
"buffer_size": self.settings.stream_buffer_size,
|
||||
"ping_interval": self.settings.websocket_ping_interval,
|
||||
"timeout": self.settings.websocket_timeout
|
||||
}
|
||||
}
|
||||
|
||||
async def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get service metrics."""
|
||||
total_messages = self.stats["messages_sent"] + self.stats["messages_failed"]
|
||||
success_rate = self.stats["messages_sent"] / max(1, total_messages)
|
||||
|
||||
return {
|
||||
"stream_service": {
|
||||
"active_connections": self.stats["active_connections"],
|
||||
"total_connections": self.stats["total_connections"],
|
||||
"messages_sent": self.stats["messages_sent"],
|
||||
"messages_failed": self.stats["messages_failed"],
|
||||
"message_success_rate": success_rate,
|
||||
"data_points_streamed": self.stats["data_points_streamed"],
|
||||
"average_latency_ms": self.stats["average_latency_ms"]
|
||||
}
|
||||
}
|
||||
|
||||
async def get_connection_info(self) -> List[Dict[str, Any]]:
|
||||
"""Get information about active connections."""
|
||||
connections_info = []
|
||||
|
||||
for websocket in self.connections:
|
||||
metadata = self.connection_metadata.get(websocket, {})
|
||||
|
||||
connection_info = {
|
||||
"id": id(websocket),
|
||||
"connected_at": metadata.get("connected_at", "unknown"),
|
||||
"user_agent": metadata.get("user_agent", "unknown"),
|
||||
"ip_address": metadata.get("ip_address", "unknown"),
|
||||
"subscription_types": metadata.get("subscription_types", [])
|
||||
}
|
||||
|
||||
connections_info.append(connection_info)
|
||||
|
||||
return connections_info
|
||||
|
||||
async def reset(self):
|
||||
"""Reset service state."""
|
||||
# Clear buffers
|
||||
self.pose_buffer.clear()
|
||||
self.csi_buffer.clear()
|
||||
|
||||
# Reset statistics
|
||||
self.stats = {
|
||||
"active_connections": len(self.connections),
|
||||
"total_connections": 0,
|
||||
"messages_sent": 0,
|
||||
"messages_failed": 0,
|
||||
"data_points_streamed": 0,
|
||||
"average_latency_ms": 0.0
|
||||
}
|
||||
|
||||
self.last_error = None
|
||||
self.logger.info("Stream service reset")
|
||||
|
||||
def get_buffer_data(self, buffer_type: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Get data from buffers."""
|
||||
if buffer_type == "pose":
|
||||
return list(self.pose_buffer)[-limit:]
|
||||
elif buffer_type == "csi":
|
||||
return list(self.csi_buffer)[-limit:]
|
||||
else:
|
||||
return []
|
||||
|
||||
@property
|
||||
def is_active(self) -> bool:
|
||||
"""Check if stream service is active."""
|
||||
return self.is_running
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Perform health check."""
|
||||
try:
|
||||
status = "healthy" if self.is_running and not self.last_error else "unhealthy"
|
||||
|
||||
return {
|
||||
"status": status,
|
||||
"message": self.last_error if self.last_error else "Stream service is running normally",
|
||||
"active_connections": len(self.connections),
|
||||
"metrics": {
|
||||
"messages_sent": self.stats["messages_sent"],
|
||||
"messages_failed": self.stats["messages_failed"],
|
||||
"data_points_streamed": self.stats["data_points_streamed"]
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"message": f"Health check failed: {str(e)}"
|
||||
}
|
||||
|
||||
async def is_ready(self) -> bool:
|
||||
"""Check if service is ready."""
|
||||
return self.is_running
|
||||
612
v1/src/tasks/backup.py
Normal file
612
v1/src/tasks/backup.py
Normal file
@@ -0,0 +1,612 @@
|
||||
"""
|
||||
Backup tasks for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import gzip
|
||||
import json
|
||||
import subprocess
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.database.connection import get_database_manager
|
||||
from src.database.models import Device, Session, CSIData, PoseDetection, SystemMetric, AuditLog
|
||||
from src.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BackupTask:
|
||||
"""Base class for backup tasks."""
|
||||
|
||||
def __init__(self, name: str, settings: Settings):
|
||||
self.name = name
|
||||
self.settings = settings
|
||||
self.enabled = True
|
||||
self.last_run = None
|
||||
self.run_count = 0
|
||||
self.error_count = 0
|
||||
self.backup_dir = Path(settings.backup_directory)
|
||||
self.backup_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def execute_backup(self, session: AsyncSession) -> Dict[str, Any]:
|
||||
"""Execute the backup task."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def run(self, session: AsyncSession) -> Dict[str, Any]:
|
||||
"""Run the backup task with error handling."""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
logger.info(f"Starting backup task: {self.name}")
|
||||
|
||||
result = await self.execute_backup(session)
|
||||
|
||||
self.last_run = start_time
|
||||
self.run_count += 1
|
||||
|
||||
logger.info(
|
||||
f"Backup task {self.name} completed: "
|
||||
f"backed up {result.get('backup_size_mb', 0):.2f} MB"
|
||||
)
|
||||
|
||||
return {
|
||||
"task": self.name,
|
||||
"status": "success",
|
||||
"start_time": start_time.isoformat(),
|
||||
"duration_ms": (datetime.utcnow() - start_time).total_seconds() * 1000,
|
||||
**result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.error_count += 1
|
||||
logger.error(f"Backup task {self.name} failed: {e}", exc_info=True)
|
||||
|
||||
return {
|
||||
"task": self.name,
|
||||
"status": "error",
|
||||
"start_time": start_time.isoformat(),
|
||||
"duration_ms": (datetime.utcnow() - start_time).total_seconds() * 1000,
|
||||
"error": str(e),
|
||||
"backup_size_mb": 0
|
||||
}
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get task statistics."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"enabled": self.enabled,
|
||||
"last_run": self.last_run.isoformat() if self.last_run else None,
|
||||
"run_count": self.run_count,
|
||||
"error_count": self.error_count,
|
||||
"backup_directory": str(self.backup_dir),
|
||||
}
|
||||
|
||||
def _get_backup_filename(self, prefix: str, extension: str = ".gz") -> str:
|
||||
"""Generate backup filename with timestamp."""
|
||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
||||
return f"{prefix}_{timestamp}{extension}"
|
||||
|
||||
def _get_file_size_mb(self, file_path: Path) -> float:
|
||||
"""Get file size in MB."""
|
||||
if file_path.exists():
|
||||
return file_path.stat().st_size / (1024 * 1024)
|
||||
return 0.0
|
||||
|
||||
def _cleanup_old_backups(self, pattern: str, retention_days: int):
|
||||
"""Clean up old backup files."""
|
||||
if retention_days <= 0:
|
||||
return
|
||||
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=retention_days)
|
||||
|
||||
for backup_file in self.backup_dir.glob(pattern):
|
||||
if backup_file.stat().st_mtime < cutoff_date.timestamp():
|
||||
try:
|
||||
backup_file.unlink()
|
||||
logger.debug(f"Deleted old backup: {backup_file}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete old backup {backup_file}: {e}")
|
||||
|
||||
|
||||
class DatabaseBackup(BackupTask):
|
||||
"""Full database backup using pg_dump."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
super().__init__("database_backup", settings)
|
||||
self.retention_days = settings.database_backup_retention_days
|
||||
|
||||
async def execute_backup(self, session: AsyncSession) -> Dict[str, Any]:
|
||||
"""Execute database backup."""
|
||||
backup_filename = self._get_backup_filename("database_full", ".sql.gz")
|
||||
backup_path = self.backup_dir / backup_filename
|
||||
|
||||
# Build pg_dump command
|
||||
pg_dump_cmd = [
|
||||
"pg_dump",
|
||||
"--verbose",
|
||||
"--no-password",
|
||||
"--format=custom",
|
||||
"--compress=9",
|
||||
"--file", str(backup_path),
|
||||
]
|
||||
|
||||
# Add connection parameters
|
||||
if self.settings.database_url:
|
||||
pg_dump_cmd.append(self.settings.database_url)
|
||||
else:
|
||||
pg_dump_cmd.extend([
|
||||
"--host", self.settings.db_host,
|
||||
"--port", str(self.settings.db_port),
|
||||
"--username", self.settings.db_user,
|
||||
"--dbname", self.settings.db_name,
|
||||
])
|
||||
|
||||
# Set environment variables
|
||||
env = os.environ.copy()
|
||||
if self.settings.db_password:
|
||||
env["PGPASSWORD"] = self.settings.db_password
|
||||
|
||||
# Execute pg_dump
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*pg_dump_cmd,
|
||||
env=env,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode != 0:
|
||||
error_msg = stderr.decode() if stderr else "Unknown pg_dump error"
|
||||
raise Exception(f"pg_dump failed: {error_msg}")
|
||||
|
||||
backup_size_mb = self._get_file_size_mb(backup_path)
|
||||
|
||||
# Clean up old backups
|
||||
self._cleanup_old_backups("database_full_*.sql.gz", self.retention_days)
|
||||
|
||||
return {
|
||||
"backup_file": backup_filename,
|
||||
"backup_path": str(backup_path),
|
||||
"backup_size_mb": backup_size_mb,
|
||||
"retention_days": self.retention_days,
|
||||
}
|
||||
|
||||
|
||||
class ConfigurationBackup(BackupTask):
|
||||
"""Backup configuration files and settings."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
super().__init__("configuration_backup", settings)
|
||||
self.retention_days = settings.config_backup_retention_days
|
||||
self.config_files = [
|
||||
"src/config/settings.py",
|
||||
".env",
|
||||
"pyproject.toml",
|
||||
"docker-compose.yml",
|
||||
"Dockerfile",
|
||||
]
|
||||
|
||||
async def execute_backup(self, session: AsyncSession) -> Dict[str, Any]:
|
||||
"""Execute configuration backup."""
|
||||
backup_filename = self._get_backup_filename("configuration", ".tar.gz")
|
||||
backup_path = self.backup_dir / backup_filename
|
||||
|
||||
# Create temporary directory for config files
|
||||
temp_dir = self.backup_dir / "temp_config"
|
||||
temp_dir.mkdir(exist_ok=True)
|
||||
|
||||
try:
|
||||
copied_files = []
|
||||
|
||||
# Copy configuration files
|
||||
for config_file in self.config_files:
|
||||
source_path = Path(config_file)
|
||||
if source_path.exists():
|
||||
dest_path = temp_dir / source_path.name
|
||||
shutil.copy2(source_path, dest_path)
|
||||
copied_files.append(config_file)
|
||||
|
||||
# Create settings dump
|
||||
settings_dump = {
|
||||
"backup_timestamp": datetime.utcnow().isoformat(),
|
||||
"environment": self.settings.environment,
|
||||
"debug": self.settings.debug,
|
||||
"version": self.settings.version,
|
||||
"database_settings": {
|
||||
"db_host": self.settings.db_host,
|
||||
"db_port": self.settings.db_port,
|
||||
"db_name": self.settings.db_name,
|
||||
"db_pool_size": self.settings.db_pool_size,
|
||||
},
|
||||
"redis_settings": {
|
||||
"redis_enabled": self.settings.redis_enabled,
|
||||
"redis_host": self.settings.redis_host,
|
||||
"redis_port": self.settings.redis_port,
|
||||
"redis_db": self.settings.redis_db,
|
||||
},
|
||||
"monitoring_settings": {
|
||||
"monitoring_interval_seconds": self.settings.monitoring_interval_seconds,
|
||||
"cleanup_interval_seconds": self.settings.cleanup_interval_seconds,
|
||||
},
|
||||
}
|
||||
|
||||
settings_file = temp_dir / "settings_dump.json"
|
||||
with open(settings_file, 'w') as f:
|
||||
json.dump(settings_dump, f, indent=2)
|
||||
|
||||
# Create tar.gz archive
|
||||
tar_cmd = [
|
||||
"tar", "-czf", str(backup_path),
|
||||
"-C", str(temp_dir),
|
||||
"."
|
||||
]
|
||||
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*tar_cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode != 0:
|
||||
error_msg = stderr.decode() if stderr else "Unknown tar error"
|
||||
raise Exception(f"tar failed: {error_msg}")
|
||||
|
||||
backup_size_mb = self._get_file_size_mb(backup_path)
|
||||
|
||||
# Clean up old backups
|
||||
self._cleanup_old_backups("configuration_*.tar.gz", self.retention_days)
|
||||
|
||||
return {
|
||||
"backup_file": backup_filename,
|
||||
"backup_path": str(backup_path),
|
||||
"backup_size_mb": backup_size_mb,
|
||||
"copied_files": copied_files,
|
||||
"retention_days": self.retention_days,
|
||||
}
|
||||
|
||||
finally:
|
||||
# Clean up temporary directory
|
||||
if temp_dir.exists():
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
class DataExportBackup(BackupTask):
|
||||
"""Export specific data tables to JSON format."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
super().__init__("data_export_backup", settings)
|
||||
self.retention_days = settings.data_export_retention_days
|
||||
self.export_batch_size = 1000
|
||||
|
||||
async def execute_backup(self, session: AsyncSession) -> Dict[str, Any]:
|
||||
"""Execute data export backup."""
|
||||
backup_filename = self._get_backup_filename("data_export", ".json.gz")
|
||||
backup_path = self.backup_dir / backup_filename
|
||||
|
||||
export_data = {
|
||||
"backup_timestamp": datetime.utcnow().isoformat(),
|
||||
"export_version": "1.0",
|
||||
"tables": {}
|
||||
}
|
||||
|
||||
# Export devices
|
||||
devices_data = await self._export_table_data(session, Device, "devices")
|
||||
export_data["tables"]["devices"] = devices_data
|
||||
|
||||
# Export sessions
|
||||
sessions_data = await self._export_table_data(session, Session, "sessions")
|
||||
export_data["tables"]["sessions"] = sessions_data
|
||||
|
||||
# Export recent CSI data (last 7 days)
|
||||
recent_date = datetime.utcnow() - timedelta(days=7)
|
||||
csi_query = select(CSIData).where(CSIData.created_at >= recent_date)
|
||||
csi_data = await self._export_query_data(session, csi_query, "csi_data")
|
||||
export_data["tables"]["csi_data_recent"] = csi_data
|
||||
|
||||
# Export recent pose detections (last 7 days)
|
||||
pose_query = select(PoseDetection).where(PoseDetection.created_at >= recent_date)
|
||||
pose_data = await self._export_query_data(session, pose_query, "pose_detections")
|
||||
export_data["tables"]["pose_detections_recent"] = pose_data
|
||||
|
||||
# Write compressed JSON
|
||||
with gzip.open(backup_path, 'wt', encoding='utf-8') as f:
|
||||
json.dump(export_data, f, indent=2, default=str)
|
||||
|
||||
backup_size_mb = self._get_file_size_mb(backup_path)
|
||||
|
||||
# Clean up old backups
|
||||
self._cleanup_old_backups("data_export_*.json.gz", self.retention_days)
|
||||
|
||||
total_records = sum(
|
||||
table_data["record_count"]
|
||||
for table_data in export_data["tables"].values()
|
||||
)
|
||||
|
||||
return {
|
||||
"backup_file": backup_filename,
|
||||
"backup_path": str(backup_path),
|
||||
"backup_size_mb": backup_size_mb,
|
||||
"total_records": total_records,
|
||||
"tables_exported": list(export_data["tables"].keys()),
|
||||
"retention_days": self.retention_days,
|
||||
}
|
||||
|
||||
async def _export_table_data(self, session: AsyncSession, model_class, table_name: str) -> Dict[str, Any]:
|
||||
"""Export all data from a table."""
|
||||
query = select(model_class)
|
||||
return await self._export_query_data(session, query, table_name)
|
||||
|
||||
async def _export_query_data(self, session: AsyncSession, query, table_name: str) -> Dict[str, Any]:
|
||||
"""Export data from a query."""
|
||||
result = await session.execute(query)
|
||||
records = result.scalars().all()
|
||||
|
||||
exported_records = []
|
||||
for record in records:
|
||||
if hasattr(record, 'to_dict'):
|
||||
exported_records.append(record.to_dict())
|
||||
else:
|
||||
# Fallback for records without to_dict method
|
||||
record_dict = {}
|
||||
for column in record.__table__.columns:
|
||||
value = getattr(record, column.name)
|
||||
if isinstance(value, datetime):
|
||||
value = value.isoformat()
|
||||
record_dict[column.name] = value
|
||||
exported_records.append(record_dict)
|
||||
|
||||
return {
|
||||
"table_name": table_name,
|
||||
"record_count": len(exported_records),
|
||||
"export_timestamp": datetime.utcnow().isoformat(),
|
||||
"records": exported_records,
|
||||
}
|
||||
|
||||
|
||||
class LogsBackup(BackupTask):
|
||||
"""Backup application logs."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
super().__init__("logs_backup", settings)
|
||||
self.retention_days = settings.logs_backup_retention_days
|
||||
self.logs_directory = Path(settings.log_directory)
|
||||
|
||||
async def execute_backup(self, session: AsyncSession) -> Dict[str, Any]:
|
||||
"""Execute logs backup."""
|
||||
if not self.logs_directory.exists():
|
||||
return {
|
||||
"backup_file": None,
|
||||
"backup_path": None,
|
||||
"backup_size_mb": 0,
|
||||
"message": "Logs directory does not exist",
|
||||
}
|
||||
|
||||
backup_filename = self._get_backup_filename("logs", ".tar.gz")
|
||||
backup_path = self.backup_dir / backup_filename
|
||||
|
||||
# Create tar.gz archive of logs
|
||||
tar_cmd = [
|
||||
"tar", "-czf", str(backup_path),
|
||||
"-C", str(self.logs_directory.parent),
|
||||
self.logs_directory.name
|
||||
]
|
||||
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*tar_cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode != 0:
|
||||
error_msg = stderr.decode() if stderr else "Unknown tar error"
|
||||
raise Exception(f"tar failed: {error_msg}")
|
||||
|
||||
backup_size_mb = self._get_file_size_mb(backup_path)
|
||||
|
||||
# Count log files
|
||||
log_files = list(self.logs_directory.glob("*.log*"))
|
||||
|
||||
# Clean up old backups
|
||||
self._cleanup_old_backups("logs_*.tar.gz", self.retention_days)
|
||||
|
||||
return {
|
||||
"backup_file": backup_filename,
|
||||
"backup_path": str(backup_path),
|
||||
"backup_size_mb": backup_size_mb,
|
||||
"log_files_count": len(log_files),
|
||||
"retention_days": self.retention_days,
|
||||
}
|
||||
|
||||
|
||||
class BackupManager:
|
||||
"""Manager for all backup tasks."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self.db_manager = get_database_manager(settings)
|
||||
self.tasks = self._initialize_tasks()
|
||||
self.running = False
|
||||
self.last_run = None
|
||||
self.run_count = 0
|
||||
self.total_backup_size = 0
|
||||
|
||||
def _initialize_tasks(self) -> List[BackupTask]:
|
||||
"""Initialize all backup tasks."""
|
||||
tasks = [
|
||||
DatabaseBackup(self.settings),
|
||||
ConfigurationBackup(self.settings),
|
||||
DataExportBackup(self.settings),
|
||||
LogsBackup(self.settings),
|
||||
]
|
||||
|
||||
# Filter enabled tasks
|
||||
enabled_tasks = [task for task in tasks if task.enabled]
|
||||
|
||||
logger.info(f"Initialized {len(enabled_tasks)} backup tasks")
|
||||
return enabled_tasks
|
||||
|
||||
async def run_all_tasks(self) -> Dict[str, Any]:
|
||||
"""Run all backup tasks."""
|
||||
if self.running:
|
||||
return {"status": "already_running", "message": "Backup already in progress"}
|
||||
|
||||
self.running = True
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
logger.info("Starting backup tasks")
|
||||
|
||||
results = []
|
||||
total_backup_size = 0
|
||||
|
||||
async with self.db_manager.get_async_session() as session:
|
||||
for task in self.tasks:
|
||||
if not task.enabled:
|
||||
continue
|
||||
|
||||
result = await task.run(session)
|
||||
results.append(result)
|
||||
total_backup_size += result.get("backup_size_mb", 0)
|
||||
|
||||
self.last_run = start_time
|
||||
self.run_count += 1
|
||||
self.total_backup_size += total_backup_size
|
||||
|
||||
duration = (datetime.utcnow() - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Backup tasks completed: created {total_backup_size:.2f} MB "
|
||||
f"in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"start_time": start_time.isoformat(),
|
||||
"duration_seconds": duration,
|
||||
"total_backup_size_mb": total_backup_size,
|
||||
"task_results": results,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Backup tasks failed: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"start_time": start_time.isoformat(),
|
||||
"duration_seconds": (datetime.utcnow() - start_time).total_seconds(),
|
||||
"error": str(e),
|
||||
"total_backup_size_mb": 0,
|
||||
}
|
||||
|
||||
finally:
|
||||
self.running = False
|
||||
|
||||
async def run_task(self, task_name: str) -> Dict[str, Any]:
|
||||
"""Run a specific backup task."""
|
||||
task = next((t for t in self.tasks if t.name == task_name), None)
|
||||
|
||||
if not task:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": f"Task '{task_name}' not found",
|
||||
"available_tasks": [t.name for t in self.tasks]
|
||||
}
|
||||
|
||||
if not task.enabled:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": f"Task '{task_name}' is disabled"
|
||||
}
|
||||
|
||||
async with self.db_manager.get_async_session() as session:
|
||||
return await task.run(session)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get backup manager statistics."""
|
||||
return {
|
||||
"manager": {
|
||||
"running": self.running,
|
||||
"last_run": self.last_run.isoformat() if self.last_run else None,
|
||||
"run_count": self.run_count,
|
||||
"total_backup_size_mb": self.total_backup_size,
|
||||
},
|
||||
"tasks": [task.get_stats() for task in self.tasks],
|
||||
}
|
||||
|
||||
def list_backups(self) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""List all backup files."""
|
||||
backup_files = {}
|
||||
|
||||
for task in self.tasks:
|
||||
task_backups = []
|
||||
|
||||
# Define patterns for each task type
|
||||
patterns = {
|
||||
"database_backup": "database_full_*.sql.gz",
|
||||
"configuration_backup": "configuration_*.tar.gz",
|
||||
"data_export_backup": "data_export_*.json.gz",
|
||||
"logs_backup": "logs_*.tar.gz",
|
||||
}
|
||||
|
||||
pattern = patterns.get(task.name, f"{task.name}_*")
|
||||
|
||||
for backup_file in task.backup_dir.glob(pattern):
|
||||
stat = backup_file.stat()
|
||||
task_backups.append({
|
||||
"filename": backup_file.name,
|
||||
"path": str(backup_file),
|
||||
"size_mb": stat.st_size / (1024 * 1024),
|
||||
"created_at": datetime.fromtimestamp(stat.st_mtime).isoformat(),
|
||||
})
|
||||
|
||||
# Sort by creation time (newest first)
|
||||
task_backups.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
backup_files[task.name] = task_backups
|
||||
|
||||
return backup_files
|
||||
|
||||
|
||||
# Global backup manager instance
|
||||
_backup_manager: Optional[BackupManager] = None
|
||||
|
||||
|
||||
def get_backup_manager(settings: Settings) -> BackupManager:
|
||||
"""Get backup manager instance."""
|
||||
global _backup_manager
|
||||
if _backup_manager is None:
|
||||
_backup_manager = BackupManager(settings)
|
||||
return _backup_manager
|
||||
|
||||
|
||||
async def run_periodic_backup(settings: Settings):
|
||||
"""Run periodic backup tasks."""
|
||||
backup_manager = get_backup_manager(settings)
|
||||
|
||||
while True:
|
||||
try:
|
||||
await backup_manager.run_all_tasks()
|
||||
|
||||
# Wait for next backup interval
|
||||
await asyncio.sleep(settings.backup_interval_seconds)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Periodic backup cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Periodic backup error: {e}", exc_info=True)
|
||||
# Wait before retrying
|
||||
await asyncio.sleep(300) # 5 minutes
|
||||
598
v1/src/tasks/cleanup.py
Normal file
598
v1/src/tasks/cleanup.py
Normal file
@@ -0,0 +1,598 @@
|
||||
"""
|
||||
Periodic cleanup tasks for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, Optional, List
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from sqlalchemy import delete, select, func, and_, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.database.connection import get_database_manager
|
||||
from src.database.models import (
|
||||
CSIData, PoseDetection, SystemMetric, AuditLog, Session, Device
|
||||
)
|
||||
from src.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CleanupTask:
|
||||
"""Base class for cleanup tasks."""
|
||||
|
||||
def __init__(self, name: str, settings: Settings):
|
||||
self.name = name
|
||||
self.settings = settings
|
||||
self.enabled = True
|
||||
self.last_run = None
|
||||
self.run_count = 0
|
||||
self.error_count = 0
|
||||
self.total_cleaned = 0
|
||||
|
||||
async def execute(self, session: AsyncSession) -> Dict[str, Any]:
|
||||
"""Execute the cleanup task."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def run(self, session: AsyncSession) -> Dict[str, Any]:
|
||||
"""Run the cleanup task with error handling."""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
logger.info(f"Starting cleanup task: {self.name}")
|
||||
|
||||
result = await self.execute(session)
|
||||
|
||||
self.last_run = start_time
|
||||
self.run_count += 1
|
||||
|
||||
if result.get("cleaned_count", 0) > 0:
|
||||
self.total_cleaned += result["cleaned_count"]
|
||||
logger.info(
|
||||
f"Cleanup task {self.name} completed: "
|
||||
f"cleaned {result['cleaned_count']} items"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Cleanup task {self.name} completed: no items to clean")
|
||||
|
||||
return {
|
||||
"task": self.name,
|
||||
"status": "success",
|
||||
"start_time": start_time.isoformat(),
|
||||
"duration_ms": (datetime.utcnow() - start_time).total_seconds() * 1000,
|
||||
**result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.error_count += 1
|
||||
logger.error(f"Cleanup task {self.name} failed: {e}", exc_info=True)
|
||||
|
||||
return {
|
||||
"task": self.name,
|
||||
"status": "error",
|
||||
"start_time": start_time.isoformat(),
|
||||
"duration_ms": (datetime.utcnow() - start_time).total_seconds() * 1000,
|
||||
"error": str(e),
|
||||
"cleaned_count": 0
|
||||
}
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get task statistics."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"enabled": self.enabled,
|
||||
"last_run": self.last_run.isoformat() if self.last_run else None,
|
||||
"run_count": self.run_count,
|
||||
"error_count": self.error_count,
|
||||
"total_cleaned": self.total_cleaned,
|
||||
}
|
||||
|
||||
|
||||
class OldCSIDataCleanup(CleanupTask):
|
||||
"""Cleanup old CSI data records."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
super().__init__("old_csi_data_cleanup", settings)
|
||||
self.retention_days = settings.csi_data_retention_days
|
||||
self.batch_size = settings.cleanup_batch_size
|
||||
|
||||
async def execute(self, session: AsyncSession) -> Dict[str, Any]:
|
||||
"""Execute CSI data cleanup."""
|
||||
if self.retention_days <= 0:
|
||||
return {"cleaned_count": 0, "message": "CSI data retention disabled"}
|
||||
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=self.retention_days)
|
||||
|
||||
# Count records to be deleted
|
||||
count_query = select(func.count(CSIData.id)).where(
|
||||
CSIData.created_at < cutoff_date
|
||||
)
|
||||
total_count = await session.scalar(count_query)
|
||||
|
||||
if total_count == 0:
|
||||
return {"cleaned_count": 0, "message": "No old CSI data to clean"}
|
||||
|
||||
# Delete in batches
|
||||
cleaned_count = 0
|
||||
while cleaned_count < total_count:
|
||||
# Get batch of IDs to delete
|
||||
id_query = select(CSIData.id).where(
|
||||
CSIData.created_at < cutoff_date
|
||||
).limit(self.batch_size)
|
||||
|
||||
result = await session.execute(id_query)
|
||||
ids_to_delete = [row[0] for row in result.fetchall()]
|
||||
|
||||
if not ids_to_delete:
|
||||
break
|
||||
|
||||
# Delete batch
|
||||
delete_query = delete(CSIData).where(CSIData.id.in_(ids_to_delete))
|
||||
await session.execute(delete_query)
|
||||
await session.commit()
|
||||
|
||||
batch_size = len(ids_to_delete)
|
||||
cleaned_count += batch_size
|
||||
|
||||
logger.debug(f"Deleted {batch_size} CSI data records (total: {cleaned_count})")
|
||||
|
||||
# Small delay to avoid overwhelming the database
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
return {
|
||||
"cleaned_count": cleaned_count,
|
||||
"retention_days": self.retention_days,
|
||||
"cutoff_date": cutoff_date.isoformat()
|
||||
}
|
||||
|
||||
|
||||
class OldPoseDetectionCleanup(CleanupTask):
|
||||
"""Cleanup old pose detection records."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
super().__init__("old_pose_detection_cleanup", settings)
|
||||
self.retention_days = settings.pose_detection_retention_days
|
||||
self.batch_size = settings.cleanup_batch_size
|
||||
|
||||
async def execute(self, session: AsyncSession) -> Dict[str, Any]:
|
||||
"""Execute pose detection cleanup."""
|
||||
if self.retention_days <= 0:
|
||||
return {"cleaned_count": 0, "message": "Pose detection retention disabled"}
|
||||
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=self.retention_days)
|
||||
|
||||
# Count records to be deleted
|
||||
count_query = select(func.count(PoseDetection.id)).where(
|
||||
PoseDetection.created_at < cutoff_date
|
||||
)
|
||||
total_count = await session.scalar(count_query)
|
||||
|
||||
if total_count == 0:
|
||||
return {"cleaned_count": 0, "message": "No old pose detections to clean"}
|
||||
|
||||
# Delete in batches
|
||||
cleaned_count = 0
|
||||
while cleaned_count < total_count:
|
||||
# Get batch of IDs to delete
|
||||
id_query = select(PoseDetection.id).where(
|
||||
PoseDetection.created_at < cutoff_date
|
||||
).limit(self.batch_size)
|
||||
|
||||
result = await session.execute(id_query)
|
||||
ids_to_delete = [row[0] for row in result.fetchall()]
|
||||
|
||||
if not ids_to_delete:
|
||||
break
|
||||
|
||||
# Delete batch
|
||||
delete_query = delete(PoseDetection).where(PoseDetection.id.in_(ids_to_delete))
|
||||
await session.execute(delete_query)
|
||||
await session.commit()
|
||||
|
||||
batch_size = len(ids_to_delete)
|
||||
cleaned_count += batch_size
|
||||
|
||||
logger.debug(f"Deleted {batch_size} pose detection records (total: {cleaned_count})")
|
||||
|
||||
# Small delay to avoid overwhelming the database
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
return {
|
||||
"cleaned_count": cleaned_count,
|
||||
"retention_days": self.retention_days,
|
||||
"cutoff_date": cutoff_date.isoformat()
|
||||
}
|
||||
|
||||
|
||||
class OldMetricsCleanup(CleanupTask):
|
||||
"""Cleanup old system metrics."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
super().__init__("old_metrics_cleanup", settings)
|
||||
self.retention_days = settings.metrics_retention_days
|
||||
self.batch_size = settings.cleanup_batch_size
|
||||
|
||||
async def execute(self, session: AsyncSession) -> Dict[str, Any]:
|
||||
"""Execute metrics cleanup."""
|
||||
if self.retention_days <= 0:
|
||||
return {"cleaned_count": 0, "message": "Metrics retention disabled"}
|
||||
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=self.retention_days)
|
||||
|
||||
# Count records to be deleted
|
||||
count_query = select(func.count(SystemMetric.id)).where(
|
||||
SystemMetric.created_at < cutoff_date
|
||||
)
|
||||
total_count = await session.scalar(count_query)
|
||||
|
||||
if total_count == 0:
|
||||
return {"cleaned_count": 0, "message": "No old metrics to clean"}
|
||||
|
||||
# Delete in batches
|
||||
cleaned_count = 0
|
||||
while cleaned_count < total_count:
|
||||
# Get batch of IDs to delete
|
||||
id_query = select(SystemMetric.id).where(
|
||||
SystemMetric.created_at < cutoff_date
|
||||
).limit(self.batch_size)
|
||||
|
||||
result = await session.execute(id_query)
|
||||
ids_to_delete = [row[0] for row in result.fetchall()]
|
||||
|
||||
if not ids_to_delete:
|
||||
break
|
||||
|
||||
# Delete batch
|
||||
delete_query = delete(SystemMetric).where(SystemMetric.id.in_(ids_to_delete))
|
||||
await session.execute(delete_query)
|
||||
await session.commit()
|
||||
|
||||
batch_size = len(ids_to_delete)
|
||||
cleaned_count += batch_size
|
||||
|
||||
logger.debug(f"Deleted {batch_size} metric records (total: {cleaned_count})")
|
||||
|
||||
# Small delay to avoid overwhelming the database
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
return {
|
||||
"cleaned_count": cleaned_count,
|
||||
"retention_days": self.retention_days,
|
||||
"cutoff_date": cutoff_date.isoformat()
|
||||
}
|
||||
|
||||
|
||||
class OldAuditLogCleanup(CleanupTask):
|
||||
"""Cleanup old audit logs."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
super().__init__("old_audit_log_cleanup", settings)
|
||||
self.retention_days = settings.audit_log_retention_days
|
||||
self.batch_size = settings.cleanup_batch_size
|
||||
|
||||
async def execute(self, session: AsyncSession) -> Dict[str, Any]:
|
||||
"""Execute audit log cleanup."""
|
||||
if self.retention_days <= 0:
|
||||
return {"cleaned_count": 0, "message": "Audit log retention disabled"}
|
||||
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=self.retention_days)
|
||||
|
||||
# Count records to be deleted
|
||||
count_query = select(func.count(AuditLog.id)).where(
|
||||
AuditLog.created_at < cutoff_date
|
||||
)
|
||||
total_count = await session.scalar(count_query)
|
||||
|
||||
if total_count == 0:
|
||||
return {"cleaned_count": 0, "message": "No old audit logs to clean"}
|
||||
|
||||
# Delete in batches
|
||||
cleaned_count = 0
|
||||
while cleaned_count < total_count:
|
||||
# Get batch of IDs to delete
|
||||
id_query = select(AuditLog.id).where(
|
||||
AuditLog.created_at < cutoff_date
|
||||
).limit(self.batch_size)
|
||||
|
||||
result = await session.execute(id_query)
|
||||
ids_to_delete = [row[0] for row in result.fetchall()]
|
||||
|
||||
if not ids_to_delete:
|
||||
break
|
||||
|
||||
# Delete batch
|
||||
delete_query = delete(AuditLog).where(AuditLog.id.in_(ids_to_delete))
|
||||
await session.execute(delete_query)
|
||||
await session.commit()
|
||||
|
||||
batch_size = len(ids_to_delete)
|
||||
cleaned_count += batch_size
|
||||
|
||||
logger.debug(f"Deleted {batch_size} audit log records (total: {cleaned_count})")
|
||||
|
||||
# Small delay to avoid overwhelming the database
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
return {
|
||||
"cleaned_count": cleaned_count,
|
||||
"retention_days": self.retention_days,
|
||||
"cutoff_date": cutoff_date.isoformat()
|
||||
}
|
||||
|
||||
|
||||
class OrphanedSessionCleanup(CleanupTask):
|
||||
"""Cleanup orphaned sessions (sessions without associated data)."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
super().__init__("orphaned_session_cleanup", settings)
|
||||
self.orphan_threshold_days = settings.orphaned_session_threshold_days
|
||||
self.batch_size = settings.cleanup_batch_size
|
||||
|
||||
async def execute(self, session: AsyncSession) -> Dict[str, Any]:
|
||||
"""Execute orphaned session cleanup."""
|
||||
if self.orphan_threshold_days <= 0:
|
||||
return {"cleaned_count": 0, "message": "Orphaned session cleanup disabled"}
|
||||
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=self.orphan_threshold_days)
|
||||
|
||||
# Find sessions that are old and have no associated CSI data or pose detections
|
||||
orphaned_sessions_query = select(Session.id).where(
|
||||
and_(
|
||||
Session.created_at < cutoff_date,
|
||||
Session.status.in_(["completed", "failed", "cancelled"]),
|
||||
~Session.id.in_(select(CSIData.session_id).where(CSIData.session_id.isnot(None))),
|
||||
~Session.id.in_(select(PoseDetection.session_id))
|
||||
)
|
||||
)
|
||||
|
||||
result = await session.execute(orphaned_sessions_query)
|
||||
orphaned_ids = [row[0] for row in result.fetchall()]
|
||||
|
||||
if not orphaned_ids:
|
||||
return {"cleaned_count": 0, "message": "No orphaned sessions to clean"}
|
||||
|
||||
# Delete orphaned sessions
|
||||
delete_query = delete(Session).where(Session.id.in_(orphaned_ids))
|
||||
await session.execute(delete_query)
|
||||
await session.commit()
|
||||
|
||||
cleaned_count = len(orphaned_ids)
|
||||
|
||||
return {
|
||||
"cleaned_count": cleaned_count,
|
||||
"orphan_threshold_days": self.orphan_threshold_days,
|
||||
"cutoff_date": cutoff_date.isoformat()
|
||||
}
|
||||
|
||||
|
||||
class InvalidDataCleanup(CleanupTask):
|
||||
"""Cleanup invalid or corrupted data records."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
super().__init__("invalid_data_cleanup", settings)
|
||||
self.batch_size = settings.cleanup_batch_size
|
||||
|
||||
async def execute(self, session: AsyncSession) -> Dict[str, Any]:
|
||||
"""Execute invalid data cleanup."""
|
||||
total_cleaned = 0
|
||||
|
||||
# Clean invalid CSI data
|
||||
invalid_csi_query = select(CSIData.id).where(
|
||||
or_(
|
||||
CSIData.is_valid == False,
|
||||
CSIData.amplitude == None,
|
||||
CSIData.phase == None,
|
||||
CSIData.frequency <= 0,
|
||||
CSIData.bandwidth <= 0,
|
||||
CSIData.num_subcarriers <= 0
|
||||
)
|
||||
)
|
||||
|
||||
result = await session.execute(invalid_csi_query)
|
||||
invalid_csi_ids = [row[0] for row in result.fetchall()]
|
||||
|
||||
if invalid_csi_ids:
|
||||
delete_query = delete(CSIData).where(CSIData.id.in_(invalid_csi_ids))
|
||||
await session.execute(delete_query)
|
||||
total_cleaned += len(invalid_csi_ids)
|
||||
logger.debug(f"Deleted {len(invalid_csi_ids)} invalid CSI data records")
|
||||
|
||||
# Clean invalid pose detections
|
||||
invalid_pose_query = select(PoseDetection.id).where(
|
||||
or_(
|
||||
PoseDetection.is_valid == False,
|
||||
PoseDetection.person_count < 0,
|
||||
and_(
|
||||
PoseDetection.detection_confidence.isnot(None),
|
||||
or_(
|
||||
PoseDetection.detection_confidence < 0,
|
||||
PoseDetection.detection_confidence > 1
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
result = await session.execute(invalid_pose_query)
|
||||
invalid_pose_ids = [row[0] for row in result.fetchall()]
|
||||
|
||||
if invalid_pose_ids:
|
||||
delete_query = delete(PoseDetection).where(PoseDetection.id.in_(invalid_pose_ids))
|
||||
await session.execute(delete_query)
|
||||
total_cleaned += len(invalid_pose_ids)
|
||||
logger.debug(f"Deleted {len(invalid_pose_ids)} invalid pose detection records")
|
||||
|
||||
await session.commit()
|
||||
|
||||
return {
|
||||
"cleaned_count": total_cleaned,
|
||||
"invalid_csi_count": len(invalid_csi_ids) if invalid_csi_ids else 0,
|
||||
"invalid_pose_count": len(invalid_pose_ids) if invalid_pose_ids else 0,
|
||||
}
|
||||
|
||||
|
||||
class CleanupManager:
|
||||
"""Manager for all cleanup tasks."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self.db_manager = get_database_manager(settings)
|
||||
self.tasks = self._initialize_tasks()
|
||||
self.running = False
|
||||
self.last_run = None
|
||||
self.run_count = 0
|
||||
self.total_cleaned = 0
|
||||
|
||||
def _initialize_tasks(self) -> List[CleanupTask]:
|
||||
"""Initialize all cleanup tasks."""
|
||||
tasks = [
|
||||
OldCSIDataCleanup(self.settings),
|
||||
OldPoseDetectionCleanup(self.settings),
|
||||
OldMetricsCleanup(self.settings),
|
||||
OldAuditLogCleanup(self.settings),
|
||||
OrphanedSessionCleanup(self.settings),
|
||||
InvalidDataCleanup(self.settings),
|
||||
]
|
||||
|
||||
# Filter enabled tasks
|
||||
enabled_tasks = [task for task in tasks if task.enabled]
|
||||
|
||||
logger.info(f"Initialized {len(enabled_tasks)} cleanup tasks")
|
||||
return enabled_tasks
|
||||
|
||||
async def run_all_tasks(self) -> Dict[str, Any]:
|
||||
"""Run all cleanup tasks."""
|
||||
if self.running:
|
||||
return {"status": "already_running", "message": "Cleanup already in progress"}
|
||||
|
||||
self.running = True
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
logger.info("Starting cleanup tasks")
|
||||
|
||||
results = []
|
||||
total_cleaned = 0
|
||||
|
||||
async with self.db_manager.get_async_session() as session:
|
||||
for task in self.tasks:
|
||||
if not task.enabled:
|
||||
continue
|
||||
|
||||
result = await task.run(session)
|
||||
results.append(result)
|
||||
total_cleaned += result.get("cleaned_count", 0)
|
||||
|
||||
self.last_run = start_time
|
||||
self.run_count += 1
|
||||
self.total_cleaned += total_cleaned
|
||||
|
||||
duration = (datetime.utcnow() - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Cleanup tasks completed: cleaned {total_cleaned} items "
|
||||
f"in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"start_time": start_time.isoformat(),
|
||||
"duration_seconds": duration,
|
||||
"total_cleaned": total_cleaned,
|
||||
"task_results": results,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cleanup tasks failed: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"start_time": start_time.isoformat(),
|
||||
"duration_seconds": (datetime.utcnow() - start_time).total_seconds(),
|
||||
"error": str(e),
|
||||
"total_cleaned": 0,
|
||||
}
|
||||
|
||||
finally:
|
||||
self.running = False
|
||||
|
||||
async def run_task(self, task_name: str) -> Dict[str, Any]:
|
||||
"""Run a specific cleanup task."""
|
||||
task = next((t for t in self.tasks if t.name == task_name), None)
|
||||
|
||||
if not task:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": f"Task '{task_name}' not found",
|
||||
"available_tasks": [t.name for t in self.tasks]
|
||||
}
|
||||
|
||||
if not task.enabled:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": f"Task '{task_name}' is disabled"
|
||||
}
|
||||
|
||||
async with self.db_manager.get_async_session() as session:
|
||||
return await task.run(session)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get cleanup manager statistics."""
|
||||
return {
|
||||
"manager": {
|
||||
"running": self.running,
|
||||
"last_run": self.last_run.isoformat() if self.last_run else None,
|
||||
"run_count": self.run_count,
|
||||
"total_cleaned": self.total_cleaned,
|
||||
},
|
||||
"tasks": [task.get_stats() for task in self.tasks],
|
||||
}
|
||||
|
||||
def enable_task(self, task_name: str) -> bool:
|
||||
"""Enable a specific task."""
|
||||
task = next((t for t in self.tasks if t.name == task_name), None)
|
||||
if task:
|
||||
task.enabled = True
|
||||
return True
|
||||
return False
|
||||
|
||||
def disable_task(self, task_name: str) -> bool:
|
||||
"""Disable a specific task."""
|
||||
task = next((t for t in self.tasks if t.name == task_name), None)
|
||||
if task:
|
||||
task.enabled = False
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# Global cleanup manager instance
|
||||
_cleanup_manager: Optional[CleanupManager] = None
|
||||
|
||||
|
||||
def get_cleanup_manager(settings: Settings) -> CleanupManager:
|
||||
"""Get cleanup manager instance."""
|
||||
global _cleanup_manager
|
||||
if _cleanup_manager is None:
|
||||
_cleanup_manager = CleanupManager(settings)
|
||||
return _cleanup_manager
|
||||
|
||||
|
||||
async def run_periodic_cleanup(settings: Settings):
|
||||
"""Run periodic cleanup tasks."""
|
||||
cleanup_manager = get_cleanup_manager(settings)
|
||||
|
||||
while True:
|
||||
try:
|
||||
await cleanup_manager.run_all_tasks()
|
||||
|
||||
# Wait for next cleanup interval
|
||||
await asyncio.sleep(settings.cleanup_interval_seconds)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Periodic cleanup cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Periodic cleanup error: {e}", exc_info=True)
|
||||
# Wait before retrying
|
||||
await asyncio.sleep(60)
|
||||
773
v1/src/tasks/monitoring.py
Normal file
773
v1/src/tasks/monitoring.py
Normal file
@@ -0,0 +1,773 @@
|
||||
"""
|
||||
Monitoring tasks for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import psutil
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, Optional, List
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from sqlalchemy import select, func, and_, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.database.connection import get_database_manager
|
||||
from src.database.models import SystemMetric, Device, Session, CSIData, PoseDetection
|
||||
from src.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MonitoringTask:
|
||||
"""Base class for monitoring tasks."""
|
||||
|
||||
def __init__(self, name: str, settings: Settings):
|
||||
self.name = name
|
||||
self.settings = settings
|
||||
self.enabled = True
|
||||
self.last_run = None
|
||||
self.run_count = 0
|
||||
self.error_count = 0
|
||||
self.interval_seconds = 60 # Default interval
|
||||
|
||||
async def collect_metrics(self, session: AsyncSession) -> List[Dict[str, Any]]:
|
||||
"""Collect metrics for this task."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def run(self, session: AsyncSession) -> Dict[str, Any]:
|
||||
"""Run the monitoring task with error handling."""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
logger.debug(f"Starting monitoring task: {self.name}")
|
||||
|
||||
metrics = await self.collect_metrics(session)
|
||||
|
||||
# Store metrics in database
|
||||
for metric_data in metrics:
|
||||
metric = SystemMetric(
|
||||
metric_name=metric_data["name"],
|
||||
metric_type=metric_data["type"],
|
||||
value=metric_data["value"],
|
||||
unit=metric_data.get("unit"),
|
||||
labels=metric_data.get("labels"),
|
||||
tags=metric_data.get("tags"),
|
||||
source=metric_data.get("source", self.name),
|
||||
component=metric_data.get("component"),
|
||||
description=metric_data.get("description"),
|
||||
meta_data=metric_data.get("metadata"),
|
||||
)
|
||||
session.add(metric)
|
||||
|
||||
await session.commit()
|
||||
|
||||
self.last_run = start_time
|
||||
self.run_count += 1
|
||||
|
||||
logger.debug(f"Monitoring task {self.name} completed: collected {len(metrics)} metrics")
|
||||
|
||||
return {
|
||||
"task": self.name,
|
||||
"status": "success",
|
||||
"start_time": start_time.isoformat(),
|
||||
"duration_ms": (datetime.utcnow() - start_time).total_seconds() * 1000,
|
||||
"metrics_collected": len(metrics),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.error_count += 1
|
||||
logger.error(f"Monitoring task {self.name} failed: {e}", exc_info=True)
|
||||
|
||||
return {
|
||||
"task": self.name,
|
||||
"status": "error",
|
||||
"start_time": start_time.isoformat(),
|
||||
"duration_ms": (datetime.utcnow() - start_time).total_seconds() * 1000,
|
||||
"error": str(e),
|
||||
"metrics_collected": 0,
|
||||
}
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get task statistics."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"enabled": self.enabled,
|
||||
"interval_seconds": self.interval_seconds,
|
||||
"last_run": self.last_run.isoformat() if self.last_run else None,
|
||||
"run_count": self.run_count,
|
||||
"error_count": self.error_count,
|
||||
}
|
||||
|
||||
|
||||
class SystemResourceMonitoring(MonitoringTask):
|
||||
"""Monitor system resources (CPU, memory, disk, network)."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
super().__init__("system_resources", settings)
|
||||
self.interval_seconds = settings.system_monitoring_interval
|
||||
|
||||
async def collect_metrics(self, session: AsyncSession) -> List[Dict[str, Any]]:
|
||||
"""Collect system resource metrics."""
|
||||
metrics = []
|
||||
timestamp = datetime.utcnow()
|
||||
|
||||
# CPU metrics
|
||||
cpu_percent = psutil.cpu_percent(interval=1)
|
||||
cpu_count = psutil.cpu_count()
|
||||
cpu_freq = psutil.cpu_freq()
|
||||
|
||||
metrics.extend([
|
||||
{
|
||||
"name": "system_cpu_usage_percent",
|
||||
"type": "gauge",
|
||||
"value": cpu_percent,
|
||||
"unit": "percent",
|
||||
"component": "cpu",
|
||||
"description": "CPU usage percentage",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
},
|
||||
{
|
||||
"name": "system_cpu_count",
|
||||
"type": "gauge",
|
||||
"value": cpu_count,
|
||||
"unit": "count",
|
||||
"component": "cpu",
|
||||
"description": "Number of CPU cores",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
}
|
||||
])
|
||||
|
||||
if cpu_freq:
|
||||
metrics.append({
|
||||
"name": "system_cpu_frequency_mhz",
|
||||
"type": "gauge",
|
||||
"value": cpu_freq.current,
|
||||
"unit": "mhz",
|
||||
"component": "cpu",
|
||||
"description": "Current CPU frequency",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
})
|
||||
|
||||
# Memory metrics
|
||||
memory = psutil.virtual_memory()
|
||||
swap = psutil.swap_memory()
|
||||
|
||||
metrics.extend([
|
||||
{
|
||||
"name": "system_memory_total_bytes",
|
||||
"type": "gauge",
|
||||
"value": memory.total,
|
||||
"unit": "bytes",
|
||||
"component": "memory",
|
||||
"description": "Total system memory",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
},
|
||||
{
|
||||
"name": "system_memory_used_bytes",
|
||||
"type": "gauge",
|
||||
"value": memory.used,
|
||||
"unit": "bytes",
|
||||
"component": "memory",
|
||||
"description": "Used system memory",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
},
|
||||
{
|
||||
"name": "system_memory_available_bytes",
|
||||
"type": "gauge",
|
||||
"value": memory.available,
|
||||
"unit": "bytes",
|
||||
"component": "memory",
|
||||
"description": "Available system memory",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
},
|
||||
{
|
||||
"name": "system_memory_usage_percent",
|
||||
"type": "gauge",
|
||||
"value": memory.percent,
|
||||
"unit": "percent",
|
||||
"component": "memory",
|
||||
"description": "Memory usage percentage",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
},
|
||||
{
|
||||
"name": "system_swap_total_bytes",
|
||||
"type": "gauge",
|
||||
"value": swap.total,
|
||||
"unit": "bytes",
|
||||
"component": "memory",
|
||||
"description": "Total swap memory",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
},
|
||||
{
|
||||
"name": "system_swap_used_bytes",
|
||||
"type": "gauge",
|
||||
"value": swap.used,
|
||||
"unit": "bytes",
|
||||
"component": "memory",
|
||||
"description": "Used swap memory",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
}
|
||||
])
|
||||
|
||||
# Disk metrics
|
||||
disk_usage = psutil.disk_usage('/')
|
||||
disk_io = psutil.disk_io_counters()
|
||||
|
||||
metrics.extend([
|
||||
{
|
||||
"name": "system_disk_total_bytes",
|
||||
"type": "gauge",
|
||||
"value": disk_usage.total,
|
||||
"unit": "bytes",
|
||||
"component": "disk",
|
||||
"description": "Total disk space",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
},
|
||||
{
|
||||
"name": "system_disk_used_bytes",
|
||||
"type": "gauge",
|
||||
"value": disk_usage.used,
|
||||
"unit": "bytes",
|
||||
"component": "disk",
|
||||
"description": "Used disk space",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
},
|
||||
{
|
||||
"name": "system_disk_free_bytes",
|
||||
"type": "gauge",
|
||||
"value": disk_usage.free,
|
||||
"unit": "bytes",
|
||||
"component": "disk",
|
||||
"description": "Free disk space",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
},
|
||||
{
|
||||
"name": "system_disk_usage_percent",
|
||||
"type": "gauge",
|
||||
"value": (disk_usage.used / disk_usage.total) * 100,
|
||||
"unit": "percent",
|
||||
"component": "disk",
|
||||
"description": "Disk usage percentage",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
}
|
||||
])
|
||||
|
||||
if disk_io:
|
||||
metrics.extend([
|
||||
{
|
||||
"name": "system_disk_read_bytes_total",
|
||||
"type": "counter",
|
||||
"value": disk_io.read_bytes,
|
||||
"unit": "bytes",
|
||||
"component": "disk",
|
||||
"description": "Total bytes read from disk",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
},
|
||||
{
|
||||
"name": "system_disk_write_bytes_total",
|
||||
"type": "counter",
|
||||
"value": disk_io.write_bytes,
|
||||
"unit": "bytes",
|
||||
"component": "disk",
|
||||
"description": "Total bytes written to disk",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
}
|
||||
])
|
||||
|
||||
# Network metrics
|
||||
network_io = psutil.net_io_counters()
|
||||
|
||||
if network_io:
|
||||
metrics.extend([
|
||||
{
|
||||
"name": "system_network_bytes_sent_total",
|
||||
"type": "counter",
|
||||
"value": network_io.bytes_sent,
|
||||
"unit": "bytes",
|
||||
"component": "network",
|
||||
"description": "Total bytes sent over network",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
},
|
||||
{
|
||||
"name": "system_network_bytes_recv_total",
|
||||
"type": "counter",
|
||||
"value": network_io.bytes_recv,
|
||||
"unit": "bytes",
|
||||
"component": "network",
|
||||
"description": "Total bytes received over network",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
},
|
||||
{
|
||||
"name": "system_network_packets_sent_total",
|
||||
"type": "counter",
|
||||
"value": network_io.packets_sent,
|
||||
"unit": "count",
|
||||
"component": "network",
|
||||
"description": "Total packets sent over network",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
},
|
||||
{
|
||||
"name": "system_network_packets_recv_total",
|
||||
"type": "counter",
|
||||
"value": network_io.packets_recv,
|
||||
"unit": "count",
|
||||
"component": "network",
|
||||
"description": "Total packets received over network",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
}
|
||||
])
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
class DatabaseMonitoring(MonitoringTask):
|
||||
"""Monitor database performance and statistics."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
super().__init__("database", settings)
|
||||
self.interval_seconds = settings.database_monitoring_interval
|
||||
|
||||
async def collect_metrics(self, session: AsyncSession) -> List[Dict[str, Any]]:
|
||||
"""Collect database metrics."""
|
||||
metrics = []
|
||||
timestamp = datetime.utcnow()
|
||||
|
||||
# Get database connection stats
|
||||
db_manager = get_database_manager(self.settings)
|
||||
connection_stats = await db_manager.get_connection_stats()
|
||||
|
||||
# PostgreSQL connection metrics
|
||||
if "postgresql" in connection_stats:
|
||||
pg_stats = connection_stats["postgresql"]
|
||||
metrics.extend([
|
||||
{
|
||||
"name": "database_connections_total",
|
||||
"type": "gauge",
|
||||
"value": pg_stats.get("total_connections", 0),
|
||||
"unit": "count",
|
||||
"component": "postgresql",
|
||||
"description": "Total database connections",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
},
|
||||
{
|
||||
"name": "database_connections_active",
|
||||
"type": "gauge",
|
||||
"value": pg_stats.get("checked_out", 0),
|
||||
"unit": "count",
|
||||
"component": "postgresql",
|
||||
"description": "Active database connections",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
},
|
||||
{
|
||||
"name": "database_connections_available",
|
||||
"type": "gauge",
|
||||
"value": pg_stats.get("available_connections", 0),
|
||||
"unit": "count",
|
||||
"component": "postgresql",
|
||||
"description": "Available database connections",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
}
|
||||
])
|
||||
|
||||
# Redis connection metrics
|
||||
if "redis" in connection_stats and not connection_stats["redis"].get("error"):
|
||||
redis_stats = connection_stats["redis"]
|
||||
metrics.extend([
|
||||
{
|
||||
"name": "redis_connections_active",
|
||||
"type": "gauge",
|
||||
"value": redis_stats.get("connected_clients", 0),
|
||||
"unit": "count",
|
||||
"component": "redis",
|
||||
"description": "Active Redis connections",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
},
|
||||
{
|
||||
"name": "redis_connections_blocked",
|
||||
"type": "gauge",
|
||||
"value": redis_stats.get("blocked_clients", 0),
|
||||
"unit": "count",
|
||||
"component": "redis",
|
||||
"description": "Blocked Redis connections",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
}
|
||||
])
|
||||
|
||||
# Table row counts
|
||||
table_counts = await self._get_table_counts(session)
|
||||
for table_name, count in table_counts.items():
|
||||
metrics.append({
|
||||
"name": f"database_table_rows_{table_name}",
|
||||
"type": "gauge",
|
||||
"value": count,
|
||||
"unit": "count",
|
||||
"component": "postgresql",
|
||||
"description": f"Number of rows in {table_name} table",
|
||||
"metadata": {"timestamp": timestamp.isoformat(), "table": table_name}
|
||||
})
|
||||
|
||||
return metrics
|
||||
|
||||
async def _get_table_counts(self, session: AsyncSession) -> Dict[str, int]:
|
||||
"""Get row counts for all tables."""
|
||||
counts = {}
|
||||
|
||||
# Count devices
|
||||
result = await session.execute(select(func.count(Device.id)))
|
||||
counts["devices"] = result.scalar() or 0
|
||||
|
||||
# Count sessions
|
||||
result = await session.execute(select(func.count(Session.id)))
|
||||
counts["sessions"] = result.scalar() or 0
|
||||
|
||||
# Count CSI data
|
||||
result = await session.execute(select(func.count(CSIData.id)))
|
||||
counts["csi_data"] = result.scalar() or 0
|
||||
|
||||
# Count pose detections
|
||||
result = await session.execute(select(func.count(PoseDetection.id)))
|
||||
counts["pose_detections"] = result.scalar() or 0
|
||||
|
||||
# Count system metrics
|
||||
result = await session.execute(select(func.count(SystemMetric.id)))
|
||||
counts["system_metrics"] = result.scalar() or 0
|
||||
|
||||
return counts
|
||||
|
||||
|
||||
class ApplicationMonitoring(MonitoringTask):
|
||||
"""Monitor application-specific metrics."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
super().__init__("application", settings)
|
||||
self.interval_seconds = settings.application_monitoring_interval
|
||||
self.start_time = datetime.utcnow()
|
||||
|
||||
async def collect_metrics(self, session: AsyncSession) -> List[Dict[str, Any]]:
|
||||
"""Collect application metrics."""
|
||||
metrics = []
|
||||
timestamp = datetime.utcnow()
|
||||
|
||||
# Application uptime
|
||||
uptime_seconds = (timestamp - self.start_time).total_seconds()
|
||||
metrics.append({
|
||||
"name": "application_uptime_seconds",
|
||||
"type": "gauge",
|
||||
"value": uptime_seconds,
|
||||
"unit": "seconds",
|
||||
"component": "application",
|
||||
"description": "Application uptime in seconds",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
})
|
||||
|
||||
# Active sessions count
|
||||
active_sessions_query = select(func.count(Session.id)).where(
|
||||
Session.status == "active"
|
||||
)
|
||||
result = await session.execute(active_sessions_query)
|
||||
active_sessions = result.scalar() or 0
|
||||
|
||||
metrics.append({
|
||||
"name": "application_active_sessions",
|
||||
"type": "gauge",
|
||||
"value": active_sessions,
|
||||
"unit": "count",
|
||||
"component": "application",
|
||||
"description": "Number of active sessions",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
})
|
||||
|
||||
# Active devices count
|
||||
active_devices_query = select(func.count(Device.id)).where(
|
||||
Device.status == "active"
|
||||
)
|
||||
result = await session.execute(active_devices_query)
|
||||
active_devices = result.scalar() or 0
|
||||
|
||||
metrics.append({
|
||||
"name": "application_active_devices",
|
||||
"type": "gauge",
|
||||
"value": active_devices,
|
||||
"unit": "count",
|
||||
"component": "application",
|
||||
"description": "Number of active devices",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
})
|
||||
|
||||
# Recent data processing metrics (last hour)
|
||||
one_hour_ago = timestamp - timedelta(hours=1)
|
||||
|
||||
# Recent CSI data count
|
||||
recent_csi_query = select(func.count(CSIData.id)).where(
|
||||
CSIData.created_at >= one_hour_ago
|
||||
)
|
||||
result = await session.execute(recent_csi_query)
|
||||
recent_csi_count = result.scalar() or 0
|
||||
|
||||
metrics.append({
|
||||
"name": "application_csi_data_hourly",
|
||||
"type": "gauge",
|
||||
"value": recent_csi_count,
|
||||
"unit": "count",
|
||||
"component": "application",
|
||||
"description": "CSI data records created in the last hour",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
})
|
||||
|
||||
# Recent pose detections count
|
||||
recent_pose_query = select(func.count(PoseDetection.id)).where(
|
||||
PoseDetection.created_at >= one_hour_ago
|
||||
)
|
||||
result = await session.execute(recent_pose_query)
|
||||
recent_pose_count = result.scalar() or 0
|
||||
|
||||
metrics.append({
|
||||
"name": "application_pose_detections_hourly",
|
||||
"type": "gauge",
|
||||
"value": recent_pose_count,
|
||||
"unit": "count",
|
||||
"component": "application",
|
||||
"description": "Pose detections created in the last hour",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
})
|
||||
|
||||
# Processing status metrics
|
||||
processing_statuses = ["pending", "processing", "completed", "failed"]
|
||||
for status in processing_statuses:
|
||||
status_query = select(func.count(CSIData.id)).where(
|
||||
CSIData.processing_status == status
|
||||
)
|
||||
result = await session.execute(status_query)
|
||||
status_count = result.scalar() or 0
|
||||
|
||||
metrics.append({
|
||||
"name": f"application_csi_processing_{status}",
|
||||
"type": "gauge",
|
||||
"value": status_count,
|
||||
"unit": "count",
|
||||
"component": "application",
|
||||
"description": f"CSI data records with {status} processing status",
|
||||
"metadata": {"timestamp": timestamp.isoformat(), "status": status}
|
||||
})
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
class PerformanceMonitoring(MonitoringTask):
|
||||
"""Monitor performance metrics and response times."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
super().__init__("performance", settings)
|
||||
self.interval_seconds = settings.performance_monitoring_interval
|
||||
self.response_times = []
|
||||
self.error_counts = {}
|
||||
|
||||
async def collect_metrics(self, session: AsyncSession) -> List[Dict[str, Any]]:
|
||||
"""Collect performance metrics."""
|
||||
metrics = []
|
||||
timestamp = datetime.utcnow()
|
||||
|
||||
# Database query performance test
|
||||
start_time = time.time()
|
||||
test_query = select(func.count(Device.id))
|
||||
await session.execute(test_query)
|
||||
db_response_time = (time.time() - start_time) * 1000 # Convert to milliseconds
|
||||
|
||||
metrics.append({
|
||||
"name": "performance_database_query_time_ms",
|
||||
"type": "gauge",
|
||||
"value": db_response_time,
|
||||
"unit": "milliseconds",
|
||||
"component": "database",
|
||||
"description": "Database query response time",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
})
|
||||
|
||||
# Average response time (if we have data)
|
||||
if self.response_times:
|
||||
avg_response_time = sum(self.response_times) / len(self.response_times)
|
||||
metrics.append({
|
||||
"name": "performance_avg_response_time_ms",
|
||||
"type": "gauge",
|
||||
"value": avg_response_time,
|
||||
"unit": "milliseconds",
|
||||
"component": "api",
|
||||
"description": "Average API response time",
|
||||
"metadata": {"timestamp": timestamp.isoformat()}
|
||||
})
|
||||
|
||||
# Clear old response times (keep only recent ones)
|
||||
self.response_times = self.response_times[-100:] # Keep last 100
|
||||
|
||||
# Error rates
|
||||
for error_type, count in self.error_counts.items():
|
||||
metrics.append({
|
||||
"name": f"performance_errors_{error_type}_total",
|
||||
"type": "counter",
|
||||
"value": count,
|
||||
"unit": "count",
|
||||
"component": "api",
|
||||
"description": f"Total {error_type} errors",
|
||||
"metadata": {"timestamp": timestamp.isoformat(), "error_type": error_type}
|
||||
})
|
||||
|
||||
return metrics
|
||||
|
||||
def record_response_time(self, response_time_ms: float):
|
||||
"""Record an API response time."""
|
||||
self.response_times.append(response_time_ms)
|
||||
|
||||
def record_error(self, error_type: str):
|
||||
"""Record an error occurrence."""
|
||||
self.error_counts[error_type] = self.error_counts.get(error_type, 0) + 1
|
||||
|
||||
|
||||
class MonitoringManager:
|
||||
"""Manager for all monitoring tasks."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self.db_manager = get_database_manager(settings)
|
||||
self.tasks = self._initialize_tasks()
|
||||
self.running = False
|
||||
self.last_run = None
|
||||
self.run_count = 0
|
||||
|
||||
def _initialize_tasks(self) -> List[MonitoringTask]:
|
||||
"""Initialize all monitoring tasks."""
|
||||
tasks = [
|
||||
SystemResourceMonitoring(self.settings),
|
||||
DatabaseMonitoring(self.settings),
|
||||
ApplicationMonitoring(self.settings),
|
||||
PerformanceMonitoring(self.settings),
|
||||
]
|
||||
|
||||
# Filter enabled tasks
|
||||
enabled_tasks = [task for task in tasks if task.enabled]
|
||||
|
||||
logger.info(f"Initialized {len(enabled_tasks)} monitoring tasks")
|
||||
return enabled_tasks
|
||||
|
||||
async def run_all_tasks(self) -> Dict[str, Any]:
|
||||
"""Run all monitoring tasks."""
|
||||
if self.running:
|
||||
return {"status": "already_running", "message": "Monitoring already in progress"}
|
||||
|
||||
self.running = True
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
logger.debug("Starting monitoring tasks")
|
||||
|
||||
results = []
|
||||
total_metrics = 0
|
||||
|
||||
async with self.db_manager.get_async_session() as session:
|
||||
for task in self.tasks:
|
||||
if not task.enabled:
|
||||
continue
|
||||
|
||||
result = await task.run(session)
|
||||
results.append(result)
|
||||
total_metrics += result.get("metrics_collected", 0)
|
||||
|
||||
self.last_run = start_time
|
||||
self.run_count += 1
|
||||
|
||||
duration = (datetime.utcnow() - start_time).total_seconds()
|
||||
|
||||
logger.debug(
|
||||
f"Monitoring tasks completed: collected {total_metrics} metrics "
|
||||
f"in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"start_time": start_time.isoformat(),
|
||||
"duration_seconds": duration,
|
||||
"total_metrics": total_metrics,
|
||||
"task_results": results,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Monitoring tasks failed: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"start_time": start_time.isoformat(),
|
||||
"duration_seconds": (datetime.utcnow() - start_time).total_seconds(),
|
||||
"error": str(e),
|
||||
"total_metrics": 0,
|
||||
}
|
||||
|
||||
finally:
|
||||
self.running = False
|
||||
|
||||
async def run_task(self, task_name: str) -> Dict[str, Any]:
|
||||
"""Run a specific monitoring task."""
|
||||
task = next((t for t in self.tasks if t.name == task_name), None)
|
||||
|
||||
if not task:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": f"Task '{task_name}' not found",
|
||||
"available_tasks": [t.name for t in self.tasks]
|
||||
}
|
||||
|
||||
if not task.enabled:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": f"Task '{task_name}' is disabled"
|
||||
}
|
||||
|
||||
async with self.db_manager.get_async_session() as session:
|
||||
return await task.run(session)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get monitoring manager statistics."""
|
||||
return {
|
||||
"manager": {
|
||||
"running": self.running,
|
||||
"last_run": self.last_run.isoformat() if self.last_run else None,
|
||||
"run_count": self.run_count,
|
||||
},
|
||||
"tasks": [task.get_stats() for task in self.tasks],
|
||||
}
|
||||
|
||||
def get_performance_task(self) -> Optional[PerformanceMonitoring]:
|
||||
"""Get the performance monitoring task for recording metrics."""
|
||||
return next((t for t in self.tasks if isinstance(t, PerformanceMonitoring)), None)
|
||||
|
||||
|
||||
# Global monitoring manager instance
|
||||
_monitoring_manager: Optional[MonitoringManager] = None
|
||||
|
||||
|
||||
def get_monitoring_manager(settings: Settings) -> MonitoringManager:
|
||||
"""Get monitoring manager instance."""
|
||||
global _monitoring_manager
|
||||
if _monitoring_manager is None:
|
||||
_monitoring_manager = MonitoringManager(settings)
|
||||
return _monitoring_manager
|
||||
|
||||
|
||||
async def run_periodic_monitoring(settings: Settings):
|
||||
"""Run periodic monitoring tasks."""
|
||||
monitoring_manager = get_monitoring_manager(settings)
|
||||
|
||||
while True:
|
||||
try:
|
||||
await monitoring_manager.run_all_tasks()
|
||||
|
||||
# Wait for next monitoring interval
|
||||
await asyncio.sleep(settings.monitoring_interval_seconds)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Periodic monitoring cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Periodic monitoring error: {e}", exc_info=True)
|
||||
# Wait before retrying
|
||||
await asyncio.sleep(30)
|
||||
Reference in New Issue
Block a user