feat: Complete Rust port of WiFi-DensePose with modular crates

Major changes:
- Organized Python v1 implementation into v1/ subdirectory
- Created Rust workspace with 9 modular crates:
  - wifi-densepose-core: Core types, traits, errors
  - wifi-densepose-signal: CSI processing, phase sanitization, FFT
  - wifi-densepose-nn: Neural network inference (ONNX/Candle/tch)
  - wifi-densepose-api: Axum-based REST/WebSocket API
  - wifi-densepose-db: SQLx database layer
  - wifi-densepose-config: Configuration management
  - wifi-densepose-hardware: Hardware abstraction
  - wifi-densepose-wasm: WebAssembly bindings
  - wifi-densepose-cli: Command-line interface

Documentation:
- ADR-001: Workspace structure
- ADR-002: Signal processing library selection
- ADR-003: Neural network inference strategy
- DDD domain model with bounded contexts

Testing:
- 69 tests passing across all crates
- Signal processing: 45 tests
- Neural networks: 21 tests
- Core: 3 doc tests

Performance targets:
- 10x faster CSI processing (~0.5ms vs ~5ms)
- 5x lower memory usage (~100MB vs ~500MB)
- WASM support for browser deployment
This commit is contained in:
Claude
2026-01-13 03:11:16 +00:00
parent 5101504b72
commit 6ed69a3d48
427 changed files with 90993 additions and 0 deletions

266
v1/src/__init__.py Normal file
View File

@@ -0,0 +1,266 @@
"""
WiFi-DensePose API Package
==========================
A comprehensive system for WiFi-based human pose estimation using CSI data
and DensePose neural networks.
This package provides:
- Real-time CSI data collection from WiFi routers
- Advanced signal processing and phase sanitization
- DensePose neural network integration for pose estimation
- RESTful API for data access and control
- Background task management for data processing
- Comprehensive monitoring and logging
Example usage:
>>> from src.app import app
>>> from src.config.settings import get_settings
>>>
>>> settings = get_settings()
>>> # Run with: uvicorn src.app:app --host 0.0.0.0 --port 8000
For CLI usage:
$ wifi-densepose start --host 0.0.0.0 --port 8000
$ wifi-densepose status
$ wifi-densepose stop
Author: WiFi-DensePose Team
License: MIT
"""
__version__ = "1.1.0"
__author__ = "WiFi-DensePose Team"
__email__ = "team@wifi-densepose.com"
__license__ = "MIT"
__copyright__ = "Copyright 2024 WiFi-DensePose Team"
# Package metadata
__title__ = "wifi-densepose"
__description__ = "WiFi-based human pose estimation using CSI data and DensePose neural networks"
__url__ = "https://github.com/wifi-densepose/wifi-densepose"
__download_url__ = "https://github.com/wifi-densepose/wifi-densepose/archive/main.zip"
# Version info tuple
__version_info__ = tuple(int(x) for x in __version__.split('.'))
# Import key components for easy access
try:
from src.app import app
from src.config.settings import get_settings, Settings
from src.logger import setup_logging, get_logger
# Core components
from src.core.csi_processor import CSIProcessor
from src.core.phase_sanitizer import PhaseSanitizer
from src.core.pose_estimator import PoseEstimator
from src.core.router_interface import RouterInterface
# Services
from src.services.orchestrator import ServiceOrchestrator
from src.services.health_check import HealthCheckService
from src.services.metrics import MetricsService
# Database
from src.database.connection import get_database_manager
from src.database.models import (
Device, Session, CSIData, PoseDetection,
SystemMetric, AuditLog
)
__all__ = [
# Core app
'app',
'get_settings',
'Settings',
'setup_logging',
'get_logger',
# Core processing
'CSIProcessor',
'PhaseSanitizer',
'PoseEstimator',
'RouterInterface',
# Services
'ServiceOrchestrator',
'HealthCheckService',
'MetricsService',
# Database
'get_database_manager',
'Device',
'Session',
'CSIData',
'PoseDetection',
'SystemMetric',
'AuditLog',
# Metadata
'__version__',
'__version_info__',
'__author__',
'__email__',
'__license__',
'__copyright__',
]
except ImportError as e:
# Handle import errors gracefully during package installation
import warnings
warnings.warn(
f"Some components could not be imported: {e}. "
"This is normal during package installation.",
ImportWarning
)
__all__ = [
'__version__',
'__version_info__',
'__author__',
'__email__',
'__license__',
'__copyright__',
]
def get_version():
"""Get the package version."""
return __version__
def get_version_info():
"""Get the package version as a tuple."""
return __version_info__
def get_package_info():
"""Get comprehensive package information."""
return {
'name': __title__,
'version': __version__,
'version_info': __version_info__,
'description': __description__,
'author': __author__,
'author_email': __email__,
'license': __license__,
'copyright': __copyright__,
'url': __url__,
'download_url': __download_url__,
}
def check_dependencies():
"""Check if all required dependencies are available."""
missing_deps = []
optional_deps = []
# Core dependencies
required_modules = [
('fastapi', 'FastAPI'),
('uvicorn', 'Uvicorn'),
('pydantic', 'Pydantic'),
('sqlalchemy', 'SQLAlchemy'),
('numpy', 'NumPy'),
('torch', 'PyTorch'),
('cv2', 'OpenCV'),
('scipy', 'SciPy'),
('pandas', 'Pandas'),
('redis', 'Redis'),
('psutil', 'psutil'),
('click', 'Click'),
]
for module_name, display_name in required_modules:
try:
__import__(module_name)
except ImportError:
missing_deps.append(display_name)
# Optional dependencies
optional_modules = [
('scapy', 'Scapy (for network packet capture)'),
('paramiko', 'Paramiko (for SSH connections)'),
('serial', 'PySerial (for serial communication)'),
('matplotlib', 'Matplotlib (for plotting)'),
('prometheus_client', 'Prometheus Client (for metrics)'),
]
for module_name, display_name in optional_modules:
try:
__import__(module_name)
except ImportError:
optional_deps.append(display_name)
return {
'missing_required': missing_deps,
'missing_optional': optional_deps,
'all_required_available': len(missing_deps) == 0,
}
def print_system_info():
"""Print system and package information."""
import sys
import platform
info = get_package_info()
deps = check_dependencies()
print(f"WiFi-DensePose v{info['version']}")
print(f"Python {sys.version}")
print(f"Platform: {platform.platform()}")
print(f"Architecture: {platform.architecture()[0]}")
print()
if deps['all_required_available']:
print("✅ All required dependencies are available")
else:
print("❌ Missing required dependencies:")
for dep in deps['missing_required']:
print(f" - {dep}")
if deps['missing_optional']:
print("\n⚠️ Missing optional dependencies:")
for dep in deps['missing_optional']:
print(f" - {dep}")
print(f"\nFor more information, visit: {info['url']}")
# Package-level configuration
import logging
# Set up basic logging configuration
logging.getLogger(__name__).addHandler(logging.NullHandler())
# Suppress some noisy third-party loggers
logging.getLogger('urllib3').setLevel(logging.WARNING)
logging.getLogger('requests').setLevel(logging.WARNING)
logging.getLogger('asyncio').setLevel(logging.WARNING)
# Package initialization message
if __name__ != '__main__':
logger = logging.getLogger(__name__)
logger.debug(f"WiFi-DensePose package v{__version__} initialized")
# Compatibility aliases for backward compatibility
try:
WifiDensePose = app # Legacy alias
except NameError:
WifiDensePose = None # Will be None if app import failed
try:
get_config = get_settings # Legacy alias
except NameError:
get_config = None # Will be None if get_settings import failed
def main():
"""Main entry point for the package when run as a module."""
print_system_info()
if __name__ == '__main__':
main()

7
v1/src/api/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
"""
WiFi-DensePose FastAPI application package
"""
# API package - routers and dependencies are imported by app.py
__all__ = []

447
v1/src/api/dependencies.py Normal file
View File

@@ -0,0 +1,447 @@
"""
Dependency injection for WiFi-DensePose API
"""
import logging
from typing import Optional, Dict, Any
from functools import lru_cache
from fastapi import Depends, HTTPException, status, Request
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from src.config.settings import get_settings
from src.config.domains import get_domain_config
from src.services.pose_service import PoseService
from src.services.stream_service import StreamService
from src.services.hardware_service import HardwareService
logger = logging.getLogger(__name__)
# Security scheme for JWT authentication
security = HTTPBearer(auto_error=False)
# Service dependencies
@lru_cache()
def get_pose_service() -> PoseService:
"""Get pose service instance."""
settings = get_settings()
domain_config = get_domain_config()
return PoseService(
settings=settings,
domain_config=domain_config
)
@lru_cache()
def get_stream_service() -> StreamService:
"""Get stream service instance."""
settings = get_settings()
domain_config = get_domain_config()
return StreamService(
settings=settings,
domain_config=domain_config
)
@lru_cache()
def get_hardware_service() -> HardwareService:
"""Get hardware service instance."""
settings = get_settings()
domain_config = get_domain_config()
return HardwareService(
settings=settings,
domain_config=domain_config
)
# Authentication dependencies
async def get_current_user(
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
) -> Optional[Dict[str, Any]]:
"""Get current authenticated user."""
settings = get_settings()
# Skip authentication if disabled
if not settings.enable_authentication:
return None
# Check if user is already set by middleware
if hasattr(request.state, 'user') and request.state.user:
return request.state.user
# No credentials provided
if not credentials:
return None
# This would normally validate the JWT token
# For now, return a mock user for development
if settings.is_development:
return {
"id": "dev-user",
"username": "developer",
"email": "dev@example.com",
"is_admin": True,
"permissions": ["read", "write", "admin"]
}
# In production, implement proper JWT validation
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication not implemented",
headers={"WWW-Authenticate": "Bearer"},
)
async def get_current_active_user(
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
) -> Dict[str, Any]:
"""Get current active user (required authentication)."""
if not current_user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required",
headers={"WWW-Authenticate": "Bearer"},
)
# Check if user is active
if not current_user.get("is_active", True):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Inactive user"
)
return current_user
async def get_admin_user(
current_user: Dict[str, Any] = Depends(get_current_active_user)
) -> Dict[str, Any]:
"""Get current admin user (admin privileges required)."""
if not current_user.get("is_admin", False):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin privileges required"
)
return current_user
# Permission dependencies
def require_permission(permission: str):
"""Dependency factory for permission checking."""
async def check_permission(
current_user: Dict[str, Any] = Depends(get_current_active_user)
) -> Dict[str, Any]:
"""Check if user has required permission."""
user_permissions = current_user.get("permissions", [])
# Admin users have all permissions
if current_user.get("is_admin", False):
return current_user
# Check specific permission
if permission not in user_permissions:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Permission '{permission}' required"
)
return current_user
return check_permission
# Zone access dependencies
async def validate_zone_access(
zone_id: str,
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
) -> str:
"""Validate user access to a specific zone."""
domain_config = get_domain_config()
# Check if zone exists
zone = domain_config.get_zone(zone_id)
if not zone:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Zone '{zone_id}' not found"
)
# Check if zone is enabled
if not zone.enabled:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Zone '{zone_id}' is disabled"
)
# If authentication is enabled, check user access
if current_user:
# Admin users have access to all zones
if current_user.get("is_admin", False):
return zone_id
# Check user's zone permissions
user_zones = current_user.get("zones", [])
if user_zones and zone_id not in user_zones:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access denied to zone '{zone_id}'"
)
return zone_id
# Router access dependencies
async def validate_router_access(
router_id: str,
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
) -> str:
"""Validate user access to a specific router."""
domain_config = get_domain_config()
# Check if router exists
router = domain_config.get_router(router_id)
if not router:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Router '{router_id}' not found"
)
# Check if router is enabled
if not router.enabled:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Router '{router_id}' is disabled"
)
# If authentication is enabled, check user access
if current_user:
# Admin users have access to all routers
if current_user.get("is_admin", False):
return router_id
# Check user's router permissions
user_routers = current_user.get("routers", [])
if user_routers and router_id not in user_routers:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access denied to router '{router_id}'"
)
return router_id
# Service health dependencies
async def check_service_health(
request: Request,
service_name: str
) -> bool:
"""Check if a service is healthy."""
try:
if service_name == "pose":
service = getattr(request.app.state, 'pose_service', None)
elif service_name == "stream":
service = getattr(request.app.state, 'stream_service', None)
elif service_name == "hardware":
service = getattr(request.app.state, 'hardware_service', None)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unknown service: {service_name}"
)
if not service:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Service '{service_name}' not available"
)
# Check service health
status_info = await service.get_status()
if status_info.get("status") != "healthy":
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Service '{service_name}' is unhealthy: {status_info.get('error', 'Unknown error')}"
)
return True
except HTTPException:
raise
except Exception as e:
logger.error(f"Error checking service health for {service_name}: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Service '{service_name}' health check failed"
)
# Rate limiting dependencies
async def check_rate_limit(
request: Request,
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
) -> bool:
"""Check rate limiting status."""
settings = get_settings()
# Skip if rate limiting is disabled
if not settings.enable_rate_limiting:
return True
# Rate limiting is handled by middleware
# This dependency can be used for additional checks
return True
# Configuration dependencies
def get_zone_config(zone_id: str = Depends(validate_zone_access)):
"""Get zone configuration."""
domain_config = get_domain_config()
return domain_config.get_zone(zone_id)
def get_router_config(router_id: str = Depends(validate_router_access)):
"""Get router configuration."""
domain_config = get_domain_config()
return domain_config.get_router(router_id)
# Pagination dependencies
class PaginationParams:
"""Pagination parameters."""
def __init__(
self,
page: int = 1,
size: int = 20,
max_size: int = 100
):
if page < 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Page must be >= 1"
)
if size < 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Size must be >= 1"
)
if size > max_size:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Size must be <= {max_size}"
)
self.page = page
self.size = size
self.offset = (page - 1) * size
self.limit = size
def get_pagination_params(
page: int = 1,
size: int = 20
) -> PaginationParams:
"""Get pagination parameters."""
return PaginationParams(page=page, size=size)
# Query filter dependencies
class QueryFilters:
"""Common query filters."""
def __init__(
self,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
min_confidence: Optional[float] = None,
activity: Optional[str] = None
):
self.start_time = start_time
self.end_time = end_time
self.min_confidence = min_confidence
self.activity = activity
# Validate confidence
if min_confidence is not None:
if not 0.0 <= min_confidence <= 1.0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="min_confidence must be between 0.0 and 1.0"
)
def get_query_filters(
start_time: Optional[str] = None,
end_time: Optional[str] = None,
min_confidence: Optional[float] = None,
activity: Optional[str] = None
) -> QueryFilters:
"""Get query filters."""
return QueryFilters(
start_time=start_time,
end_time=end_time,
min_confidence=min_confidence,
activity=activity
)
# WebSocket dependencies
async def get_websocket_user(
websocket_token: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""Get user from WebSocket token."""
settings = get_settings()
# Skip authentication if disabled
if not settings.enable_authentication:
return None
# For development, return mock user
if settings.is_development:
return {
"id": "ws-user",
"username": "websocket_user",
"is_admin": False,
"permissions": ["read"]
}
# In production, implement proper token validation
return None
async def get_current_user_ws(
websocket_token: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""Get current user for WebSocket connections."""
return await get_websocket_user(websocket_token)
# Authentication requirement dependencies
async def require_auth(
current_user: Dict[str, Any] = Depends(get_current_active_user)
) -> Dict[str, Any]:
"""Require authentication for endpoint access."""
return current_user
# Development dependencies
async def development_only():
"""Dependency that only allows access in development."""
settings = get_settings()
if not settings.is_development:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Endpoint not available in production"
)
return True

421
v1/src/api/main.py Normal file
View File

@@ -0,0 +1,421 @@
"""
FastAPI application for WiFi-DensePose API
"""
import asyncio
import logging
import logging.config
from contextlib import asynccontextmanager
from typing import Dict, Any
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
from src.config.settings import get_settings
from src.config.domains import get_domain_config
from src.api.routers import pose, stream, health
from src.api.middleware.auth import AuthMiddleware
from src.api.middleware.rate_limit import RateLimitMiddleware
from src.api.dependencies import get_pose_service, get_stream_service, get_hardware_service
from src.api.websocket.connection_manager import connection_manager
from src.api.websocket.pose_stream import PoseStreamHandler
# Configure logging
settings = get_settings()
logging.config.dictConfig(settings.get_logging_config())
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager."""
logger.info("Starting WiFi-DensePose API...")
try:
# Initialize services
await initialize_services(app)
# Start background tasks
await start_background_tasks(app)
logger.info("WiFi-DensePose API started successfully")
yield
except Exception as e:
logger.error(f"Failed to start application: {e}")
raise
finally:
# Cleanup on shutdown
logger.info("Shutting down WiFi-DensePose API...")
await cleanup_services(app)
logger.info("WiFi-DensePose API shutdown complete")
async def initialize_services(app: FastAPI):
"""Initialize application services."""
try:
# Initialize hardware service
hardware_service = get_hardware_service()
await hardware_service.initialize()
# Initialize pose service
pose_service = get_pose_service()
await pose_service.initialize()
# Initialize stream service
stream_service = get_stream_service()
await stream_service.initialize()
# Initialize pose stream handler
pose_stream_handler = PoseStreamHandler(
connection_manager=connection_manager,
pose_service=pose_service,
stream_service=stream_service
)
# Store in app state for access in routes
app.state.hardware_service = hardware_service
app.state.pose_service = pose_service
app.state.stream_service = stream_service
app.state.pose_stream_handler = pose_stream_handler
logger.info("Services initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize services: {e}")
raise
async def start_background_tasks(app: FastAPI):
"""Start background tasks."""
try:
# Start pose service
pose_service = app.state.pose_service
await pose_service.start()
logger.info("Pose service started")
# Start pose streaming if enabled
if settings.enable_real_time_processing:
pose_stream_handler = app.state.pose_stream_handler
await pose_stream_handler.start_streaming()
logger.info("Background tasks started")
except Exception as e:
logger.error(f"Failed to start background tasks: {e}")
raise
async def cleanup_services(app: FastAPI):
"""Cleanup services on shutdown."""
try:
# Stop pose streaming
if hasattr(app.state, 'pose_stream_handler'):
await app.state.pose_stream_handler.shutdown()
# Shutdown connection manager
await connection_manager.shutdown()
# Cleanup services
if hasattr(app.state, 'stream_service'):
await app.state.stream_service.shutdown()
if hasattr(app.state, 'pose_service'):
await app.state.pose_service.stop()
if hasattr(app.state, 'hardware_service'):
await app.state.hardware_service.shutdown()
logger.info("Services cleaned up successfully")
except Exception as e:
logger.error(f"Error during cleanup: {e}")
# Create FastAPI application
app = FastAPI(
title=settings.app_name,
version=settings.version,
description="WiFi-based human pose estimation and activity recognition API",
docs_url=settings.docs_url if not settings.is_production else None,
redoc_url=settings.redoc_url if not settings.is_production else None,
openapi_url=settings.openapi_url if not settings.is_production else None,
lifespan=lifespan
)
# Add middleware
if settings.enable_rate_limiting:
app.add_middleware(RateLimitMiddleware)
if settings.enable_authentication:
app.add_middleware(AuthMiddleware)
# Add CORS middleware
cors_config = settings.get_cors_config()
app.add_middleware(
CORSMiddleware,
**cors_config
)
# Add trusted host middleware for production
if settings.is_production:
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=settings.allowed_hosts
)
# Exception handlers
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
"""Handle HTTP exceptions."""
return JSONResponse(
status_code=exc.status_code,
content={
"error": {
"code": exc.status_code,
"message": exc.detail,
"type": "http_error"
}
}
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""Handle request validation errors."""
return JSONResponse(
status_code=422,
content={
"error": {
"code": 422,
"message": "Validation error",
"type": "validation_error",
"details": exc.errors()
}
}
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""Handle general exceptions."""
logger.error(f"Unhandled exception: {exc}", exc_info=True)
return JSONResponse(
status_code=500,
content={
"error": {
"code": 500,
"message": "Internal server error",
"type": "internal_error"
}
}
)
# Middleware for request logging
@app.middleware("http")
async def log_requests(request: Request, call_next):
"""Log all requests."""
start_time = asyncio.get_event_loop().time()
# Process request
response = await call_next(request)
# Calculate processing time
process_time = asyncio.get_event_loop().time() - start_time
# Log request
logger.info(
f"{request.method} {request.url.path} - "
f"Status: {response.status_code} - "
f"Time: {process_time:.3f}s"
)
# Add processing time header
response.headers["X-Process-Time"] = str(process_time)
return response
# Include routers
app.include_router(
health.router,
prefix="/health",
tags=["Health"]
)
app.include_router(
pose.router,
prefix=f"{settings.api_prefix}/pose",
tags=["Pose Estimation"]
)
app.include_router(
stream.router,
prefix=f"{settings.api_prefix}/stream",
tags=["Streaming"]
)
# Root endpoint
@app.get("/")
async def root():
"""Root endpoint with API information."""
return {
"name": settings.app_name,
"version": settings.version,
"environment": settings.environment,
"docs_url": settings.docs_url,
"api_prefix": settings.api_prefix,
"features": {
"authentication": settings.enable_authentication,
"rate_limiting": settings.enable_rate_limiting,
"websockets": settings.enable_websockets,
"real_time_processing": settings.enable_real_time_processing
}
}
# API information endpoint
@app.get(f"{settings.api_prefix}/info")
async def api_info():
"""Get detailed API information."""
domain_config = get_domain_config()
return {
"api": {
"name": settings.app_name,
"version": settings.version,
"environment": settings.environment,
"prefix": settings.api_prefix
},
"configuration": {
"zones": len(domain_config.zones),
"routers": len(domain_config.routers),
"pose_models": len(domain_config.pose_models)
},
"features": {
"authentication": settings.enable_authentication,
"rate_limiting": settings.enable_rate_limiting,
"websockets": settings.enable_websockets,
"real_time_processing": settings.enable_real_time_processing,
"historical_data": settings.enable_historical_data
},
"limits": {
"rate_limit_requests": settings.rate_limit_requests,
"rate_limit_window": settings.rate_limit_window,
"max_websocket_connections": domain_config.streaming.max_connections
}
}
# Status endpoint
@app.get(f"{settings.api_prefix}/status")
async def api_status(request: Request):
"""Get current API status."""
try:
# Get services from app state
hardware_service = getattr(request.app.state, 'hardware_service', None)
pose_service = getattr(request.app.state, 'pose_service', None)
stream_service = getattr(request.app.state, 'stream_service', None)
pose_stream_handler = getattr(request.app.state, 'pose_stream_handler', None)
# Get service statuses
status = {
"api": {
"status": "healthy",
"uptime": "unknown",
"version": settings.version
},
"services": {
"hardware": await hardware_service.get_status() if hardware_service else {"status": "unavailable"},
"pose": await pose_service.get_status() if pose_service else {"status": "unavailable"},
"stream": await stream_service.get_status() if stream_service else {"status": "unavailable"}
},
"streaming": pose_stream_handler.get_stream_status() if pose_stream_handler else {"is_streaming": False},
"connections": await connection_manager.get_connection_stats()
}
return status
except Exception as e:
logger.error(f"Error getting API status: {e}")
return {
"api": {
"status": "error",
"error": str(e)
}
}
# Metrics endpoint (if enabled)
if settings.metrics_enabled:
@app.get(f"{settings.api_prefix}/metrics")
async def api_metrics(request: Request):
"""Get API metrics."""
try:
# Get services from app state
pose_stream_handler = getattr(request.app.state, 'pose_stream_handler', None)
metrics = {
"connections": await connection_manager.get_metrics(),
"streaming": await pose_stream_handler.get_performance_metrics() if pose_stream_handler else {}
}
return metrics
except Exception as e:
logger.error(f"Error getting metrics: {e}")
return {"error": str(e)}
# Development endpoints (only in development)
if settings.is_development and settings.enable_test_endpoints:
@app.get(f"{settings.api_prefix}/dev/config")
async def dev_config():
"""Get current configuration (development only)."""
domain_config = get_domain_config()
return {
"settings": settings.dict(),
"domain_config": domain_config.to_dict()
}
@app.post(f"{settings.api_prefix}/dev/reset")
async def dev_reset(request: Request):
"""Reset services (development only)."""
try:
# Reset services
hardware_service = getattr(request.app.state, 'hardware_service', None)
pose_service = getattr(request.app.state, 'pose_service', None)
if hardware_service:
await hardware_service.reset()
if pose_service:
await pose_service.reset()
return {"message": "Services reset successfully"}
except Exception as e:
logger.error(f"Error resetting services: {e}")
return {"error": str(e)}
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"src.api.main:app",
host=settings.host,
port=settings.port,
reload=settings.reload,
workers=settings.workers if not settings.reload else 1,
log_level=settings.log_level.lower()
)

View File

@@ -0,0 +1,8 @@
"""
FastAPI middleware package
"""
from .auth import AuthMiddleware
from .rate_limit import RateLimitMiddleware
__all__ = ["AuthMiddleware", "RateLimitMiddleware"]

View File

@@ -0,0 +1,322 @@
"""
JWT Authentication middleware for WiFi-DensePose API
"""
import logging
from typing import Optional, Dict, Any
from datetime import datetime
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from jose import JWTError, jwt
from src.config.settings import get_settings
logger = logging.getLogger(__name__)
class AuthMiddleware(BaseHTTPMiddleware):
"""JWT Authentication middleware."""
def __init__(self, app):
super().__init__(app)
self.settings = get_settings()
# Paths that don't require authentication
self.public_paths = {
"/",
"/docs",
"/redoc",
"/openapi.json",
"/health",
"/ready",
"/live",
"/version",
"/metrics"
}
# Paths that require authentication
self.protected_paths = {
"/api/v1/pose/analyze",
"/api/v1/pose/calibrate",
"/api/v1/pose/historical",
"/api/v1/stream/start",
"/api/v1/stream/stop",
"/api/v1/stream/clients",
"/api/v1/stream/broadcast"
}
async def dispatch(self, request: Request, call_next):
"""Process request through authentication middleware."""
# Skip authentication for public paths
if self._is_public_path(request.url.path):
return await call_next(request)
# Extract and validate token
token = self._extract_token(request)
if token:
try:
# Verify token and add user info to request state
user_data = await self._verify_token(token)
request.state.user = user_data
request.state.authenticated = True
logger.debug(f"Authenticated user: {user_data.get('id')}")
except Exception as e:
logger.warning(f"Token validation failed: {e}")
# For protected paths, return 401
if self._is_protected_path(request.url.path):
return JSONResponse(
status_code=401,
content={
"error": {
"code": 401,
"message": "Invalid or expired token",
"type": "authentication_error"
}
}
)
# For other paths, continue without authentication
request.state.user = None
request.state.authenticated = False
else:
# No token provided
if self._is_protected_path(request.url.path):
return JSONResponse(
status_code=401,
content={
"error": {
"code": 401,
"message": "Authentication required",
"type": "authentication_error"
}
},
headers={"WWW-Authenticate": "Bearer"}
)
request.state.user = None
request.state.authenticated = False
# Continue with request processing
response = await call_next(request)
# Add authentication headers to response
if hasattr(request.state, 'user') and request.state.user:
response.headers["X-User-ID"] = request.state.user.get("id", "")
response.headers["X-Authenticated"] = "true"
else:
response.headers["X-Authenticated"] = "false"
return response
def _is_public_path(self, path: str) -> bool:
"""Check if path is public (doesn't require authentication)."""
# Exact match
if path in self.public_paths:
return True
# Pattern matching for public paths
public_patterns = [
"/health",
"/metrics",
"/api/v1/pose/current", # Allow anonymous access to current pose data
"/api/v1/pose/zones/", # Allow anonymous access to zone data
"/api/v1/pose/activities", # Allow anonymous access to activities
"/api/v1/pose/stats", # Allow anonymous access to stats
"/api/v1/stream/status" # Allow anonymous access to stream status
]
for pattern in public_patterns:
if path.startswith(pattern):
return True
return False
def _is_protected_path(self, path: str) -> bool:
"""Check if path requires authentication."""
# Exact match
if path in self.protected_paths:
return True
# Pattern matching for protected paths
protected_patterns = [
"/api/v1/pose/analyze",
"/api/v1/pose/calibrate",
"/api/v1/pose/historical",
"/api/v1/stream/start",
"/api/v1/stream/stop",
"/api/v1/stream/clients",
"/api/v1/stream/broadcast"
]
for pattern in protected_patterns:
if path.startswith(pattern):
return True
return False
def _extract_token(self, request: Request) -> Optional[str]:
"""Extract JWT token from request."""
# Check Authorization header
auth_header = request.headers.get("authorization")
if auth_header and auth_header.startswith("Bearer "):
return auth_header.split(" ")[1]
# Check query parameter (for WebSocket connections)
token = request.query_params.get("token")
if token:
return token
# Check cookie
token = request.cookies.get("access_token")
if token:
return token
return None
async def _verify_token(self, token: str) -> Dict[str, Any]:
"""Verify JWT token and return user data."""
try:
# Decode JWT token
payload = jwt.decode(
token,
self.settings.secret_key,
algorithms=[self.settings.jwt_algorithm]
)
# Extract user information
user_id = payload.get("sub")
if not user_id:
raise ValueError("Token missing user ID")
# Check token expiration
exp = payload.get("exp")
if exp and datetime.utcnow() > datetime.fromtimestamp(exp):
raise ValueError("Token expired")
# Build user object
user_data = {
"id": user_id,
"username": payload.get("username"),
"email": payload.get("email"),
"is_admin": payload.get("is_admin", False),
"permissions": payload.get("permissions", []),
"accessible_zones": payload.get("accessible_zones", []),
"token_issued_at": payload.get("iat"),
"token_expires_at": payload.get("exp"),
"session_id": payload.get("session_id")
}
return user_data
except JWTError as e:
raise ValueError(f"JWT validation failed: {e}")
except Exception as e:
raise ValueError(f"Token verification error: {e}")
def _log_authentication_event(self, request: Request, event_type: str, details: Dict[str, Any] = None):
"""Log authentication events for security monitoring."""
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "unknown")
log_data = {
"event_type": event_type,
"timestamp": datetime.utcnow().isoformat(),
"client_ip": client_ip,
"user_agent": user_agent,
"path": request.url.path,
"method": request.method
}
if details:
log_data.update(details)
if event_type in ["authentication_failed", "token_expired", "invalid_token"]:
logger.warning(f"Auth event: {log_data}")
else:
logger.info(f"Auth event: {log_data}")
class TokenBlacklist:
"""Simple in-memory token blacklist for logout functionality."""
def __init__(self):
self._blacklisted_tokens = set()
self._cleanup_interval = 3600 # 1 hour
self._last_cleanup = datetime.utcnow()
def add_token(self, token: str):
"""Add token to blacklist."""
self._blacklisted_tokens.add(token)
self._cleanup_if_needed()
def is_blacklisted(self, token: str) -> bool:
"""Check if token is blacklisted."""
self._cleanup_if_needed()
return token in self._blacklisted_tokens
def _cleanup_if_needed(self):
"""Clean up expired tokens from blacklist."""
now = datetime.utcnow()
if (now - self._last_cleanup).total_seconds() > self._cleanup_interval:
# In a real implementation, you would check token expiration
# For now, we'll just clear old tokens periodically
self._blacklisted_tokens.clear()
self._last_cleanup = now
# Global token blacklist instance
token_blacklist = TokenBlacklist()
class SecurityHeaders:
"""Security headers for API responses."""
@staticmethod
def add_security_headers(response: Response) -> Response:
"""Add security headers to response."""
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
response.headers["Content-Security-Policy"] = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline'; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data:; "
"connect-src 'self' ws: wss:;"
)
return response
class APIKeyAuth:
"""Alternative API key authentication for service-to-service communication."""
def __init__(self, api_keys: Dict[str, Dict[str, Any]] = None):
self.api_keys = api_keys or {}
def verify_api_key(self, api_key: str) -> Optional[Dict[str, Any]]:
"""Verify API key and return associated service info."""
if api_key in self.api_keys:
return self.api_keys[api_key]
return None
def add_api_key(self, api_key: str, service_info: Dict[str, Any]):
"""Add new API key."""
self.api_keys[api_key] = service_info
def revoke_api_key(self, api_key: str):
"""Revoke API key."""
if api_key in self.api_keys:
del self.api_keys[api_key]
# Global API key auth instance
api_key_auth = APIKeyAuth()

View File

@@ -0,0 +1,429 @@
"""
Rate limiting middleware for WiFi-DensePose API
"""
import logging
import time
from typing import Dict, Optional, Tuple
from datetime import datetime, timedelta
from collections import defaultdict, deque
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from src.config.settings import get_settings
logger = logging.getLogger(__name__)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Rate limiting middleware with sliding window algorithm."""
def __init__(self, app):
super().__init__(app)
self.settings = get_settings()
# Rate limit storage (in production, use Redis)
self.request_counts = defaultdict(lambda: deque())
self.blocked_clients = {}
# Rate limit configurations
self.rate_limits = {
"anonymous": {
"requests": self.settings.rate_limit_requests,
"window": self.settings.rate_limit_window,
"burst": 10 # Allow burst of 10 requests
},
"authenticated": {
"requests": self.settings.rate_limit_authenticated_requests,
"window": self.settings.rate_limit_window,
"burst": 50
},
"admin": {
"requests": 10000, # Very high limit for admins
"window": self.settings.rate_limit_window,
"burst": 100
}
}
# Path-specific rate limits
self.path_limits = {
"/api/v1/pose/current": {"requests": 60, "window": 60}, # 1 per second
"/api/v1/pose/analyze": {"requests": 10, "window": 60}, # 10 per minute
"/api/v1/pose/calibrate": {"requests": 1, "window": 300}, # 1 per 5 minutes
"/api/v1/stream/start": {"requests": 5, "window": 60}, # 5 per minute
"/api/v1/stream/stop": {"requests": 5, "window": 60}, # 5 per minute
}
# Exempt paths from rate limiting
self.exempt_paths = {
"/health",
"/ready",
"/live",
"/version",
"/metrics"
}
async def dispatch(self, request: Request, call_next):
"""Process request through rate limiting middleware."""
# Skip rate limiting for exempt paths
if self._is_exempt_path(request.url.path):
return await call_next(request)
# Get client identifier
client_id = self._get_client_id(request)
# Check if client is temporarily blocked
if self._is_client_blocked(client_id):
return self._create_rate_limit_response(
"Client temporarily blocked due to excessive requests"
)
# Get user type for rate limiting
user_type = self._get_user_type(request)
# Check rate limits
rate_limit_result = self._check_rate_limits(
client_id,
request.url.path,
user_type
)
if not rate_limit_result["allowed"]:
# Log rate limit violation
self._log_rate_limit_violation(request, client_id, rate_limit_result)
# Check if client should be temporarily blocked
if rate_limit_result.get("violations", 0) > 5:
self._block_client(client_id, duration=300) # 5 minutes
return self._create_rate_limit_response(
rate_limit_result["message"],
retry_after=rate_limit_result.get("retry_after", 60)
)
# Record the request
self._record_request(client_id, request.url.path)
# Process request
response = await call_next(request)
# Add rate limit headers
self._add_rate_limit_headers(response, client_id, user_type)
return response
def _is_exempt_path(self, path: str) -> bool:
"""Check if path is exempt from rate limiting."""
return path in self.exempt_paths
def _get_client_id(self, request: Request) -> str:
"""Get unique client identifier for rate limiting."""
# Try to get user ID from request state (set by auth middleware)
if hasattr(request.state, 'user') and request.state.user:
return f"user:{request.state.user['id']}"
# Fall back to IP address
client_ip = request.client.host if request.client else "unknown"
# Include user agent for better identification
user_agent = request.headers.get("user-agent", "")
user_agent_hash = str(hash(user_agent))[:8]
return f"ip:{client_ip}:{user_agent_hash}"
def _get_user_type(self, request: Request) -> str:
"""Determine user type for rate limiting."""
if hasattr(request.state, 'user') and request.state.user:
if request.state.user.get("is_admin", False):
return "admin"
return "authenticated"
return "anonymous"
def _check_rate_limits(self, client_id: str, path: str, user_type: str) -> Dict:
"""Check if request is within rate limits."""
now = time.time()
# Get applicable rate limits
general_limit = self.rate_limits[user_type]
path_limit = self.path_limits.get(path)
# Check general rate limit
general_result = self._check_limit(
client_id,
"general",
general_limit["requests"],
general_limit["window"],
now
)
if not general_result["allowed"]:
return general_result
# Check path-specific rate limit if exists
if path_limit:
path_result = self._check_limit(
client_id,
f"path:{path}",
path_limit["requests"],
path_limit["window"],
now
)
if not path_result["allowed"]:
return path_result
return {"allowed": True}
def _check_limit(self, client_id: str, limit_type: str, max_requests: int, window: int, now: float) -> Dict:
"""Check specific rate limit using sliding window."""
key = f"{client_id}:{limit_type}"
requests = self.request_counts[key]
# Remove old requests outside the window
cutoff = now - window
while requests and requests[0] <= cutoff:
requests.popleft()
# Check if limit exceeded
if len(requests) >= max_requests:
# Calculate retry after time
oldest_request = requests[0] if requests else now
retry_after = int(oldest_request + window - now) + 1
return {
"allowed": False,
"message": f"Rate limit exceeded: {max_requests} requests per {window} seconds",
"retry_after": retry_after,
"current_count": len(requests),
"limit": max_requests,
"window": window
}
return {
"allowed": True,
"current_count": len(requests),
"limit": max_requests,
"window": window
}
def _record_request(self, client_id: str, path: str):
"""Record a request for rate limiting."""
now = time.time()
# Record general request
general_key = f"{client_id}:general"
self.request_counts[general_key].append(now)
# Record path-specific request if path has specific limits
if path in self.path_limits:
path_key = f"{client_id}:path:{path}"
self.request_counts[path_key].append(now)
def _is_client_blocked(self, client_id: str) -> bool:
"""Check if client is temporarily blocked."""
if client_id in self.blocked_clients:
block_until = self.blocked_clients[client_id]
if time.time() < block_until:
return True
else:
# Block expired, remove it
del self.blocked_clients[client_id]
return False
def _block_client(self, client_id: str, duration: int):
"""Temporarily block a client."""
self.blocked_clients[client_id] = time.time() + duration
logger.warning(f"Client {client_id} blocked for {duration} seconds due to rate limit violations")
def _create_rate_limit_response(self, message: str, retry_after: int = 60) -> JSONResponse:
"""Create rate limit exceeded response."""
return JSONResponse(
status_code=429,
content={
"error": {
"code": 429,
"message": message,
"type": "rate_limit_exceeded"
}
},
headers={
"Retry-After": str(retry_after),
"X-RateLimit-Limit": "Exceeded",
"X-RateLimit-Remaining": "0"
}
)
def _add_rate_limit_headers(self, response: Response, client_id: str, user_type: str):
"""Add rate limit headers to response."""
try:
general_limit = self.rate_limits[user_type]
general_key = f"{client_id}:general"
current_requests = len(self.request_counts[general_key])
remaining = max(0, general_limit["requests"] - current_requests)
response.headers["X-RateLimit-Limit"] = str(general_limit["requests"])
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Window"] = str(general_limit["window"])
# Add reset time
if self.request_counts[general_key]:
oldest_request = self.request_counts[general_key][0]
reset_time = int(oldest_request + general_limit["window"])
response.headers["X-RateLimit-Reset"] = str(reset_time)
except Exception as e:
logger.error(f"Error adding rate limit headers: {e}")
def _log_rate_limit_violation(self, request: Request, client_id: str, result: Dict):
"""Log rate limit violations for monitoring."""
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "unknown")
log_data = {
"event_type": "rate_limit_violation",
"timestamp": datetime.utcnow().isoformat(),
"client_id": client_id,
"client_ip": client_ip,
"user_agent": user_agent,
"path": request.url.path,
"method": request.method,
"current_count": result.get("current_count"),
"limit": result.get("limit"),
"window": result.get("window")
}
logger.warning(f"Rate limit violation: {log_data}")
def cleanup_old_data(self):
"""Clean up old rate limiting data (call periodically)."""
now = time.time()
cutoff = now - 3600 # Keep data for 1 hour
# Clean up request counts
for key in list(self.request_counts.keys()):
requests = self.request_counts[key]
while requests and requests[0] <= cutoff:
requests.popleft()
# Remove empty deques
if not requests:
del self.request_counts[key]
# Clean up expired blocks
expired_blocks = [
client_id for client_id, block_until in self.blocked_clients.items()
if now >= block_until
]
for client_id in expired_blocks:
del self.blocked_clients[client_id]
class AdaptiveRateLimit:
"""Adaptive rate limiting based on system load."""
def __init__(self):
self.base_limits = {}
self.current_multiplier = 1.0
self.load_history = deque(maxlen=60) # Keep 1 minute of load data
def update_system_load(self, cpu_percent: float, memory_percent: float):
"""Update system load metrics."""
load_score = (cpu_percent + memory_percent) / 2
self.load_history.append(load_score)
# Calculate adaptive multiplier
if len(self.load_history) >= 10:
avg_load = sum(self.load_history) / len(self.load_history)
if avg_load > 80:
self.current_multiplier = 0.5 # Reduce limits by 50%
elif avg_load > 60:
self.current_multiplier = 0.7 # Reduce limits by 30%
elif avg_load < 30:
self.current_multiplier = 1.2 # Increase limits by 20%
else:
self.current_multiplier = 1.0 # Normal limits
def get_adjusted_limit(self, base_limit: int) -> int:
"""Get adjusted rate limit based on system load."""
return max(1, int(base_limit * self.current_multiplier))
class RateLimitStorage:
"""Abstract interface for rate limit storage (Redis implementation)."""
async def get_count(self, key: str, window: int) -> int:
"""Get current request count for key within window."""
raise NotImplementedError
async def increment(self, key: str, window: int) -> int:
"""Increment request count and return new count."""
raise NotImplementedError
async def is_blocked(self, client_id: str) -> bool:
"""Check if client is blocked."""
raise NotImplementedError
async def block_client(self, client_id: str, duration: int):
"""Block client for duration seconds."""
raise NotImplementedError
class RedisRateLimitStorage(RateLimitStorage):
"""Redis-based rate limit storage for production use."""
def __init__(self, redis_client):
self.redis = redis_client
async def get_count(self, key: str, window: int) -> int:
"""Get current request count using Redis sliding window."""
now = time.time()
pipeline = self.redis.pipeline()
# Remove old entries
pipeline.zremrangebyscore(key, 0, now - window)
# Count current entries
pipeline.zcard(key)
results = await pipeline.execute()
return results[1]
async def increment(self, key: str, window: int) -> int:
"""Increment request count using Redis."""
now = time.time()
pipeline = self.redis.pipeline()
# Add current request
pipeline.zadd(key, {str(now): now})
# Remove old entries
pipeline.zremrangebyscore(key, 0, now - window)
# Set expiration
pipeline.expire(key, window + 1)
# Get count
pipeline.zcard(key)
results = await pipeline.execute()
return results[3]
async def is_blocked(self, client_id: str) -> bool:
"""Check if client is blocked."""
block_key = f"blocked:{client_id}"
return await self.redis.exists(block_key)
async def block_client(self, client_id: str, duration: int):
"""Block client for duration seconds."""
block_key = f"blocked:{client_id}"
await self.redis.setex(block_key, duration, "1")
# Global adaptive rate limiter instance
adaptive_rate_limit = AdaptiveRateLimit()

View File

@@ -0,0 +1,7 @@
"""
API routers package
"""
from . import pose, stream, health
__all__ = ["pose", "stream", "health"]

View File

@@ -0,0 +1,419 @@
"""
Health check API endpoints
"""
import logging
import psutil
from typing import Dict, Any, Optional
from datetime import datetime, timedelta
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel, Field
from src.api.dependencies import get_current_user
from src.config.settings import get_settings
logger = logging.getLogger(__name__)
router = APIRouter()
# Response models
class ComponentHealth(BaseModel):
"""Health status for a system component."""
name: str = Field(..., description="Component name")
status: str = Field(..., description="Health status (healthy, degraded, unhealthy)")
message: Optional[str] = Field(default=None, description="Status message")
last_check: datetime = Field(..., description="Last health check timestamp")
uptime_seconds: Optional[float] = Field(default=None, description="Component uptime")
metrics: Optional[Dict[str, Any]] = Field(default=None, description="Component metrics")
class SystemHealth(BaseModel):
"""Overall system health status."""
status: str = Field(..., description="Overall system status")
timestamp: datetime = Field(..., description="Health check timestamp")
uptime_seconds: float = Field(..., description="System uptime")
components: Dict[str, ComponentHealth] = Field(..., description="Component health status")
system_metrics: Dict[str, Any] = Field(..., description="System-level metrics")
class ReadinessCheck(BaseModel):
"""System readiness check result."""
ready: bool = Field(..., description="Whether system is ready to serve requests")
timestamp: datetime = Field(..., description="Readiness check timestamp")
checks: Dict[str, bool] = Field(..., description="Individual readiness checks")
message: str = Field(..., description="Readiness status message")
# Health check endpoints
@router.get("/health", response_model=SystemHealth)
async def health_check(request: Request):
"""Comprehensive system health check."""
try:
# Get services from app state
hardware_service = getattr(request.app.state, 'hardware_service', None)
pose_service = getattr(request.app.state, 'pose_service', None)
stream_service = getattr(request.app.state, 'stream_service', None)
timestamp = datetime.utcnow()
components = {}
overall_status = "healthy"
# Check hardware service
if hardware_service:
try:
hw_health = await hardware_service.health_check()
components["hardware"] = ComponentHealth(
name="Hardware Service",
status=hw_health["status"],
message=hw_health.get("message"),
last_check=timestamp,
uptime_seconds=hw_health.get("uptime_seconds"),
metrics=hw_health.get("metrics")
)
if hw_health["status"] != "healthy":
overall_status = "degraded" if overall_status == "healthy" else "unhealthy"
except Exception as e:
logger.error(f"Hardware service health check failed: {e}")
components["hardware"] = ComponentHealth(
name="Hardware Service",
status="unhealthy",
message=f"Health check failed: {str(e)}",
last_check=timestamp
)
overall_status = "unhealthy"
else:
components["hardware"] = ComponentHealth(
name="Hardware Service",
status="unavailable",
message="Service not initialized",
last_check=timestamp
)
overall_status = "degraded"
# Check pose service
if pose_service:
try:
pose_health = await pose_service.health_check()
components["pose"] = ComponentHealth(
name="Pose Service",
status=pose_health["status"],
message=pose_health.get("message"),
last_check=timestamp,
uptime_seconds=pose_health.get("uptime_seconds"),
metrics=pose_health.get("metrics")
)
if pose_health["status"] != "healthy":
overall_status = "degraded" if overall_status == "healthy" else "unhealthy"
except Exception as e:
logger.error(f"Pose service health check failed: {e}")
components["pose"] = ComponentHealth(
name="Pose Service",
status="unhealthy",
message=f"Health check failed: {str(e)}",
last_check=timestamp
)
overall_status = "unhealthy"
else:
components["pose"] = ComponentHealth(
name="Pose Service",
status="unavailable",
message="Service not initialized",
last_check=timestamp
)
overall_status = "degraded"
# Check stream service
if stream_service:
try:
stream_health = await stream_service.health_check()
components["stream"] = ComponentHealth(
name="Stream Service",
status=stream_health["status"],
message=stream_health.get("message"),
last_check=timestamp,
uptime_seconds=stream_health.get("uptime_seconds"),
metrics=stream_health.get("metrics")
)
if stream_health["status"] != "healthy":
overall_status = "degraded" if overall_status == "healthy" else "unhealthy"
except Exception as e:
logger.error(f"Stream service health check failed: {e}")
components["stream"] = ComponentHealth(
name="Stream Service",
status="unhealthy",
message=f"Health check failed: {str(e)}",
last_check=timestamp
)
overall_status = "unhealthy"
else:
components["stream"] = ComponentHealth(
name="Stream Service",
status="unavailable",
message="Service not initialized",
last_check=timestamp
)
overall_status = "degraded"
# Get system metrics
system_metrics = get_system_metrics()
# Calculate system uptime (placeholder - would need actual startup time)
uptime_seconds = 0.0 # TODO: Implement actual uptime tracking
return SystemHealth(
status=overall_status,
timestamp=timestamp,
uptime_seconds=uptime_seconds,
components=components,
system_metrics=system_metrics
)
except Exception as e:
logger.error(f"Health check failed: {e}")
raise HTTPException(
status_code=500,
detail=f"Health check failed: {str(e)}"
)
@router.get("/ready", response_model=ReadinessCheck)
async def readiness_check(request: Request):
"""Check if system is ready to serve requests."""
try:
timestamp = datetime.utcnow()
checks = {}
# Check if services are available in app state
if hasattr(request.app.state, 'pose_service') and request.app.state.pose_service:
try:
checks["pose_ready"] = await request.app.state.pose_service.is_ready()
except Exception as e:
logger.warning(f"Pose service readiness check failed: {e}")
checks["pose_ready"] = False
else:
checks["pose_ready"] = False
if hasattr(request.app.state, 'stream_service') and request.app.state.stream_service:
try:
checks["stream_ready"] = await request.app.state.stream_service.is_ready()
except Exception as e:
logger.warning(f"Stream service readiness check failed: {e}")
checks["stream_ready"] = False
else:
checks["stream_ready"] = False
# Hardware service check (basic availability)
checks["hardware_ready"] = True # Basic readiness - API is responding
# Check system resources
checks["memory_available"] = check_memory_availability()
checks["disk_space_available"] = check_disk_space()
# Application is ready if at least the basic services are available
# For now, we'll consider it ready if the API is responding
ready = True # Basic readiness
message = "System is ready" if ready else "System is not ready"
if not ready:
failed_checks = [name for name, status in checks.items() if not status]
message += f". Failed checks: {', '.join(failed_checks)}"
return ReadinessCheck(
ready=ready,
timestamp=timestamp,
checks=checks,
message=message
)
except Exception as e:
logger.error(f"Readiness check failed: {e}")
return ReadinessCheck(
ready=False,
timestamp=datetime.utcnow(),
checks={},
message=f"Readiness check failed: {str(e)}"
)
@router.get("/live")
async def liveness_check():
"""Simple liveness check for load balancers."""
return {
"status": "alive",
"timestamp": datetime.utcnow().isoformat()
}
@router.get("/metrics")
async def get_health_metrics(
request: Request,
current_user: Optional[Dict] = Depends(get_current_user)
):
"""Get detailed system metrics."""
try:
metrics = get_system_metrics()
# Add additional metrics if authenticated
if current_user:
metrics.update(get_detailed_metrics())
return {
"timestamp": datetime.utcnow().isoformat(),
"metrics": metrics
}
except Exception as e:
logger.error(f"Error getting system metrics: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to get system metrics: {str(e)}"
)
@router.get("/version")
async def get_version_info():
"""Get application version information."""
settings = get_settings()
return {
"name": settings.app_name,
"version": settings.version,
"environment": settings.environment,
"debug": settings.debug,
"timestamp": datetime.utcnow().isoformat()
}
def get_system_metrics() -> Dict[str, Any]:
"""Get basic system metrics."""
try:
# CPU metrics
cpu_percent = psutil.cpu_percent(interval=1)
cpu_count = psutil.cpu_count()
# Memory metrics
memory = psutil.virtual_memory()
memory_metrics = {
"total_gb": round(memory.total / (1024**3), 2),
"available_gb": round(memory.available / (1024**3), 2),
"used_gb": round(memory.used / (1024**3), 2),
"percent": memory.percent
}
# Disk metrics
disk = psutil.disk_usage('/')
disk_metrics = {
"total_gb": round(disk.total / (1024**3), 2),
"free_gb": round(disk.free / (1024**3), 2),
"used_gb": round(disk.used / (1024**3), 2),
"percent": round((disk.used / disk.total) * 100, 2)
}
# Network metrics (basic)
network = psutil.net_io_counters()
network_metrics = {
"bytes_sent": network.bytes_sent,
"bytes_recv": network.bytes_recv,
"packets_sent": network.packets_sent,
"packets_recv": network.packets_recv
}
return {
"cpu": {
"percent": cpu_percent,
"count": cpu_count
},
"memory": memory_metrics,
"disk": disk_metrics,
"network": network_metrics
}
except Exception as e:
logger.error(f"Error getting system metrics: {e}")
return {}
def get_detailed_metrics() -> Dict[str, Any]:
"""Get detailed system metrics (requires authentication)."""
try:
# Process metrics
process = psutil.Process()
process_metrics = {
"pid": process.pid,
"cpu_percent": process.cpu_percent(),
"memory_mb": round(process.memory_info().rss / (1024**2), 2),
"num_threads": process.num_threads(),
"create_time": datetime.fromtimestamp(process.create_time()).isoformat()
}
# Load average (Unix-like systems)
load_avg = None
try:
load_avg = psutil.getloadavg()
except AttributeError:
# Windows doesn't have load average
pass
# Temperature sensors (if available)
temperatures = {}
try:
temps = psutil.sensors_temperatures()
for name, entries in temps.items():
temperatures[name] = [
{"label": entry.label, "current": entry.current}
for entry in entries
]
except AttributeError:
# Not available on all systems
pass
detailed = {
"process": process_metrics
}
if load_avg:
detailed["load_average"] = {
"1min": load_avg[0],
"5min": load_avg[1],
"15min": load_avg[2]
}
if temperatures:
detailed["temperatures"] = temperatures
return detailed
except Exception as e:
logger.error(f"Error getting detailed metrics: {e}")
return {}
def check_memory_availability() -> bool:
"""Check if sufficient memory is available."""
try:
memory = psutil.virtual_memory()
# Consider system ready if less than 90% memory is used
return memory.percent < 90.0
except Exception:
return False
def check_disk_space() -> bool:
"""Check if sufficient disk space is available."""
try:
disk = psutil.disk_usage('/')
# Consider system ready if more than 1GB free space
free_gb = disk.free / (1024**3)
return free_gb > 1.0
except Exception:
return False

420
v1/src/api/routers/pose.py Normal file
View File

@@ -0,0 +1,420 @@
"""
Pose estimation API endpoints
"""
import logging
from typing import List, Optional, Dict, Any
from datetime import datetime, timedelta
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from src.api.dependencies import (
get_pose_service,
get_hardware_service,
get_current_user,
require_auth
)
from src.services.pose_service import PoseService
from src.services.hardware_service import HardwareService
from src.config.settings import get_settings
logger = logging.getLogger(__name__)
router = APIRouter()
# Request/Response models
class PoseEstimationRequest(BaseModel):
"""Request model for pose estimation."""
zone_ids: Optional[List[str]] = Field(
default=None,
description="Specific zones to analyze (all zones if not specified)"
)
confidence_threshold: Optional[float] = Field(
default=None,
ge=0.0,
le=1.0,
description="Minimum confidence threshold for detections"
)
max_persons: Optional[int] = Field(
default=None,
ge=1,
le=50,
description="Maximum number of persons to detect"
)
include_keypoints: bool = Field(
default=True,
description="Include detailed keypoint data"
)
include_segmentation: bool = Field(
default=False,
description="Include DensePose segmentation masks"
)
class PersonPose(BaseModel):
"""Person pose data model."""
person_id: str = Field(..., description="Unique person identifier")
confidence: float = Field(..., description="Detection confidence score")
bounding_box: Dict[str, float] = Field(..., description="Person bounding box")
keypoints: Optional[List[Dict[str, Any]]] = Field(
default=None,
description="Body keypoints with coordinates and confidence"
)
segmentation: Optional[Dict[str, Any]] = Field(
default=None,
description="DensePose segmentation data"
)
zone_id: Optional[str] = Field(
default=None,
description="Zone where person is detected"
)
activity: Optional[str] = Field(
default=None,
description="Detected activity"
)
timestamp: datetime = Field(..., description="Detection timestamp")
class PoseEstimationResponse(BaseModel):
"""Response model for pose estimation."""
timestamp: datetime = Field(..., description="Analysis timestamp")
frame_id: str = Field(..., description="Unique frame identifier")
persons: List[PersonPose] = Field(..., description="Detected persons")
zone_summary: Dict[str, int] = Field(..., description="Person count per zone")
processing_time_ms: float = Field(..., description="Processing time in milliseconds")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
class HistoricalDataRequest(BaseModel):
"""Request model for historical pose data."""
start_time: datetime = Field(..., description="Start time for data query")
end_time: datetime = Field(..., description="End time for data query")
zone_ids: Optional[List[str]] = Field(
default=None,
description="Filter by specific zones"
)
aggregation_interval: Optional[int] = Field(
default=300,
ge=60,
le=3600,
description="Aggregation interval in seconds"
)
include_raw_data: bool = Field(
default=False,
description="Include raw detection data"
)
# Endpoints
@router.get("/current", response_model=PoseEstimationResponse)
async def get_current_pose_estimation(
request: PoseEstimationRequest = Depends(),
pose_service: PoseService = Depends(get_pose_service),
current_user: Optional[Dict] = Depends(get_current_user)
):
"""Get current pose estimation from WiFi signals."""
try:
logger.info(f"Processing pose estimation request from user: {current_user.get('id') if current_user else 'anonymous'}")
# Get current pose estimation
result = await pose_service.estimate_poses(
zone_ids=request.zone_ids,
confidence_threshold=request.confidence_threshold,
max_persons=request.max_persons,
include_keypoints=request.include_keypoints,
include_segmentation=request.include_segmentation
)
return PoseEstimationResponse(**result)
except Exception as e:
logger.error(f"Error in pose estimation: {e}")
raise HTTPException(
status_code=500,
detail=f"Pose estimation failed: {str(e)}"
)
@router.post("/analyze", response_model=PoseEstimationResponse)
async def analyze_pose_data(
request: PoseEstimationRequest,
background_tasks: BackgroundTasks,
pose_service: PoseService = Depends(get_pose_service),
current_user: Dict = Depends(require_auth)
):
"""Trigger pose analysis with custom parameters."""
try:
logger.info(f"Custom pose analysis requested by user: {current_user['id']}")
# Trigger analysis
result = await pose_service.analyze_with_params(
zone_ids=request.zone_ids,
confidence_threshold=request.confidence_threshold,
max_persons=request.max_persons,
include_keypoints=request.include_keypoints,
include_segmentation=request.include_segmentation
)
# Schedule background processing if needed
if request.include_segmentation:
background_tasks.add_task(
pose_service.process_segmentation_data,
result["frame_id"]
)
return PoseEstimationResponse(**result)
except Exception as e:
logger.error(f"Error in pose analysis: {e}")
raise HTTPException(
status_code=500,
detail=f"Pose analysis failed: {str(e)}"
)
@router.get("/zones/{zone_id}/occupancy")
async def get_zone_occupancy(
zone_id: str,
pose_service: PoseService = Depends(get_pose_service),
current_user: Optional[Dict] = Depends(get_current_user)
):
"""Get current occupancy for a specific zone."""
try:
occupancy = await pose_service.get_zone_occupancy(zone_id)
if occupancy is None:
raise HTTPException(
status_code=404,
detail=f"Zone '{zone_id}' not found"
)
return {
"zone_id": zone_id,
"current_occupancy": occupancy["count"],
"max_occupancy": occupancy.get("max_occupancy"),
"persons": occupancy["persons"],
"timestamp": occupancy["timestamp"]
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting zone occupancy: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to get zone occupancy: {str(e)}"
)
@router.get("/zones/summary")
async def get_zones_summary(
pose_service: PoseService = Depends(get_pose_service),
current_user: Optional[Dict] = Depends(get_current_user)
):
"""Get occupancy summary for all zones."""
try:
summary = await pose_service.get_zones_summary()
return {
"timestamp": datetime.utcnow(),
"total_persons": summary["total_persons"],
"zones": summary["zones"],
"active_zones": summary["active_zones"]
}
except Exception as e:
logger.error(f"Error getting zones summary: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to get zones summary: {str(e)}"
)
@router.post("/historical")
async def get_historical_data(
request: HistoricalDataRequest,
pose_service: PoseService = Depends(get_pose_service),
current_user: Dict = Depends(require_auth)
):
"""Get historical pose estimation data."""
try:
# Validate time range
if request.end_time <= request.start_time:
raise HTTPException(
status_code=400,
detail="End time must be after start time"
)
# Limit query range to prevent excessive data
max_range = timedelta(days=7)
if request.end_time - request.start_time > max_range:
raise HTTPException(
status_code=400,
detail="Query range cannot exceed 7 days"
)
data = await pose_service.get_historical_data(
start_time=request.start_time,
end_time=request.end_time,
zone_ids=request.zone_ids,
aggregation_interval=request.aggregation_interval,
include_raw_data=request.include_raw_data
)
return {
"query": {
"start_time": request.start_time,
"end_time": request.end_time,
"zone_ids": request.zone_ids,
"aggregation_interval": request.aggregation_interval
},
"data": data["aggregated_data"],
"raw_data": data.get("raw_data") if request.include_raw_data else None,
"total_records": data["total_records"]
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting historical data: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to get historical data: {str(e)}"
)
@router.get("/activities")
async def get_detected_activities(
zone_id: Optional[str] = Query(None, description="Filter by zone ID"),
limit: int = Query(10, ge=1, le=100, description="Maximum number of activities"),
pose_service: PoseService = Depends(get_pose_service),
current_user: Optional[Dict] = Depends(get_current_user)
):
"""Get recently detected activities."""
try:
activities = await pose_service.get_recent_activities(
zone_id=zone_id,
limit=limit
)
return {
"activities": activities,
"total_count": len(activities),
"zone_id": zone_id
}
except Exception as e:
logger.error(f"Error getting activities: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to get activities: {str(e)}"
)
@router.post("/calibrate")
async def calibrate_pose_system(
background_tasks: BackgroundTasks,
pose_service: PoseService = Depends(get_pose_service),
hardware_service: HardwareService = Depends(get_hardware_service),
current_user: Dict = Depends(require_auth)
):
"""Calibrate the pose estimation system."""
try:
logger.info(f"Pose system calibration initiated by user: {current_user['id']}")
# Check if calibration is already in progress
if await pose_service.is_calibrating():
raise HTTPException(
status_code=409,
detail="Calibration already in progress"
)
# Start calibration process
calibration_id = await pose_service.start_calibration()
# Schedule background calibration task
background_tasks.add_task(
pose_service.run_calibration,
calibration_id
)
return {
"calibration_id": calibration_id,
"status": "started",
"estimated_duration_minutes": 5,
"message": "Calibration process started"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error starting calibration: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to start calibration: {str(e)}"
)
@router.get("/calibration/status")
async def get_calibration_status(
pose_service: PoseService = Depends(get_pose_service),
current_user: Dict = Depends(require_auth)
):
"""Get current calibration status."""
try:
status = await pose_service.get_calibration_status()
return {
"is_calibrating": status["is_calibrating"],
"calibration_id": status.get("calibration_id"),
"progress_percent": status.get("progress_percent", 0),
"current_step": status.get("current_step"),
"estimated_remaining_minutes": status.get("estimated_remaining_minutes"),
"last_calibration": status.get("last_calibration")
}
except Exception as e:
logger.error(f"Error getting calibration status: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to get calibration status: {str(e)}"
)
@router.get("/stats")
async def get_pose_statistics(
hours: int = Query(24, ge=1, le=168, description="Hours of data to analyze"),
pose_service: PoseService = Depends(get_pose_service),
current_user: Optional[Dict] = Depends(get_current_user)
):
"""Get pose estimation statistics."""
try:
end_time = datetime.utcnow()
start_time = end_time - timedelta(hours=hours)
stats = await pose_service.get_statistics(
start_time=start_time,
end_time=end_time
)
return {
"period": {
"start_time": start_time,
"end_time": end_time,
"hours": hours
},
"statistics": stats
}
except Exception as e:
logger.error(f"Error getting statistics: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to get statistics: {str(e)}"
)

View File

@@ -0,0 +1,468 @@
"""
WebSocket streaming API endpoints
"""
import json
import logging
from typing import Dict, List, Optional, Any
from datetime import datetime
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, HTTPException, Query
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from src.api.dependencies import (
get_stream_service,
get_pose_service,
get_current_user_ws,
require_auth
)
from src.api.websocket.connection_manager import ConnectionManager
from src.services.stream_service import StreamService
from src.services.pose_service import PoseService
logger = logging.getLogger(__name__)
router = APIRouter()
# Initialize connection manager
connection_manager = ConnectionManager()
# Request/Response models
class StreamSubscriptionRequest(BaseModel):
"""Request model for stream subscription."""
zone_ids: Optional[List[str]] = Field(
default=None,
description="Zones to subscribe to (all zones if not specified)"
)
stream_types: List[str] = Field(
default=["pose_data"],
description="Types of data to stream"
)
min_confidence: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Minimum confidence threshold for streaming"
)
max_fps: int = Field(
default=30,
ge=1,
le=60,
description="Maximum frames per second"
)
include_metadata: bool = Field(
default=True,
description="Include metadata in stream"
)
class StreamStatus(BaseModel):
"""Stream status model."""
is_active: bool = Field(..., description="Whether streaming is active")
connected_clients: int = Field(..., description="Number of connected clients")
streams: List[Dict[str, Any]] = Field(..., description="Active streams")
uptime_seconds: float = Field(..., description="Stream uptime in seconds")
# WebSocket endpoints
@router.websocket("/pose")
async def websocket_pose_stream(
websocket: WebSocket,
zone_ids: Optional[str] = Query(None, description="Comma-separated zone IDs"),
min_confidence: float = Query(0.5, ge=0.0, le=1.0),
max_fps: int = Query(30, ge=1, le=60),
token: Optional[str] = Query(None, description="Authentication token")
):
"""WebSocket endpoint for real-time pose data streaming."""
client_id = None
try:
# Accept WebSocket connection
await websocket.accept()
# Check authentication if enabled
from src.config.settings import get_settings
settings = get_settings()
if settings.enable_authentication and not token:
await websocket.send_json({
"type": "error",
"message": "Authentication token required"
})
await websocket.close(code=1008)
return
# Parse zone IDs
zone_list = None
if zone_ids:
zone_list = [zone.strip() for zone in zone_ids.split(",") if zone.strip()]
# Register client with connection manager
client_id = await connection_manager.connect(
websocket=websocket,
stream_type="pose",
zone_ids=zone_list,
min_confidence=min_confidence,
max_fps=max_fps
)
logger.info(f"WebSocket client {client_id} connected for pose streaming")
# Send initial connection confirmation
await websocket.send_json({
"type": "connection_established",
"client_id": client_id,
"timestamp": datetime.utcnow().isoformat(),
"config": {
"zone_ids": zone_list,
"min_confidence": min_confidence,
"max_fps": max_fps
}
})
# Keep connection alive and handle incoming messages
while True:
try:
# Wait for client messages (ping, config updates, etc.)
message = await websocket.receive_text()
data = json.loads(message)
await handle_websocket_message(client_id, data, websocket)
except WebSocketDisconnect:
break
except json.JSONDecodeError:
await websocket.send_json({
"type": "error",
"message": "Invalid JSON format"
})
except Exception as e:
logger.error(f"Error handling WebSocket message: {e}")
await websocket.send_json({
"type": "error",
"message": "Internal server error"
})
except WebSocketDisconnect:
logger.info(f"WebSocket client {client_id} disconnected")
except Exception as e:
logger.error(f"WebSocket error: {e}")
finally:
if client_id:
await connection_manager.disconnect(client_id)
@router.websocket("/events")
async def websocket_events_stream(
websocket: WebSocket,
event_types: Optional[str] = Query(None, description="Comma-separated event types"),
zone_ids: Optional[str] = Query(None, description="Comma-separated zone IDs"),
token: Optional[str] = Query(None, description="Authentication token")
):
"""WebSocket endpoint for real-time event streaming."""
client_id = None
try:
await websocket.accept()
# Check authentication if enabled
from src.config.settings import get_settings
settings = get_settings()
if settings.enable_authentication and not token:
await websocket.send_json({
"type": "error",
"message": "Authentication token required"
})
await websocket.close(code=1008)
return
# Parse parameters
event_list = None
if event_types:
event_list = [event.strip() for event in event_types.split(",") if event.strip()]
zone_list = None
if zone_ids:
zone_list = [zone.strip() for zone in zone_ids.split(",") if zone.strip()]
# Register client
client_id = await connection_manager.connect(
websocket=websocket,
stream_type="events",
zone_ids=zone_list,
event_types=event_list
)
logger.info(f"WebSocket client {client_id} connected for event streaming")
# Send confirmation
await websocket.send_json({
"type": "connection_established",
"client_id": client_id,
"timestamp": datetime.utcnow().isoformat(),
"config": {
"event_types": event_list,
"zone_ids": zone_list
}
})
# Handle messages
while True:
try:
message = await websocket.receive_text()
data = json.loads(message)
await handle_websocket_message(client_id, data, websocket)
except WebSocketDisconnect:
break
except Exception as e:
logger.error(f"Error in events WebSocket: {e}")
except WebSocketDisconnect:
logger.info(f"Events WebSocket client {client_id} disconnected")
except Exception as e:
logger.error(f"Events WebSocket error: {e}")
finally:
if client_id:
await connection_manager.disconnect(client_id)
async def handle_websocket_message(client_id: str, data: Dict[str, Any], websocket: WebSocket):
"""Handle incoming WebSocket messages."""
message_type = data.get("type")
if message_type == "ping":
await websocket.send_json({
"type": "pong",
"timestamp": datetime.utcnow().isoformat()
})
elif message_type == "update_config":
# Update client configuration
config = data.get("config", {})
await connection_manager.update_client_config(client_id, config)
await websocket.send_json({
"type": "config_updated",
"timestamp": datetime.utcnow().isoformat(),
"config": config
})
elif message_type == "get_status":
# Send current status
status = await connection_manager.get_client_status(client_id)
await websocket.send_json({
"type": "status",
"timestamp": datetime.utcnow().isoformat(),
"status": status
})
else:
await websocket.send_json({
"type": "error",
"message": f"Unknown message type: {message_type}"
})
# HTTP endpoints for stream management
@router.get("/status", response_model=StreamStatus)
async def get_stream_status(
stream_service: StreamService = Depends(get_stream_service)
):
"""Get current streaming status."""
try:
status = await stream_service.get_status()
connections = await connection_manager.get_connection_stats()
# Calculate uptime (simplified for now)
uptime_seconds = 0.0
if status.get("running", False):
uptime_seconds = 3600.0 # Default 1 hour for demo
return StreamStatus(
is_active=status.get("running", False),
connected_clients=connections.get("total_clients", status["connections"]["active"]),
streams=[{
"type": "pose_stream",
"active": status.get("running", False),
"buffer_size": status["buffers"]["pose_buffer_size"]
}],
uptime_seconds=uptime_seconds
)
except Exception as e:
logger.error(f"Error getting stream status: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to get stream status: {str(e)}"
)
@router.post("/start")
async def start_streaming(
stream_service: StreamService = Depends(get_stream_service),
current_user: Dict = Depends(require_auth)
):
"""Start the streaming service."""
try:
logger.info(f"Starting streaming service by user: {current_user['id']}")
if await stream_service.is_active():
return JSONResponse(
status_code=200,
content={"message": "Streaming service is already active"}
)
await stream_service.start()
return {
"message": "Streaming service started successfully",
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error starting streaming: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to start streaming: {str(e)}"
)
@router.post("/stop")
async def stop_streaming(
stream_service: StreamService = Depends(get_stream_service),
current_user: Dict = Depends(require_auth)
):
"""Stop the streaming service."""
try:
logger.info(f"Stopping streaming service by user: {current_user['id']}")
await stream_service.stop()
await connection_manager.disconnect_all()
return {
"message": "Streaming service stopped successfully",
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error stopping streaming: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to stop streaming: {str(e)}"
)
@router.get("/clients")
async def get_connected_clients(
current_user: Dict = Depends(require_auth)
):
"""Get list of connected WebSocket clients."""
try:
clients = await connection_manager.get_connected_clients()
return {
"total_clients": len(clients),
"clients": clients,
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error getting connected clients: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to get connected clients: {str(e)}"
)
@router.delete("/clients/{client_id}")
async def disconnect_client(
client_id: str,
current_user: Dict = Depends(require_auth)
):
"""Disconnect a specific WebSocket client."""
try:
logger.info(f"Disconnecting client {client_id} by user: {current_user['id']}")
success = await connection_manager.disconnect(client_id)
if not success:
raise HTTPException(
status_code=404,
detail=f"Client {client_id} not found"
)
return {
"message": f"Client {client_id} disconnected successfully",
"timestamp": datetime.utcnow().isoformat()
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error disconnecting client: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to disconnect client: {str(e)}"
)
@router.post("/broadcast")
async def broadcast_message(
message: Dict[str, Any],
stream_type: Optional[str] = Query(None, description="Target stream type"),
zone_ids: Optional[List[str]] = Query(None, description="Target zone IDs"),
current_user: Dict = Depends(require_auth)
):
"""Broadcast a message to connected WebSocket clients."""
try:
logger.info(f"Broadcasting message by user: {current_user['id']}")
# Add metadata to message
broadcast_data = {
**message,
"broadcast_timestamp": datetime.utcnow().isoformat(),
"sender": current_user["id"]
}
# Broadcast to matching clients
sent_count = await connection_manager.broadcast(
data=broadcast_data,
stream_type=stream_type,
zone_ids=zone_ids
)
return {
"message": "Broadcast sent successfully",
"recipients": sent_count,
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error broadcasting message: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to broadcast message: {str(e)}"
)
@router.get("/metrics")
async def get_streaming_metrics():
"""Get streaming performance metrics."""
try:
metrics = await connection_manager.get_metrics()
return {
"metrics": metrics,
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error getting streaming metrics: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to get streaming metrics: {str(e)}"
)

View File

@@ -0,0 +1,8 @@
"""
WebSocket handlers package
"""
from .connection_manager import ConnectionManager
from .pose_stream import PoseStreamHandler
__all__ = ["ConnectionManager", "PoseStreamHandler"]

View File

@@ -0,0 +1,461 @@
"""
WebSocket connection manager for WiFi-DensePose API
"""
import asyncio
import json
import logging
import uuid
from typing import Dict, List, Optional, Any, Set
from datetime import datetime, timedelta
from collections import defaultdict
from fastapi import WebSocket, WebSocketDisconnect
logger = logging.getLogger(__name__)
class WebSocketConnection:
"""Represents a WebSocket connection with metadata."""
def __init__(
self,
websocket: WebSocket,
client_id: str,
stream_type: str,
zone_ids: Optional[List[str]] = None,
**config
):
self.websocket = websocket
self.client_id = client_id
self.stream_type = stream_type
self.zone_ids = zone_ids or []
self.config = config
self.connected_at = datetime.utcnow()
self.last_ping = datetime.utcnow()
self.message_count = 0
self.is_active = True
async def send_json(self, data: Dict[str, Any]):
"""Send JSON data to client."""
try:
await self.websocket.send_json(data)
self.message_count += 1
except Exception as e:
logger.error(f"Error sending to client {self.client_id}: {e}")
self.is_active = False
raise
async def send_text(self, message: str):
"""Send text message to client."""
try:
await self.websocket.send_text(message)
self.message_count += 1
except Exception as e:
logger.error(f"Error sending text to client {self.client_id}: {e}")
self.is_active = False
raise
def update_config(self, config: Dict[str, Any]):
"""Update connection configuration."""
self.config.update(config)
# Update zone IDs if provided
if "zone_ids" in config:
self.zone_ids = config["zone_ids"] or []
def matches_filter(
self,
stream_type: Optional[str] = None,
zone_ids: Optional[List[str]] = None,
**filters
) -> bool:
"""Check if connection matches given filters."""
# Check stream type
if stream_type and self.stream_type != stream_type:
return False
# Check zone IDs
if zone_ids:
if not self.zone_ids: # Connection listens to all zones
return True
# Check if any requested zone is in connection's zones
if not any(zone in self.zone_ids for zone in zone_ids):
return False
# Check additional filters
for key, value in filters.items():
if key in self.config and self.config[key] != value:
return False
return True
def get_info(self) -> Dict[str, Any]:
"""Get connection information."""
return {
"client_id": self.client_id,
"stream_type": self.stream_type,
"zone_ids": self.zone_ids,
"config": self.config,
"connected_at": self.connected_at.isoformat(),
"last_ping": self.last_ping.isoformat(),
"message_count": self.message_count,
"is_active": self.is_active,
"uptime_seconds": (datetime.utcnow() - self.connected_at).total_seconds()
}
class ConnectionManager:
"""Manages WebSocket connections for real-time streaming."""
def __init__(self):
self.connections: Dict[str, WebSocketConnection] = {}
self.connections_by_type: Dict[str, Set[str]] = defaultdict(set)
self.connections_by_zone: Dict[str, Set[str]] = defaultdict(set)
self.metrics = {
"total_connections": 0,
"active_connections": 0,
"messages_sent": 0,
"errors": 0,
"start_time": datetime.utcnow()
}
self._cleanup_task = None
self._started = False
async def connect(
self,
websocket: WebSocket,
stream_type: str,
zone_ids: Optional[List[str]] = None,
**config
) -> str:
"""Register a new WebSocket connection."""
client_id = str(uuid.uuid4())
try:
# Create connection object
connection = WebSocketConnection(
websocket=websocket,
client_id=client_id,
stream_type=stream_type,
zone_ids=zone_ids,
**config
)
# Store connection
self.connections[client_id] = connection
self.connections_by_type[stream_type].add(client_id)
# Index by zones
if zone_ids:
for zone_id in zone_ids:
self.connections_by_zone[zone_id].add(client_id)
# Update metrics
self.metrics["total_connections"] += 1
self.metrics["active_connections"] = len(self.connections)
logger.info(f"WebSocket client {client_id} connected for {stream_type}")
return client_id
except Exception as e:
logger.error(f"Error connecting WebSocket client: {e}")
raise
async def disconnect(self, client_id: str) -> bool:
"""Disconnect a WebSocket client."""
if client_id not in self.connections:
return False
try:
connection = self.connections[client_id]
# Remove from indexes
self.connections_by_type[connection.stream_type].discard(client_id)
for zone_id in connection.zone_ids:
self.connections_by_zone[zone_id].discard(client_id)
# Close WebSocket if still active
if connection.is_active:
try:
await connection.websocket.close()
except:
pass # Connection might already be closed
# Remove connection
del self.connections[client_id]
# Update metrics
self.metrics["active_connections"] = len(self.connections)
logger.info(f"WebSocket client {client_id} disconnected")
return True
except Exception as e:
logger.error(f"Error disconnecting client {client_id}: {e}")
return False
async def disconnect_all(self):
"""Disconnect all WebSocket clients."""
client_ids = list(self.connections.keys())
for client_id in client_ids:
await self.disconnect(client_id)
logger.info("All WebSocket clients disconnected")
async def send_to_client(self, client_id: str, data: Dict[str, Any]) -> bool:
"""Send data to a specific client."""
if client_id not in self.connections:
return False
connection = self.connections[client_id]
try:
await connection.send_json(data)
self.metrics["messages_sent"] += 1
return True
except Exception as e:
logger.error(f"Error sending to client {client_id}: {e}")
self.metrics["errors"] += 1
# Mark connection as inactive and schedule for cleanup
connection.is_active = False
return False
async def broadcast(
self,
data: Dict[str, Any],
stream_type: Optional[str] = None,
zone_ids: Optional[List[str]] = None,
**filters
) -> int:
"""Broadcast data to matching clients."""
sent_count = 0
failed_clients = []
# Get matching connections
matching_clients = self._get_matching_clients(
stream_type=stream_type,
zone_ids=zone_ids,
**filters
)
# Send to all matching clients
for client_id in matching_clients:
try:
success = await self.send_to_client(client_id, data)
if success:
sent_count += 1
else:
failed_clients.append(client_id)
except Exception as e:
logger.error(f"Error broadcasting to client {client_id}: {e}")
failed_clients.append(client_id)
# Clean up failed connections
for client_id in failed_clients:
await self.disconnect(client_id)
return sent_count
async def update_client_config(self, client_id: str, config: Dict[str, Any]) -> bool:
"""Update client configuration."""
if client_id not in self.connections:
return False
connection = self.connections[client_id]
old_zones = set(connection.zone_ids)
# Update configuration
connection.update_config(config)
# Update zone indexes if zones changed
new_zones = set(connection.zone_ids)
# Remove from old zones
for zone_id in old_zones - new_zones:
self.connections_by_zone[zone_id].discard(client_id)
# Add to new zones
for zone_id in new_zones - old_zones:
self.connections_by_zone[zone_id].add(client_id)
return True
async def get_client_status(self, client_id: str) -> Optional[Dict[str, Any]]:
"""Get status of a specific client."""
if client_id not in self.connections:
return None
return self.connections[client_id].get_info()
async def get_connected_clients(self) -> List[Dict[str, Any]]:
"""Get list of all connected clients."""
return [conn.get_info() for conn in self.connections.values()]
async def get_connection_stats(self) -> Dict[str, Any]:
"""Get connection statistics."""
stats = {
"total_clients": len(self.connections),
"clients_by_type": {
stream_type: len(clients)
for stream_type, clients in self.connections_by_type.items()
},
"clients_by_zone": {
zone_id: len(clients)
for zone_id, clients in self.connections_by_zone.items()
if clients # Only include zones with active clients
},
"active_clients": sum(1 for conn in self.connections.values() if conn.is_active),
"inactive_clients": sum(1 for conn in self.connections.values() if not conn.is_active)
}
return stats
async def get_metrics(self) -> Dict[str, Any]:
"""Get detailed metrics."""
uptime = (datetime.utcnow() - self.metrics["start_time"]).total_seconds()
return {
**self.metrics,
"active_connections": len(self.connections),
"uptime_seconds": uptime,
"messages_per_second": self.metrics["messages_sent"] / max(uptime, 1),
"error_rate": self.metrics["errors"] / max(self.metrics["messages_sent"], 1)
}
def _get_matching_clients(
self,
stream_type: Optional[str] = None,
zone_ids: Optional[List[str]] = None,
**filters
) -> List[str]:
"""Get client IDs that match the given filters."""
candidates = set(self.connections.keys())
# Filter by stream type
if stream_type:
type_clients = self.connections_by_type.get(stream_type, set())
candidates &= type_clients
# Filter by zones
if zone_ids:
zone_clients = set()
for zone_id in zone_ids:
zone_clients.update(self.connections_by_zone.get(zone_id, set()))
# Also include clients listening to all zones (empty zone list)
all_zone_clients = {
client_id for client_id, conn in self.connections.items()
if not conn.zone_ids
}
zone_clients.update(all_zone_clients)
candidates &= zone_clients
# Apply additional filters
matching_clients = []
for client_id in candidates:
connection = self.connections[client_id]
if connection.is_active and connection.matches_filter(**filters):
matching_clients.append(client_id)
return matching_clients
async def ping_clients(self):
"""Send ping to all connected clients."""
ping_data = {
"type": "ping",
"timestamp": datetime.utcnow().isoformat()
}
failed_clients = []
for client_id, connection in self.connections.items():
try:
await connection.send_json(ping_data)
connection.last_ping = datetime.utcnow()
except Exception as e:
logger.warning(f"Ping failed for client {client_id}: {e}")
failed_clients.append(client_id)
# Clean up failed connections
for client_id in failed_clients:
await self.disconnect(client_id)
async def cleanup_inactive_connections(self):
"""Clean up inactive or stale connections."""
now = datetime.utcnow()
stale_threshold = timedelta(minutes=5) # 5 minutes without ping
stale_clients = []
for client_id, connection in self.connections.items():
# Check if connection is inactive
if not connection.is_active:
stale_clients.append(client_id)
continue
# Check if connection is stale (no ping response)
if now - connection.last_ping > stale_threshold:
logger.warning(f"Client {client_id} appears stale, disconnecting")
stale_clients.append(client_id)
# Clean up stale connections
for client_id in stale_clients:
await self.disconnect(client_id)
if stale_clients:
logger.info(f"Cleaned up {len(stale_clients)} stale connections")
async def start(self):
"""Start the connection manager."""
if not self._started:
self._start_cleanup_task()
self._started = True
logger.info("Connection manager started")
def _start_cleanup_task(self):
"""Start background cleanup task."""
async def cleanup_loop():
while True:
try:
await asyncio.sleep(60) # Run every minute
await self.cleanup_inactive_connections()
# Send periodic ping every 2 minutes
if datetime.utcnow().minute % 2 == 0:
await self.ping_clients()
except Exception as e:
logger.error(f"Error in cleanup task: {e}")
try:
self._cleanup_task = asyncio.create_task(cleanup_loop())
except RuntimeError:
# No event loop running, will start later
logger.debug("No event loop running, cleanup task will start later")
async def shutdown(self):
"""Shutdown connection manager."""
# Cancel cleanup task
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
# Disconnect all clients
await self.disconnect_all()
logger.info("Connection manager shutdown complete")
# Global connection manager instance
connection_manager = ConnectionManager()

View File

@@ -0,0 +1,384 @@
"""
Pose streaming WebSocket handler
"""
import asyncio
import json
import logging
from typing import Dict, List, Optional, Any
from datetime import datetime
from fastapi import WebSocket
from pydantic import BaseModel, Field
from src.api.websocket.connection_manager import ConnectionManager
from src.services.pose_service import PoseService
from src.services.stream_service import StreamService
logger = logging.getLogger(__name__)
class PoseStreamData(BaseModel):
"""Pose stream data model."""
timestamp: datetime = Field(..., description="Data timestamp")
zone_id: str = Field(..., description="Zone identifier")
pose_data: Dict[str, Any] = Field(..., description="Pose estimation data")
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score")
activity: Optional[str] = Field(default=None, description="Detected activity")
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Additional metadata")
class PoseStreamHandler:
"""Handles pose data streaming to WebSocket clients."""
def __init__(
self,
connection_manager: ConnectionManager,
pose_service: PoseService,
stream_service: StreamService
):
self.connection_manager = connection_manager
self.pose_service = pose_service
self.stream_service = stream_service
self.is_streaming = False
self.stream_task = None
self.subscribers = {}
self.stream_config = {
"fps": 30,
"min_confidence": 0.5,
"include_metadata": True,
"buffer_size": 100
}
async def start_streaming(self):
"""Start pose data streaming."""
if self.is_streaming:
logger.warning("Pose streaming already active")
return
self.is_streaming = True
self.stream_task = asyncio.create_task(self._stream_loop())
logger.info("Pose streaming started")
async def stop_streaming(self):
"""Stop pose data streaming."""
if not self.is_streaming:
return
self.is_streaming = False
if self.stream_task:
self.stream_task.cancel()
try:
await self.stream_task
except asyncio.CancelledError:
pass
logger.info("Pose streaming stopped")
async def _stream_loop(self):
"""Main streaming loop."""
try:
logger.info("🚀 Starting pose streaming loop")
while self.is_streaming:
try:
# Get current pose data from all zones
logger.debug("📡 Getting current pose data...")
pose_data = await self.pose_service.get_current_pose_data()
logger.debug(f"📊 Received pose data: {pose_data}")
if pose_data:
logger.debug("📤 Broadcasting pose data...")
await self._process_and_broadcast_pose_data(pose_data)
else:
logger.debug("⚠️ No pose data received")
# Control streaming rate
await asyncio.sleep(1.0 / self.stream_config["fps"])
except Exception as e:
logger.error(f"Error in pose streaming loop: {e}")
await asyncio.sleep(1.0) # Brief pause on error
except asyncio.CancelledError:
logger.info("Pose streaming loop cancelled")
except Exception as e:
logger.error(f"Fatal error in pose streaming loop: {e}")
finally:
logger.info("🛑 Pose streaming loop stopped")
self.is_streaming = False
async def _process_and_broadcast_pose_data(self, raw_pose_data: Dict[str, Any]):
"""Process and broadcast pose data to subscribers."""
try:
# Process data for each zone
for zone_id, zone_data in raw_pose_data.items():
if not zone_data:
continue
# Create structured pose data
pose_stream_data = PoseStreamData(
timestamp=datetime.utcnow(),
zone_id=zone_id,
pose_data=zone_data.get("pose", {}),
confidence=zone_data.get("confidence", 0.0),
activity=zone_data.get("activity"),
metadata=zone_data.get("metadata") if self.stream_config["include_metadata"] else None
)
# Filter by minimum confidence
if pose_stream_data.confidence < self.stream_config["min_confidence"]:
continue
# Broadcast to subscribers
await self._broadcast_pose_data(pose_stream_data)
except Exception as e:
logger.error(f"Error processing pose data: {e}")
async def _broadcast_pose_data(self, pose_data: PoseStreamData):
"""Broadcast pose data to matching WebSocket clients."""
try:
logger.debug(f"📡 Preparing to broadcast pose data for zone {pose_data.zone_id}")
# Prepare broadcast data
broadcast_data = {
"type": "pose_data",
"timestamp": pose_data.timestamp.isoformat(),
"zone_id": pose_data.zone_id,
"data": {
"pose": pose_data.pose_data,
"confidence": pose_data.confidence,
"activity": pose_data.activity
}
}
# Add metadata if enabled
if pose_data.metadata and self.stream_config["include_metadata"]:
broadcast_data["metadata"] = pose_data.metadata
logger.debug(f"📤 Broadcasting data: {broadcast_data}")
# Broadcast to pose stream subscribers
sent_count = await self.connection_manager.broadcast(
data=broadcast_data,
stream_type="pose",
zone_ids=[pose_data.zone_id]
)
logger.info(f"✅ Broadcasted pose data for zone {pose_data.zone_id} to {sent_count} clients")
except Exception as e:
logger.error(f"Error broadcasting pose data: {e}")
async def handle_client_subscription(
self,
client_id: str,
subscription_config: Dict[str, Any]
):
"""Handle client subscription configuration."""
try:
# Store client subscription config
self.subscribers[client_id] = {
"zone_ids": subscription_config.get("zone_ids", []),
"min_confidence": subscription_config.get("min_confidence", 0.5),
"max_fps": subscription_config.get("max_fps", 30),
"include_metadata": subscription_config.get("include_metadata", True),
"stream_types": subscription_config.get("stream_types", ["pose_data"]),
"subscribed_at": datetime.utcnow()
}
logger.info(f"Updated subscription for client {client_id}")
# Send confirmation
confirmation = {
"type": "subscription_updated",
"client_id": client_id,
"config": self.subscribers[client_id],
"timestamp": datetime.utcnow().isoformat()
}
await self.connection_manager.send_to_client(client_id, confirmation)
except Exception as e:
logger.error(f"Error handling client subscription: {e}")
async def handle_client_disconnect(self, client_id: str):
"""Handle client disconnection."""
if client_id in self.subscribers:
del self.subscribers[client_id]
logger.info(f"Removed subscription for disconnected client {client_id}")
async def send_historical_data(
self,
client_id: str,
zone_id: str,
start_time: datetime,
end_time: datetime,
limit: int = 100
):
"""Send historical pose data to client."""
try:
# Get historical data from pose service
historical_data = await self.pose_service.get_historical_data(
zone_id=zone_id,
start_time=start_time,
end_time=end_time,
limit=limit
)
# Send data in chunks to avoid overwhelming the client
chunk_size = 10
for i in range(0, len(historical_data), chunk_size):
chunk = historical_data[i:i + chunk_size]
message = {
"type": "historical_data",
"zone_id": zone_id,
"chunk_index": i // chunk_size,
"total_chunks": (len(historical_data) + chunk_size - 1) // chunk_size,
"data": chunk,
"timestamp": datetime.utcnow().isoformat()
}
await self.connection_manager.send_to_client(client_id, message)
# Small delay between chunks
await asyncio.sleep(0.1)
# Send completion message
completion_message = {
"type": "historical_data_complete",
"zone_id": zone_id,
"total_records": len(historical_data),
"timestamp": datetime.utcnow().isoformat()
}
await self.connection_manager.send_to_client(client_id, completion_message)
except Exception as e:
logger.error(f"Error sending historical data: {e}")
# Send error message to client
error_message = {
"type": "error",
"message": f"Failed to retrieve historical data: {str(e)}",
"timestamp": datetime.utcnow().isoformat()
}
await self.connection_manager.send_to_client(client_id, error_message)
async def send_zone_statistics(self, client_id: str, zone_id: str):
"""Send zone statistics to client."""
try:
# Get zone statistics
stats = await self.pose_service.get_zone_statistics(zone_id)
message = {
"type": "zone_statistics",
"zone_id": zone_id,
"statistics": stats,
"timestamp": datetime.utcnow().isoformat()
}
await self.connection_manager.send_to_client(client_id, message)
except Exception as e:
logger.error(f"Error sending zone statistics: {e}")
async def broadcast_system_event(self, event_type: str, event_data: Dict[str, Any]):
"""Broadcast system events to all connected clients."""
try:
message = {
"type": "system_event",
"event_type": event_type,
"data": event_data,
"timestamp": datetime.utcnow().isoformat()
}
# Broadcast to all pose stream clients
sent_count = await self.connection_manager.broadcast(
data=message,
stream_type="pose"
)
logger.info(f"Broadcasted system event '{event_type}' to {sent_count} clients")
except Exception as e:
logger.error(f"Error broadcasting system event: {e}")
async def update_stream_config(self, config: Dict[str, Any]):
"""Update streaming configuration."""
try:
# Validate and update configuration
if "fps" in config:
fps = max(1, min(60, config["fps"]))
self.stream_config["fps"] = fps
if "min_confidence" in config:
confidence = max(0.0, min(1.0, config["min_confidence"]))
self.stream_config["min_confidence"] = confidence
if "include_metadata" in config:
self.stream_config["include_metadata"] = bool(config["include_metadata"])
if "buffer_size" in config:
buffer_size = max(10, min(1000, config["buffer_size"]))
self.stream_config["buffer_size"] = buffer_size
logger.info(f"Updated stream configuration: {self.stream_config}")
# Broadcast configuration update to clients
await self.broadcast_system_event("stream_config_updated", {
"new_config": self.stream_config
})
except Exception as e:
logger.error(f"Error updating stream configuration: {e}")
def get_stream_status(self) -> Dict[str, Any]:
"""Get current streaming status."""
return {
"is_streaming": self.is_streaming,
"config": self.stream_config,
"subscriber_count": len(self.subscribers),
"subscribers": {
client_id: {
"zone_ids": sub["zone_ids"],
"min_confidence": sub["min_confidence"],
"subscribed_at": sub["subscribed_at"].isoformat()
}
for client_id, sub in self.subscribers.items()
}
}
async def get_performance_metrics(self) -> Dict[str, Any]:
"""Get streaming performance metrics."""
try:
# Get connection manager metrics
conn_metrics = await self.connection_manager.get_metrics()
# Get pose service metrics
pose_metrics = await self.pose_service.get_performance_metrics()
return {
"streaming": {
"is_active": self.is_streaming,
"fps": self.stream_config["fps"],
"subscriber_count": len(self.subscribers)
},
"connections": conn_metrics,
"pose_service": pose_metrics,
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error getting performance metrics: {e}")
return {}
async def shutdown(self):
"""Shutdown pose stream handler."""
await self.stop_streaming()
self.subscribers.clear()
logger.info("Pose stream handler shutdown complete")

328
v1/src/app.py Normal file
View File

@@ -0,0 +1,328 @@
"""
FastAPI application factory and configuration
"""
import logging
import os
from contextlib import asynccontextmanager
from typing import Optional
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
from src.config.settings import Settings
from src.services.orchestrator import ServiceOrchestrator
from src.middleware.auth import AuthenticationMiddleware
from fastapi.middleware.cors import CORSMiddleware
from src.middleware.rate_limit import RateLimitMiddleware
from src.middleware.error_handler import ErrorHandlingMiddleware
from src.api.routers import pose, stream, health
from src.api.websocket.connection_manager import connection_manager
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager."""
logger.info("Starting WiFi-DensePose API...")
try:
# Get orchestrator from app state
orchestrator: ServiceOrchestrator = app.state.orchestrator
# Start connection manager
await connection_manager.start()
# Start all services
await orchestrator.start()
logger.info("WiFi-DensePose API started successfully")
yield
except Exception as e:
logger.error(f"Failed to start application: {e}")
raise
finally:
# Cleanup on shutdown
logger.info("Shutting down WiFi-DensePose API...")
# Shutdown connection manager
await connection_manager.shutdown()
if hasattr(app.state, 'orchestrator'):
await app.state.orchestrator.shutdown()
logger.info("WiFi-DensePose API shutdown complete")
def create_app(settings: Settings, orchestrator: ServiceOrchestrator) -> FastAPI:
"""Create and configure FastAPI application."""
# Create FastAPI application
app = FastAPI(
title=settings.app_name,
version=settings.version,
description="WiFi-based human pose estimation and activity recognition API",
docs_url=settings.docs_url if not settings.is_production else None,
redoc_url=settings.redoc_url if not settings.is_production else None,
openapi_url=settings.openapi_url if not settings.is_production else None,
lifespan=lifespan
)
# Store orchestrator in app state
app.state.orchestrator = orchestrator
app.state.settings = settings
# Add middleware in reverse order (last added = first executed)
setup_middleware(app, settings)
# Add exception handlers
setup_exception_handlers(app)
# Include routers
setup_routers(app, settings)
# Add root endpoints
setup_root_endpoints(app, settings)
return app
def setup_middleware(app: FastAPI, settings: Settings):
"""Setup application middleware."""
# Rate limiting middleware
if settings.enable_rate_limiting:
app.add_middleware(RateLimitMiddleware, settings=settings)
# Authentication middleware
if settings.enable_authentication:
app.add_middleware(AuthenticationMiddleware, settings=settings)
# CORS middleware
if settings.cors_enabled:
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=settings.cors_allow_credentials,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
allow_headers=["*"],
)
# Trusted host middleware for production
if settings.is_production:
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=settings.allowed_hosts
)
def setup_exception_handlers(app: FastAPI):
"""Setup global exception handlers."""
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
"""Handle HTTP exceptions."""
return JSONResponse(
status_code=exc.status_code,
content={
"error": {
"code": exc.status_code,
"message": exc.detail,
"type": "http_error",
"path": str(request.url.path)
}
}
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""Handle request validation errors."""
return JSONResponse(
status_code=422,
content={
"error": {
"code": 422,
"message": "Validation error",
"type": "validation_error",
"path": str(request.url.path),
"details": exc.errors()
}
}
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""Handle general exceptions."""
logger.error(f"Unhandled exception on {request.url.path}: {exc}", exc_info=True)
return JSONResponse(
status_code=500,
content={
"error": {
"code": 500,
"message": "Internal server error",
"type": "internal_error",
"path": str(request.url.path)
}
}
)
def setup_routers(app: FastAPI, settings: Settings):
"""Setup API routers."""
# Health check router (no prefix)
app.include_router(
health.router,
prefix="/health",
tags=["Health"]
)
# API routers with prefix
app.include_router(
pose.router,
prefix=f"{settings.api_prefix}/pose",
tags=["Pose Estimation"]
)
app.include_router(
stream.router,
prefix=f"{settings.api_prefix}/stream",
tags=["Streaming"]
)
def setup_root_endpoints(app: FastAPI, settings: Settings):
"""Setup root application endpoints."""
@app.get("/")
async def root():
"""Root endpoint with API information."""
return {
"name": settings.app_name,
"version": settings.version,
"environment": settings.environment,
"docs_url": settings.docs_url,
"api_prefix": settings.api_prefix,
"features": {
"authentication": settings.enable_authentication,
"rate_limiting": settings.enable_rate_limiting,
"websockets": settings.enable_websockets,
"real_time_processing": settings.enable_real_time_processing
}
}
@app.get(f"{settings.api_prefix}/info")
async def api_info(request: Request):
"""Get detailed API information."""
orchestrator: ServiceOrchestrator = request.app.state.orchestrator
return {
"api": {
"name": settings.app_name,
"version": settings.version,
"environment": settings.environment,
"prefix": settings.api_prefix
},
"services": await orchestrator.get_service_info(),
"features": {
"authentication": settings.enable_authentication,
"rate_limiting": settings.enable_rate_limiting,
"websockets": settings.enable_websockets,
"real_time_processing": settings.enable_real_time_processing,
"historical_data": settings.enable_historical_data
},
"limits": {
"rate_limit_requests": settings.rate_limit_requests,
"rate_limit_window": settings.rate_limit_window
}
}
@app.get(f"{settings.api_prefix}/status")
async def api_status(request: Request):
"""Get current API status."""
try:
orchestrator: ServiceOrchestrator = request.app.state.orchestrator
status = {
"api": {
"status": "healthy",
"version": settings.version,
"environment": settings.environment
},
"services": await orchestrator.get_service_status(),
"connections": await connection_manager.get_connection_stats()
}
return status
except Exception as e:
logger.error(f"Error getting API status: {e}")
return {
"api": {
"status": "error",
"error": str(e)
}
}
# Metrics endpoint (if enabled)
if settings.metrics_enabled:
@app.get(f"{settings.api_prefix}/metrics")
async def api_metrics(request: Request):
"""Get API metrics."""
try:
orchestrator: ServiceOrchestrator = request.app.state.orchestrator
metrics = {
"connections": await connection_manager.get_metrics(),
"services": await orchestrator.get_service_metrics()
}
return metrics
except Exception as e:
logger.error(f"Error getting metrics: {e}")
return {"error": str(e)}
# Development endpoints (only in development)
if settings.is_development and settings.enable_test_endpoints:
@app.get(f"{settings.api_prefix}/dev/config")
async def dev_config():
"""Get current configuration (development only)."""
return {
"settings": settings.dict(),
"environment_variables": dict(os.environ)
}
@app.post(f"{settings.api_prefix}/dev/reset")
async def dev_reset(request: Request):
"""Reset services (development only)."""
try:
orchestrator: ServiceOrchestrator = request.app.state.orchestrator
await orchestrator.reset_services()
return {"message": "Services reset successfully"}
except Exception as e:
logger.error(f"Error resetting services: {e}")
return {"error": str(e)}
# Create default app instance for uvicorn
def get_app() -> FastAPI:
"""Get the default application instance."""
from src.config.settings import get_settings
from src.services.orchestrator import ServiceOrchestrator
settings = get_settings()
orchestrator = ServiceOrchestrator(settings)
return create_app(settings, orchestrator)
# Default app instance for uvicorn
app = get_app()

621
v1/src/cli.py Normal file
View File

@@ -0,0 +1,621 @@
"""
Command-line interface for WiFi-DensePose API
"""
import asyncio
import click
import sys
from pathlib import Path
from typing import Optional
from src.config.settings import get_settings, load_settings_from_file
from src.logger import setup_logging, get_logger
from src.commands.start import start_command
from src.commands.stop import stop_command
from src.commands.status import status_command
# Get default settings and setup logging for CLI
settings = get_settings()
setup_logging(settings)
logger = get_logger(__name__)
def get_settings_with_config(config_file: Optional[str] = None):
"""Get settings with optional config file."""
if config_file:
return load_settings_from_file(config_file)
else:
return get_settings()
@click.group()
@click.option(
'--config',
'-c',
type=click.Path(exists=True),
help='Path to configuration file'
)
@click.option(
'--verbose',
'-v',
is_flag=True,
help='Enable verbose logging'
)
@click.option(
'--debug',
is_flag=True,
help='Enable debug mode'
)
@click.pass_context
def cli(ctx, config: Optional[str], verbose: bool, debug: bool):
"""WiFi-DensePose API Command Line Interface."""
# Ensure context object exists
ctx.ensure_object(dict)
# Store CLI options in context
ctx.obj['config_file'] = config
ctx.obj['verbose'] = verbose
ctx.obj['debug'] = debug
# Setup logging level
if debug:
import logging
logging.getLogger().setLevel(logging.DEBUG)
logger.debug("Debug mode enabled")
elif verbose:
import logging
logging.getLogger().setLevel(logging.INFO)
logger.info("Verbose mode enabled")
@cli.command()
@click.option(
'--host',
default='0.0.0.0',
help='Host to bind to (default: 0.0.0.0)'
)
@click.option(
'--port',
default=8000,
type=int,
help='Port to bind to (default: 8000)'
)
@click.option(
'--workers',
default=1,
type=int,
help='Number of worker processes (default: 1)'
)
@click.option(
'--reload',
is_flag=True,
help='Enable auto-reload for development'
)
@click.option(
'--daemon',
'-d',
is_flag=True,
help='Run as daemon (background process)'
)
@click.pass_context
def start(ctx, host: str, port: int, workers: int, reload: bool, daemon: bool):
"""Start the WiFi-DensePose API server."""
try:
# Get settings
settings = get_settings_with_config(ctx.obj.get('config_file'))
# Override settings with CLI options
if ctx.obj.get('debug'):
settings.debug = True
# Run start command
asyncio.run(start_command(
settings=settings,
host=host,
port=port,
workers=workers,
reload=reload,
daemon=daemon
))
except KeyboardInterrupt:
logger.info("Received interrupt signal, shutting down...")
sys.exit(0)
except Exception as e:
logger.error(f"Failed to start server: {e}")
sys.exit(1)
@cli.command()
@click.option(
'--force',
'-f',
is_flag=True,
help='Force stop without graceful shutdown'
)
@click.option(
'--timeout',
default=30,
type=int,
help='Timeout for graceful shutdown (default: 30 seconds)'
)
@click.pass_context
def stop(ctx, force: bool, timeout: int):
"""Stop the WiFi-DensePose API server."""
try:
# Get settings
settings = get_settings_with_config(ctx.obj.get('config_file'))
# Run stop command
asyncio.run(stop_command(
settings=settings,
force=force,
timeout=timeout
))
except Exception as e:
logger.error(f"Failed to stop server: {e}")
sys.exit(1)
@cli.command()
@click.option(
'--format',
type=click.Choice(['text', 'json']),
default='text',
help='Output format (default: text)'
)
@click.option(
'--detailed',
is_flag=True,
help='Show detailed status information'
)
@click.pass_context
def status(ctx, format: str, detailed: bool):
"""Show the status of the WiFi-DensePose API server."""
try:
# Get settings
settings = get_settings_with_config(ctx.obj.get('config_file'))
# Run status command
asyncio.run(status_command(
settings=settings,
output_format=format,
detailed=detailed
))
except Exception as e:
logger.error(f"Failed to get status: {e}")
sys.exit(1)
@cli.group()
def db():
"""Database management commands."""
pass
@db.command()
@click.option(
'--url',
help='Database URL (overrides config)'
)
@click.pass_context
def init(ctx, url: Optional[str]):
"""Initialize the database schema."""
try:
from src.database.connection import get_database_manager
from alembic.config import Config
from alembic import command
import os
# Get settings
settings = get_settings_with_config(ctx.obj.get('config_file'))
if url:
settings.database_url = url
# Initialize database
db_manager = get_database_manager(settings)
async def init_db():
await db_manager.initialize()
logger.info("Database initialized successfully")
asyncio.run(init_db())
# Run migrations if alembic.ini exists
alembic_ini_path = "alembic.ini"
if os.path.exists(alembic_ini_path):
try:
alembic_cfg = Config(alembic_ini_path)
# Set the database URL in the config
alembic_cfg.set_main_option("sqlalchemy.url", settings.get_database_url())
command.upgrade(alembic_cfg, "head")
logger.info("Database migrations applied successfully")
except Exception as migration_error:
logger.warning(f"Migration failed, but database is initialized: {migration_error}")
else:
logger.info("No alembic.ini found, skipping migrations")
except Exception as e:
logger.error(f"Failed to initialize database: {e}")
sys.exit(1)
@db.command()
@click.option(
'--revision',
default='head',
help='Target revision (default: head)'
)
@click.pass_context
def migrate(ctx, revision: str):
"""Run database migrations."""
try:
from alembic.config import Config
from alembic import command
# Run migrations
alembic_cfg = Config("alembic.ini")
command.upgrade(alembic_cfg, revision)
logger.info(f"Database migrated to revision: {revision}")
except Exception as e:
logger.error(f"Failed to run migrations: {e}")
sys.exit(1)
@db.command()
@click.option(
'--steps',
default=1,
type=int,
help='Number of steps to rollback (default: 1)'
)
@click.pass_context
def rollback(ctx, steps: int):
"""Rollback database migrations."""
try:
from alembic.config import Config
from alembic import command
# Rollback migrations
alembic_cfg = Config("alembic.ini")
command.downgrade(alembic_cfg, f"-{steps}")
logger.info(f"Database rolled back {steps} step(s)")
except Exception as e:
logger.error(f"Failed to rollback database: {e}")
sys.exit(1)
@cli.group()
def tasks():
"""Background task management commands."""
pass
@tasks.command()
@click.option(
'--task',
type=click.Choice(['cleanup', 'monitoring', 'backup']),
help='Specific task to run'
)
@click.pass_context
def run(ctx, task: Optional[str]):
"""Run background tasks."""
try:
from src.tasks.cleanup import get_cleanup_manager
from src.tasks.monitoring import get_monitoring_manager
from src.tasks.backup import get_backup_manager
# Get settings
settings = get_settings_with_config(ctx.obj.get('config_file'))
async def run_tasks():
if task == 'cleanup' or task is None:
cleanup_manager = get_cleanup_manager(settings)
result = await cleanup_manager.run_all_tasks()
logger.info(f"Cleanup result: {result}")
if task == 'monitoring' or task is None:
monitoring_manager = get_monitoring_manager(settings)
result = await monitoring_manager.run_all_tasks()
logger.info(f"Monitoring result: {result}")
if task == 'backup' or task is None:
backup_manager = get_backup_manager(settings)
result = await backup_manager.run_all_tasks()
logger.info(f"Backup result: {result}")
asyncio.run(run_tasks())
except Exception as e:
logger.error(f"Failed to run tasks: {e}")
sys.exit(1)
@tasks.command()
@click.pass_context
def status(ctx):
"""Show background task status."""
try:
from src.tasks.cleanup import get_cleanup_manager
from src.tasks.monitoring import get_monitoring_manager
from src.tasks.backup import get_backup_manager
import json
# Get settings
settings = get_settings_with_config(ctx.obj.get('config_file'))
# Get task managers
cleanup_manager = get_cleanup_manager(settings)
monitoring_manager = get_monitoring_manager(settings)
backup_manager = get_backup_manager(settings)
# Collect status
status_data = {
"cleanup": cleanup_manager.get_stats(),
"monitoring": monitoring_manager.get_stats(),
"backup": backup_manager.get_stats(),
}
# Print status
click.echo(json.dumps(status_data, indent=2))
except Exception as e:
logger.error(f"Failed to get task status: {e}")
sys.exit(1)
@cli.group()
def config():
"""Configuration management commands."""
pass
@config.command()
@click.pass_context
def show(ctx):
"""Show current configuration."""
try:
import json
# Get settings
settings = get_settings_with_config(ctx.obj.get('config_file'))
# Convert settings to dict (excluding sensitive data)
config_dict = {
"app_name": settings.app_name,
"version": settings.version,
"environment": settings.environment,
"debug": settings.debug,
"host": settings.host,
"port": settings.port,
"api_prefix": settings.api_prefix,
"docs_url": settings.docs_url,
"redoc_url": settings.redoc_url,
"log_level": settings.log_level,
"log_file": settings.log_file,
"data_storage_path": settings.data_storage_path,
"model_storage_path": settings.model_storage_path,
"temp_storage_path": settings.temp_storage_path,
"wifi_interface": settings.wifi_interface,
"csi_buffer_size": settings.csi_buffer_size,
"pose_confidence_threshold": settings.pose_confidence_threshold,
"stream_fps": settings.stream_fps,
"websocket_ping_interval": settings.websocket_ping_interval,
"features": {
"authentication": settings.enable_authentication,
"rate_limiting": settings.enable_rate_limiting,
"websockets": settings.enable_websockets,
"historical_data": settings.enable_historical_data,
"real_time_processing": settings.enable_real_time_processing,
"cors": settings.cors_enabled,
}
}
click.echo(json.dumps(config_dict, indent=2))
except Exception as e:
logger.error(f"Failed to show configuration: {e}")
sys.exit(1)
@config.command()
@click.pass_context
def validate(ctx):
"""Validate configuration."""
try:
# Get settings
settings = get_settings_with_config(ctx.obj.get('config_file'))
# Validate database connection
from src.database.connection import get_database_manager
async def validate_config():
db_manager = get_database_manager(settings)
try:
await db_manager.test_connection()
click.echo("✓ Database connection: OK")
except Exception as e:
click.echo(f"✗ Database connection: FAILED - {e}")
return False
# Validate Redis connection (if configured)
redis_url = settings.get_redis_url()
if redis_url:
try:
import redis.asyncio as redis
redis_client = redis.from_url(redis_url)
await redis_client.ping()
click.echo("✓ Redis connection: OK")
await redis_client.close()
except Exception as e:
click.echo(f"✗ Redis connection: FAILED - {e}")
return False
else:
click.echo("- Redis connection: NOT CONFIGURED")
# Validate directories
from pathlib import Path
directories = [
("Data storage", settings.data_storage_path),
("Model storage", settings.model_storage_path),
("Temp storage", settings.temp_storage_path),
]
for name, directory in directories:
path = Path(directory)
if path.exists() and path.is_dir():
click.echo(f"{name}: OK")
else:
try:
path.mkdir(parents=True, exist_ok=True)
click.echo(f"{name}: CREATED - {directory}")
except Exception as e:
click.echo(f"{name}: FAILED TO CREATE - {directory} ({e})")
return False
click.echo("\n✓ Configuration validation passed")
return True
result = asyncio.run(validate_config())
if not result:
sys.exit(1)
except Exception as e:
logger.error(f"Failed to validate configuration: {e}")
sys.exit(1)
@config.command()
@click.option(
'--format',
type=click.Choice(['text', 'json']),
default='text',
help='Output format (default: text)'
)
@click.pass_context
def failsafe(ctx, format: str):
"""Show failsafe status and configuration."""
try:
import json
from src.database.connection import get_database_manager
# Get settings
settings = get_settings_with_config(ctx.obj.get('config_file'))
async def check_failsafe_status():
db_manager = get_database_manager(settings)
# Initialize database to check current state
try:
await db_manager.initialize()
except Exception as e:
logger.warning(f"Database initialization failed: {e}")
# Collect failsafe status
failsafe_status = {
"database": {
"failsafe_enabled": settings.enable_database_failsafe,
"using_sqlite_fallback": db_manager.is_using_sqlite_fallback(),
"sqlite_fallback_path": settings.sqlite_fallback_path,
"primary_database_url": settings.get_database_url() if not db_manager.is_using_sqlite_fallback() else None,
},
"redis": {
"failsafe_enabled": settings.enable_redis_failsafe,
"redis_enabled": settings.redis_enabled,
"redis_required": settings.redis_required,
"redis_available": db_manager.is_redis_available(),
"redis_url": settings.get_redis_url() if settings.redis_enabled else None,
},
"overall_status": "healthy"
}
# Determine overall status
if failsafe_status["database"]["using_sqlite_fallback"] or not failsafe_status["redis"]["redis_available"]:
failsafe_status["overall_status"] = "degraded"
# Output results
if format == 'json':
click.echo(json.dumps(failsafe_status, indent=2))
else:
click.echo("=== Failsafe Status ===\n")
# Database status
click.echo("Database:")
if failsafe_status["database"]["using_sqlite_fallback"]:
click.echo(" ⚠️ Using SQLite fallback database")
click.echo(f" Path: {failsafe_status['database']['sqlite_fallback_path']}")
else:
click.echo(" ✓ Using primary database (PostgreSQL)")
click.echo(f" Failsafe enabled: {'Yes' if failsafe_status['database']['failsafe_enabled'] else 'No'}")
# Redis status
click.echo("\nRedis:")
if not failsafe_status["redis"]["redis_enabled"]:
click.echo(" - Redis disabled")
elif not failsafe_status["redis"]["redis_available"]:
click.echo(" ⚠️ Redis unavailable (failsafe active)")
else:
click.echo(" ✓ Redis available")
click.echo(f" Failsafe enabled: {'Yes' if failsafe_status['redis']['failsafe_enabled'] else 'No'}")
click.echo(f" Required: {'Yes' if failsafe_status['redis']['redis_required'] else 'No'}")
# Overall status
status_icon = "" if failsafe_status["overall_status"] == "healthy" else "⚠️"
click.echo(f"\nOverall Status: {status_icon} {failsafe_status['overall_status'].upper()}")
if failsafe_status["overall_status"] == "degraded":
click.echo("\nNote: System is running in degraded mode using failsafe configurations.")
asyncio.run(check_failsafe_status())
except Exception as e:
logger.error(f"Failed to check failsafe status: {e}")
sys.exit(1)
@cli.command()
def version():
"""Show version information."""
try:
from src.config.settings import get_settings
settings = get_settings()
click.echo(f"WiFi-DensePose API v{settings.version}")
click.echo(f"Environment: {settings.environment}")
click.echo(f"Python: {sys.version}")
except Exception as e:
logger.error(f"Failed to get version: {e}")
sys.exit(1)
def create_cli(orchestrator=None):
"""Create CLI interface for the application."""
return cli
if __name__ == '__main__':
cli()

359
v1/src/commands/start.py Normal file
View File

@@ -0,0 +1,359 @@
"""
Start command implementation for WiFi-DensePose API
"""
import asyncio
import os
import signal
import sys
import uvicorn
from pathlib import Path
from typing import Optional
from src.config.settings import Settings
from src.logger import get_logger
logger = get_logger(__name__)
async def start_command(
settings: Settings,
host: str = "0.0.0.0",
port: int = 8000,
workers: int = 1,
reload: bool = False,
daemon: bool = False
) -> None:
"""Start the WiFi-DensePose API server."""
logger.info(f"Starting WiFi-DensePose API server...")
logger.info(f"Environment: {settings.environment}")
logger.info(f"Debug mode: {settings.debug}")
logger.info(f"Host: {host}")
logger.info(f"Port: {port}")
logger.info(f"Workers: {workers}")
# Validate settings
await _validate_startup_requirements(settings)
# Setup signal handlers
_setup_signal_handlers()
# Create PID file if running as daemon
pid_file = None
if daemon:
pid_file = _create_pid_file(settings)
try:
# Initialize database
await _initialize_database(settings)
# Start background tasks
background_tasks = await _start_background_tasks(settings)
# Configure uvicorn
uvicorn_config = {
"app": "src.app:app",
"host": host,
"port": port,
"reload": reload,
"workers": workers if not reload else 1, # Reload doesn't work with multiple workers
"log_level": "debug" if settings.debug else "info",
"access_log": True,
"use_colors": not daemon,
}
if daemon:
# Run as daemon
await _run_as_daemon(uvicorn_config, pid_file)
else:
# Run in foreground
await _run_server(uvicorn_config)
except KeyboardInterrupt:
logger.info("Received interrupt signal, shutting down...")
except Exception as e:
logger.error(f"Server startup failed: {e}")
raise
finally:
# Cleanup
if pid_file and pid_file.exists():
pid_file.unlink()
# Stop background tasks
if 'background_tasks' in locals():
await _stop_background_tasks(background_tasks)
async def _validate_startup_requirements(settings: Settings) -> None:
"""Validate that all startup requirements are met."""
logger.info("Validating startup requirements...")
# Check database connection
try:
from src.database.connection import get_database_manager
db_manager = get_database_manager(settings)
await db_manager.test_connection()
logger.info("✓ Database connection validated")
except Exception as e:
logger.error(f"✗ Database connection failed: {e}")
raise
# Check Redis connection (if enabled)
if settings.redis_enabled:
try:
redis_stats = await db_manager.get_connection_stats()
if "redis" in redis_stats and not redis_stats["redis"].get("error"):
logger.info("✓ Redis connection validated")
else:
logger.warning("⚠ Redis connection failed, continuing without Redis")
except Exception as e:
logger.warning(f"⚠ Redis connection failed: {e}, continuing without Redis")
# Check required directories
directories = [
("Log directory", settings.log_directory),
("Backup directory", settings.backup_directory),
]
for name, directory in directories:
path = Path(directory)
path.mkdir(parents=True, exist_ok=True)
logger.info(f"{name} ready: {directory}")
logger.info("All startup requirements validated")
async def _initialize_database(settings: Settings) -> None:
"""Initialize database connection and run migrations if needed."""
logger.info("Initializing database...")
try:
from src.database.connection import get_database_manager
db_manager = get_database_manager(settings)
await db_manager.initialize()
logger.info("Database initialized successfully")
except Exception as e:
logger.error(f"Database initialization failed: {e}")
raise
async def _start_background_tasks(settings: Settings) -> dict:
"""Start background tasks."""
logger.info("Starting background tasks...")
tasks = {}
try:
# Start cleanup task
if settings.cleanup_interval_seconds > 0:
from src.tasks.cleanup import run_periodic_cleanup
cleanup_task = asyncio.create_task(run_periodic_cleanup(settings))
tasks['cleanup'] = cleanup_task
logger.info("✓ Cleanup task started")
# Start monitoring task
if settings.monitoring_interval_seconds > 0:
from src.tasks.monitoring import run_periodic_monitoring
monitoring_task = asyncio.create_task(run_periodic_monitoring(settings))
tasks['monitoring'] = monitoring_task
logger.info("✓ Monitoring task started")
# Start backup task
if settings.backup_interval_seconds > 0:
from src.tasks.backup import run_periodic_backup
backup_task = asyncio.create_task(run_periodic_backup(settings))
tasks['backup'] = backup_task
logger.info("✓ Backup task started")
logger.info(f"Started {len(tasks)} background tasks")
return tasks
except Exception as e:
logger.error(f"Failed to start background tasks: {e}")
# Cancel any started tasks
for task in tasks.values():
task.cancel()
raise
async def _stop_background_tasks(tasks: dict) -> None:
"""Stop background tasks gracefully."""
logger.info("Stopping background tasks...")
# Cancel all tasks
for name, task in tasks.items():
if not task.done():
logger.info(f"Stopping {name} task...")
task.cancel()
# Wait for tasks to complete
if tasks:
await asyncio.gather(*tasks.values(), return_exceptions=True)
logger.info("Background tasks stopped")
def _setup_signal_handlers() -> None:
"""Setup signal handlers for graceful shutdown."""
def signal_handler(signum, frame):
logger.info(f"Received signal {signum}, initiating graceful shutdown...")
# The actual shutdown will be handled by the main loop
sys.exit(0)
# Setup signal handlers
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
if hasattr(signal, 'SIGHUP'):
signal.signal(signal.SIGHUP, signal_handler)
def _create_pid_file(settings: Settings) -> Path:
"""Create PID file for daemon mode."""
pid_file = Path(settings.log_directory) / "wifi-densepose-api.pid"
# Check if PID file already exists
if pid_file.exists():
try:
with open(pid_file, 'r') as f:
old_pid = int(f.read().strip())
# Check if process is still running
try:
os.kill(old_pid, 0) # Signal 0 just checks if process exists
logger.error(f"Server already running with PID {old_pid}")
sys.exit(1)
except OSError:
# Process doesn't exist, remove stale PID file
pid_file.unlink()
logger.info("Removed stale PID file")
except (ValueError, IOError):
# Invalid PID file, remove it
pid_file.unlink()
logger.info("Removed invalid PID file")
# Write current PID
with open(pid_file, 'w') as f:
f.write(str(os.getpid()))
logger.info(f"Created PID file: {pid_file}")
return pid_file
async def _run_server(config: dict) -> None:
"""Run the server in foreground mode."""
logger.info("Starting server in foreground mode...")
# Create uvicorn server
server = uvicorn.Server(uvicorn.Config(**config))
# Run server
await server.serve()
async def _run_as_daemon(config: dict, pid_file: Path) -> None:
"""Run the server as a daemon."""
logger.info("Starting server in daemon mode...")
# Fork process
try:
pid = os.fork()
if pid > 0:
# Parent process
logger.info(f"Server started as daemon with PID {pid}")
sys.exit(0)
except OSError as e:
logger.error(f"Fork failed: {e}")
sys.exit(1)
# Child process continues
# Decouple from parent environment
os.chdir("/")
os.setsid()
os.umask(0)
# Second fork
try:
pid = os.fork()
if pid > 0:
# Exit second parent
sys.exit(0)
except OSError as e:
logger.error(f"Second fork failed: {e}")
sys.exit(1)
# Update PID file with daemon PID
with open(pid_file, 'w') as f:
f.write(str(os.getpid()))
# Redirect standard file descriptors
sys.stdout.flush()
sys.stderr.flush()
# Redirect stdin, stdout, stderr to /dev/null
with open('/dev/null', 'r') as f:
os.dup2(f.fileno(), sys.stdin.fileno())
with open('/dev/null', 'w') as f:
os.dup2(f.fileno(), sys.stdout.fileno())
os.dup2(f.fileno(), sys.stderr.fileno())
# Create uvicorn server
server = uvicorn.Server(uvicorn.Config(**config))
# Run server
await server.serve()
def get_server_status(settings: Settings) -> dict:
"""Get current server status."""
pid_file = Path(settings.log_directory) / "wifi-densepose-api.pid"
status = {
"running": False,
"pid": None,
"pid_file": str(pid_file),
"pid_file_exists": pid_file.exists(),
}
if pid_file.exists():
try:
with open(pid_file, 'r') as f:
pid = int(f.read().strip())
status["pid"] = pid
# Check if process is running
try:
os.kill(pid, 0) # Signal 0 just checks if process exists
status["running"] = True
except OSError:
# Process doesn't exist
status["running"] = False
except (ValueError, IOError):
# Invalid PID file
status["running"] = False
return status

501
v1/src/commands/status.py Normal file
View File

@@ -0,0 +1,501 @@
"""
Status command implementation for WiFi-DensePose API
"""
import asyncio
import json
import psutil
import time
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, Any, Optional
from src.config.settings import Settings
from src.logger import get_logger
logger = get_logger(__name__)
async def status_command(
settings: Settings,
output_format: str = "text",
detailed: bool = False
) -> None:
"""Show the status of the WiFi-DensePose API server."""
logger.debug("Gathering server status information...")
try:
# Collect status information
status_data = await _collect_status_data(settings, detailed)
# Output status
if output_format == "json":
print(json.dumps(status_data, indent=2, default=str))
else:
_print_text_status(status_data, detailed)
except Exception as e:
logger.error(f"Failed to get status: {e}")
raise
async def _collect_status_data(settings: Settings, detailed: bool) -> Dict[str, Any]:
"""Collect comprehensive status data."""
status_data = {
"timestamp": datetime.utcnow().isoformat(),
"server": await _get_server_status(settings),
"system": _get_system_status(),
"configuration": _get_configuration_status(settings),
}
if detailed:
status_data.update({
"database": await _get_database_status(settings),
"background_tasks": await _get_background_tasks_status(settings),
"resources": _get_resource_usage(),
"health": await _get_health_status(settings),
})
return status_data
async def _get_server_status(settings: Settings) -> Dict[str, Any]:
"""Get server process status."""
from src.commands.stop import get_server_status
status = get_server_status(settings)
server_info = {
"running": status["running"],
"pid": status["pid"],
"pid_file": status["pid_file"],
"pid_file_exists": status["pid_file_exists"],
}
if status["running"] and status["pid"]:
try:
# Get process information
process = psutil.Process(status["pid"])
server_info.update({
"start_time": datetime.fromtimestamp(process.create_time()).isoformat(),
"uptime_seconds": time.time() - process.create_time(),
"memory_usage_mb": process.memory_info().rss / (1024 * 1024),
"cpu_percent": process.cpu_percent(),
"status": process.status(),
"num_threads": process.num_threads(),
"connections": len(process.connections()) if hasattr(process, 'connections') else None,
})
except (psutil.NoSuchProcess, psutil.AccessDenied) as e:
server_info["error"] = f"Cannot access process info: {e}"
return server_info
def _get_system_status() -> Dict[str, Any]:
"""Get system status information."""
uname_info = psutil.os.uname()
return {
"hostname": uname_info.nodename,
"platform": uname_info.sysname,
"architecture": uname_info.machine,
"python_version": f"{psutil.sys.version_info.major}.{psutil.sys.version_info.minor}.{psutil.sys.version_info.micro}",
"boot_time": datetime.fromtimestamp(psutil.boot_time()).isoformat(),
"uptime_seconds": time.time() - psutil.boot_time(),
}
def _get_configuration_status(settings: Settings) -> Dict[str, Any]:
"""Get configuration status."""
return {
"environment": settings.environment,
"debug": settings.debug,
"version": settings.version,
"host": settings.host,
"port": settings.port,
"database_configured": bool(settings.database_url or (settings.db_host and settings.db_name)),
"redis_enabled": settings.redis_enabled,
"monitoring_enabled": settings.monitoring_interval_seconds > 0,
"cleanup_enabled": settings.cleanup_interval_seconds > 0,
"backup_enabled": settings.backup_interval_seconds > 0,
}
async def _get_database_status(settings: Settings) -> Dict[str, Any]:
"""Get database status."""
db_status = {
"connected": False,
"connection_pool": None,
"tables": {},
"error": None,
}
try:
from src.database.connection import get_database_manager
db_manager = get_database_manager(settings)
# Test connection
await db_manager.test_connection()
db_status["connected"] = True
# Get connection stats
connection_stats = await db_manager.get_connection_stats()
db_status["connection_pool"] = connection_stats
# Get table counts
async with db_manager.get_async_session() as session:
from sqlalchemy import text, func
from src.database.models import Device, Session, CSIData, PoseDetection, SystemMetric, AuditLog
tables = {
"devices": Device,
"sessions": Session,
"csi_data": CSIData,
"pose_detections": PoseDetection,
"system_metrics": SystemMetric,
"audit_logs": AuditLog,
}
for table_name, model in tables.items():
try:
result = await session.execute(
text(f"SELECT COUNT(*) FROM {table_name}")
)
count = result.scalar()
db_status["tables"][table_name] = {"count": count}
except Exception as e:
db_status["tables"][table_name] = {"error": str(e)}
except Exception as e:
db_status["error"] = str(e)
return db_status
async def _get_background_tasks_status(settings: Settings) -> Dict[str, Any]:
"""Get background tasks status."""
tasks_status = {}
try:
# Cleanup tasks
from src.tasks.cleanup import get_cleanup_manager
cleanup_manager = get_cleanup_manager(settings)
tasks_status["cleanup"] = cleanup_manager.get_stats()
except Exception as e:
tasks_status["cleanup"] = {"error": str(e)}
try:
# Monitoring tasks
from src.tasks.monitoring import get_monitoring_manager
monitoring_manager = get_monitoring_manager(settings)
tasks_status["monitoring"] = monitoring_manager.get_stats()
except Exception as e:
tasks_status["monitoring"] = {"error": str(e)}
try:
# Backup tasks
from src.tasks.backup import get_backup_manager
backup_manager = get_backup_manager(settings)
tasks_status["backup"] = backup_manager.get_stats()
except Exception as e:
tasks_status["backup"] = {"error": str(e)}
return tasks_status
def _get_resource_usage() -> Dict[str, Any]:
"""Get system resource usage."""
# CPU usage
cpu_percent = psutil.cpu_percent(interval=1)
cpu_count = psutil.cpu_count()
# Memory usage
memory = psutil.virtual_memory()
swap = psutil.swap_memory()
# Disk usage
disk = psutil.disk_usage('/')
# Network I/O
network = psutil.net_io_counters()
return {
"cpu": {
"usage_percent": cpu_percent,
"count": cpu_count,
},
"memory": {
"total_mb": memory.total / (1024 * 1024),
"used_mb": memory.used / (1024 * 1024),
"available_mb": memory.available / (1024 * 1024),
"usage_percent": memory.percent,
},
"swap": {
"total_mb": swap.total / (1024 * 1024),
"used_mb": swap.used / (1024 * 1024),
"usage_percent": swap.percent,
},
"disk": {
"total_gb": disk.total / (1024 * 1024 * 1024),
"used_gb": disk.used / (1024 * 1024 * 1024),
"free_gb": disk.free / (1024 * 1024 * 1024),
"usage_percent": (disk.used / disk.total) * 100,
},
"network": {
"bytes_sent": network.bytes_sent,
"bytes_recv": network.bytes_recv,
"packets_sent": network.packets_sent,
"packets_recv": network.packets_recv,
} if network else None,
}
async def _get_health_status(settings: Settings) -> Dict[str, Any]:
"""Get overall health status."""
health = {
"status": "healthy",
"checks": {},
"issues": [],
}
# Check database health
try:
from src.database.connection import get_database_manager
db_manager = get_database_manager(settings)
await db_manager.test_connection()
health["checks"]["database"] = "healthy"
except Exception as e:
health["checks"]["database"] = "unhealthy"
health["issues"].append(f"Database connection failed: {e}")
health["status"] = "unhealthy"
# Check disk space
disk = psutil.disk_usage('/')
disk_usage_percent = (disk.used / disk.total) * 100
if disk_usage_percent > 90:
health["checks"]["disk_space"] = "critical"
health["issues"].append(f"Disk usage critical: {disk_usage_percent:.1f}%")
health["status"] = "critical"
elif disk_usage_percent > 80:
health["checks"]["disk_space"] = "warning"
health["issues"].append(f"Disk usage high: {disk_usage_percent:.1f}%")
if health["status"] == "healthy":
health["status"] = "warning"
else:
health["checks"]["disk_space"] = "healthy"
# Check memory usage
memory = psutil.virtual_memory()
if memory.percent > 90:
health["checks"]["memory"] = "critical"
health["issues"].append(f"Memory usage critical: {memory.percent:.1f}%")
health["status"] = "critical"
elif memory.percent > 80:
health["checks"]["memory"] = "warning"
health["issues"].append(f"Memory usage high: {memory.percent:.1f}%")
if health["status"] == "healthy":
health["status"] = "warning"
else:
health["checks"]["memory"] = "healthy"
# Check log directory
log_dir = Path(settings.log_directory)
if log_dir.exists() and log_dir.is_dir():
health["checks"]["log_directory"] = "healthy"
else:
health["checks"]["log_directory"] = "unhealthy"
health["issues"].append(f"Log directory not accessible: {log_dir}")
health["status"] = "unhealthy"
# Check backup directory
backup_dir = Path(settings.backup_directory)
if backup_dir.exists() and backup_dir.is_dir():
health["checks"]["backup_directory"] = "healthy"
else:
health["checks"]["backup_directory"] = "unhealthy"
health["issues"].append(f"Backup directory not accessible: {backup_dir}")
health["status"] = "unhealthy"
return health
def _print_text_status(status_data: Dict[str, Any], detailed: bool) -> None:
"""Print status in human-readable text format."""
print("=" * 60)
print("WiFi-DensePose API Server Status")
print("=" * 60)
print(f"Timestamp: {status_data['timestamp']}")
print()
# Server status
server = status_data["server"]
print("🖥️ Server Status:")
if server["running"]:
print(f" ✅ Running (PID: {server['pid']})")
if "start_time" in server:
uptime = timedelta(seconds=int(server["uptime_seconds"]))
print(f" ⏱️ Uptime: {uptime}")
print(f" 💾 Memory: {server['memory_usage_mb']:.1f} MB")
print(f" 🔧 CPU: {server['cpu_percent']:.1f}%")
print(f" 🧵 Threads: {server['num_threads']}")
else:
print(" ❌ Not running")
if server["pid_file_exists"]:
print(" ⚠️ Stale PID file exists")
print()
# System status
system = status_data["system"]
print("🖥️ System:")
print(f" Hostname: {system['hostname']}")
print(f" Platform: {system['platform']} ({system['architecture']})")
print(f" Python: {system['python_version']}")
uptime = timedelta(seconds=int(system["uptime_seconds"]))
print(f" Uptime: {uptime}")
print()
# Configuration
config = status_data["configuration"]
print("⚙️ Configuration:")
print(f" Environment: {config['environment']}")
print(f" Debug: {config['debug']}")
print(f" API Version: {config['version']}")
print(f" Listen: {config['host']}:{config['port']}")
print(f" Database: {'' if config['database_configured'] else ''}")
print(f" Redis: {'' if config['redis_enabled'] else ''}")
print(f" Monitoring: {'' if config['monitoring_enabled'] else ''}")
print(f" Cleanup: {'' if config['cleanup_enabled'] else ''}")
print(f" Backup: {'' if config['backup_enabled'] else ''}")
print()
if detailed:
# Database status
if "database" in status_data:
db = status_data["database"]
print("🗄️ Database:")
if db["connected"]:
print(" ✅ Connected")
if "tables" in db:
print(" 📊 Table counts:")
for table, info in db["tables"].items():
if "count" in info:
print(f" {table}: {info['count']:,}")
else:
print(f" {table}: Error - {info.get('error', 'Unknown')}")
else:
print(f" ❌ Not connected: {db.get('error', 'Unknown error')}")
print()
# Background tasks
if "background_tasks" in status_data:
tasks = status_data["background_tasks"]
print("🔄 Background Tasks:")
for task_name, task_info in tasks.items():
if "error" in task_info:
print(f"{task_name}: {task_info['error']}")
else:
manager_info = task_info.get("manager", {})
print(f" 📋 {task_name}:")
print(f" Running: {manager_info.get('running', 'Unknown')}")
print(f" Last run: {manager_info.get('last_run', 'Never')}")
print(f" Run count: {manager_info.get('run_count', 0)}")
print()
# Resource usage
if "resources" in status_data:
resources = status_data["resources"]
print("📊 Resource Usage:")
cpu = resources["cpu"]
print(f" 🔧 CPU: {cpu['usage_percent']:.1f}% ({cpu['count']} cores)")
memory = resources["memory"]
print(f" 💾 Memory: {memory['usage_percent']:.1f}% "
f"({memory['used_mb']:.0f}/{memory['total_mb']:.0f} MB)")
disk = resources["disk"]
print(f" 💿 Disk: {disk['usage_percent']:.1f}% "
f"({disk['used_gb']:.1f}/{disk['total_gb']:.1f} GB)")
print()
# Health status
if "health" in status_data:
health = status_data["health"]
print("🏥 Health Status:")
status_emoji = {
"healthy": "",
"warning": "⚠️",
"critical": "",
"unhealthy": ""
}
print(f" Overall: {status_emoji.get(health['status'], '')} {health['status'].upper()}")
if health["issues"]:
print(" Issues:")
for issue in health["issues"]:
print(f"{issue}")
print(" Checks:")
for check, status in health["checks"].items():
emoji = status_emoji.get(status, "")
print(f" {emoji} {check}: {status}")
print()
print("=" * 60)
def get_quick_status(settings: Settings) -> str:
"""Get a quick one-line status."""
from src.commands.stop import get_server_status
status = get_server_status(settings)
if status["running"]:
return f"✅ Running (PID: {status['pid']})"
elif status["pid_file_exists"]:
return "⚠️ Not running (stale PID file)"
else:
return "❌ Not running"
async def check_health(settings: Settings) -> bool:
"""Quick health check - returns True if healthy."""
try:
status_data = await _collect_status_data(settings, detailed=True)
# Check if server is running
if not status_data["server"]["running"]:
return False
# Check health status
if "health" in status_data:
health_status = status_data["health"]["status"]
return health_status in ["healthy", "warning"]
return True
except Exception:
return False

294
v1/src/commands/stop.py Normal file
View File

@@ -0,0 +1,294 @@
"""
Stop command implementation for WiFi-DensePose API
"""
import asyncio
import os
import signal
import time
from pathlib import Path
from typing import Optional
from src.config.settings import Settings
from src.logger import get_logger
logger = get_logger(__name__)
async def stop_command(
settings: Settings,
force: bool = False,
timeout: int = 30
) -> None:
"""Stop the WiFi-DensePose API server."""
logger.info("Stopping WiFi-DensePose API server...")
# Get server status
status = get_server_status(settings)
if not status["running"]:
if status["pid_file_exists"]:
logger.info("Server is not running, but PID file exists. Cleaning up...")
_cleanup_pid_file(settings)
else:
logger.info("Server is not running")
return
pid = status["pid"]
logger.info(f"Found running server with PID {pid}")
try:
if force:
await _force_stop_server(pid, settings)
else:
await _graceful_stop_server(pid, timeout, settings)
except Exception as e:
logger.error(f"Failed to stop server: {e}")
raise
async def _graceful_stop_server(pid: int, timeout: int, settings: Settings) -> None:
"""Stop server gracefully with timeout."""
logger.info(f"Attempting graceful shutdown (timeout: {timeout}s)...")
try:
# Send SIGTERM for graceful shutdown
os.kill(pid, signal.SIGTERM)
logger.info("Sent SIGTERM signal")
# Wait for process to terminate
start_time = time.time()
while time.time() - start_time < timeout:
try:
# Check if process is still running
os.kill(pid, 0)
await asyncio.sleep(1)
except OSError:
# Process has terminated
logger.info("Server stopped gracefully")
_cleanup_pid_file(settings)
return
# Timeout reached, force kill
logger.warning(f"Graceful shutdown timeout ({timeout}s) reached, forcing stop...")
await _force_stop_server(pid, settings)
except OSError as e:
if e.errno == 3: # No such process
logger.info("Process already terminated")
_cleanup_pid_file(settings)
else:
logger.error(f"Failed to send signal to process {pid}: {e}")
raise
async def _force_stop_server(pid: int, settings: Settings) -> None:
"""Force stop server immediately."""
logger.info("Force stopping server...")
try:
# Send SIGKILL for immediate termination
os.kill(pid, signal.SIGKILL)
logger.info("Sent SIGKILL signal")
# Wait a moment for process to die
await asyncio.sleep(2)
# Verify process is dead
try:
os.kill(pid, 0)
logger.error(f"Process {pid} still running after SIGKILL")
except OSError:
logger.info("Server force stopped")
except OSError as e:
if e.errno == 3: # No such process
logger.info("Process already terminated")
else:
logger.error(f"Failed to force kill process {pid}: {e}")
raise
finally:
_cleanup_pid_file(settings)
def _cleanup_pid_file(settings: Settings) -> None:
"""Clean up PID file."""
pid_file = Path(settings.log_directory) / "wifi-densepose-api.pid"
if pid_file.exists():
try:
pid_file.unlink()
logger.info("Cleaned up PID file")
except Exception as e:
logger.warning(f"Failed to remove PID file: {e}")
def get_server_status(settings: Settings) -> dict:
"""Get current server status."""
pid_file = Path(settings.log_directory) / "wifi-densepose-api.pid"
status = {
"running": False,
"pid": None,
"pid_file": str(pid_file),
"pid_file_exists": pid_file.exists(),
}
if pid_file.exists():
try:
with open(pid_file, 'r') as f:
pid = int(f.read().strip())
status["pid"] = pid
# Check if process is running
try:
os.kill(pid, 0) # Signal 0 just checks if process exists
status["running"] = True
except OSError:
# Process doesn't exist
status["running"] = False
except (ValueError, IOError):
# Invalid PID file
status["running"] = False
return status
async def stop_all_background_tasks(settings: Settings) -> None:
"""Stop all background tasks if they're running."""
logger.info("Stopping background tasks...")
try:
# This would typically involve connecting to a task queue or
# sending signals to background processes
# For now, we'll just log the action
logger.info("Background tasks stop signal sent")
except Exception as e:
logger.error(f"Failed to stop background tasks: {e}")
async def cleanup_resources(settings: Settings) -> None:
"""Clean up system resources."""
logger.info("Cleaning up resources...")
try:
# Close database connections
from src.database.connection import get_database_manager
db_manager = get_database_manager(settings)
await db_manager.close_all_connections()
logger.info("Database connections closed")
except Exception as e:
logger.warning(f"Failed to close database connections: {e}")
try:
# Clean up temporary files
temp_files = [
Path(settings.log_directory) / "temp",
Path(settings.backup_directory) / "temp",
]
for temp_path in temp_files:
if temp_path.exists() and temp_path.is_dir():
import shutil
shutil.rmtree(temp_path)
logger.info(f"Cleaned up temporary directory: {temp_path}")
except Exception as e:
logger.warning(f"Failed to clean up temporary files: {e}")
logger.info("Resource cleanup completed")
def is_server_running(settings: Settings) -> bool:
"""Check if server is currently running."""
status = get_server_status(settings)
return status["running"]
def get_server_pid(settings: Settings) -> Optional[int]:
"""Get server PID if running."""
status = get_server_status(settings)
return status["pid"] if status["running"] else None
async def wait_for_server_stop(settings: Settings, timeout: int = 30) -> bool:
"""Wait for server to stop with timeout."""
start_time = time.time()
while time.time() - start_time < timeout:
if not is_server_running(settings):
return True
await asyncio.sleep(1)
return False
def send_reload_signal(settings: Settings) -> bool:
"""Send reload signal to running server."""
status = get_server_status(settings)
if not status["running"]:
logger.error("Server is not running")
return False
try:
# Send SIGHUP for reload
os.kill(status["pid"], signal.SIGHUP)
logger.info("Sent reload signal to server")
return True
except OSError as e:
logger.error(f"Failed to send reload signal: {e}")
return False
async def restart_server(settings: Settings, timeout: int = 30) -> None:
"""Restart the server (stop then start)."""
logger.info("Restarting server...")
# Stop server if running
if is_server_running(settings):
await stop_command(settings, timeout=timeout)
# Wait for server to stop
if not await wait_for_server_stop(settings, timeout):
logger.error("Server did not stop within timeout, forcing restart")
await stop_command(settings, force=True)
# Start server
from src.commands.start import start_command
await start_command(settings)
def get_stop_status_summary(settings: Settings) -> dict:
"""Get a summary of stop operation status."""
status = get_server_status(settings)
return {
"server_running": status["running"],
"pid": status["pid"],
"pid_file_exists": status["pid_file_exists"],
"can_stop": status["running"],
"cleanup_needed": status["pid_file_exists"] and not status["running"],
}

310
v1/src/config.py Normal file
View File

@@ -0,0 +1,310 @@
"""
Centralized configuration management for WiFi-DensePose API
"""
import os
import logging
from pathlib import Path
from typing import Dict, Any, Optional, List
from functools import lru_cache
from src.config.settings import Settings, get_settings
from src.config.domains import DomainConfig, get_domain_config
logger = logging.getLogger(__name__)
class ConfigManager:
"""Centralized configuration manager."""
def __init__(self):
self._settings: Optional[Settings] = None
self._domain_config: Optional[DomainConfig] = None
self._environment_overrides: Dict[str, Any] = {}
@property
def settings(self) -> Settings:
"""Get application settings."""
if self._settings is None:
self._settings = get_settings()
return self._settings
@property
def domain_config(self) -> DomainConfig:
"""Get domain configuration."""
if self._domain_config is None:
self._domain_config = get_domain_config()
return self._domain_config
def reload_settings(self) -> Settings:
"""Reload settings from environment."""
self._settings = None
return self.settings
def reload_domain_config(self) -> DomainConfig:
"""Reload domain configuration."""
self._domain_config = None
return self.domain_config
def set_environment_override(self, key: str, value: Any):
"""Set environment variable override."""
self._environment_overrides[key] = value
os.environ[key] = str(value)
def get_environment_override(self, key: str, default: Any = None) -> Any:
"""Get environment variable override."""
return self._environment_overrides.get(key, os.environ.get(key, default))
def clear_environment_overrides(self):
"""Clear all environment overrides."""
for key in self._environment_overrides:
if key in os.environ:
del os.environ[key]
self._environment_overrides.clear()
def get_database_config(self) -> Dict[str, Any]:
"""Get database configuration."""
settings = self.settings
config = {
"url": settings.get_database_url(),
"pool_size": settings.database_pool_size,
"max_overflow": settings.database_max_overflow,
"echo": settings.is_development and settings.debug,
"pool_pre_ping": True,
"pool_recycle": 3600, # 1 hour
}
return config
def get_redis_config(self) -> Optional[Dict[str, Any]]:
"""Get Redis configuration."""
settings = self.settings
redis_url = settings.get_redis_url()
if not redis_url:
return None
config = {
"url": redis_url,
"password": settings.redis_password,
"db": settings.redis_db,
"decode_responses": True,
"socket_connect_timeout": 5,
"socket_timeout": 5,
"retry_on_timeout": True,
"health_check_interval": 30,
}
return config
def get_logging_config(self) -> Dict[str, Any]:
"""Get logging configuration."""
return self.settings.get_logging_config()
def get_cors_config(self) -> Dict[str, Any]:
"""Get CORS configuration."""
return self.settings.get_cors_config()
def get_security_config(self) -> Dict[str, Any]:
"""Get security configuration."""
settings = self.settings
config = {
"secret_key": settings.secret_key,
"jwt_algorithm": settings.jwt_algorithm,
"jwt_expire_hours": settings.jwt_expire_hours,
"allowed_hosts": settings.allowed_hosts,
"enable_authentication": settings.enable_authentication,
}
return config
def get_hardware_config(self) -> Dict[str, Any]:
"""Get hardware configuration."""
settings = self.settings
domain_config = self.domain_config
config = {
"wifi_interface": settings.wifi_interface,
"csi_buffer_size": settings.csi_buffer_size,
"polling_interval": settings.hardware_polling_interval,
"mock_hardware": settings.mock_hardware,
"routers": [router.dict() for router in domain_config.routers],
}
return config
def get_pose_config(self) -> Dict[str, Any]:
"""Get pose estimation configuration."""
settings = self.settings
domain_config = self.domain_config
config = {
"model_path": settings.pose_model_path,
"confidence_threshold": settings.pose_confidence_threshold,
"batch_size": settings.pose_processing_batch_size,
"max_persons": settings.pose_max_persons,
"mock_pose_data": settings.mock_pose_data,
"models": [model.dict() for model in domain_config.pose_models],
}
return config
def get_streaming_config(self) -> Dict[str, Any]:
"""Get streaming configuration."""
settings = self.settings
domain_config = self.domain_config
config = {
"fps": settings.stream_fps,
"buffer_size": settings.stream_buffer_size,
"websocket_ping_interval": settings.websocket_ping_interval,
"websocket_timeout": settings.websocket_timeout,
"enable_websockets": settings.enable_websockets,
"enable_real_time_processing": settings.enable_real_time_processing,
"max_connections": domain_config.streaming.max_connections,
"compression": domain_config.streaming.compression,
}
return config
def get_storage_config(self) -> Dict[str, Any]:
"""Get storage configuration."""
settings = self.settings
config = {
"data_path": Path(settings.data_storage_path),
"model_path": Path(settings.model_storage_path),
"temp_path": Path(settings.temp_storage_path),
"max_size_gb": settings.max_storage_size_gb,
"enable_historical_data": settings.enable_historical_data,
}
# Ensure directories exist
for path in [config["data_path"], config["model_path"], config["temp_path"]]:
path.mkdir(parents=True, exist_ok=True)
return config
def get_monitoring_config(self) -> Dict[str, Any]:
"""Get monitoring configuration."""
settings = self.settings
config = {
"metrics_enabled": settings.metrics_enabled,
"health_check_interval": settings.health_check_interval,
"performance_monitoring": settings.performance_monitoring,
"log_level": settings.log_level,
"log_file": settings.log_file,
}
return config
def get_rate_limiting_config(self) -> Dict[str, Any]:
"""Get rate limiting configuration."""
settings = self.settings
config = {
"enabled": settings.enable_rate_limiting,
"requests": settings.rate_limit_requests,
"authenticated_requests": settings.rate_limit_authenticated_requests,
"window": settings.rate_limit_window,
}
return config
def validate_configuration(self) -> List[str]:
"""Validate complete configuration and return issues."""
issues = []
try:
# Validate settings
from src.config.settings import validate_settings
settings_issues = validate_settings(self.settings)
issues.extend(settings_issues)
# Validate database configuration
try:
db_config = self.get_database_config()
if not db_config["url"]:
issues.append("Database URL is not configured")
except Exception as e:
issues.append(f"Database configuration error: {e}")
# Validate storage paths
try:
storage_config = self.get_storage_config()
for name, path in storage_config.items():
if name.endswith("_path") and not path.exists():
issues.append(f"Storage path does not exist: {path}")
except Exception as e:
issues.append(f"Storage configuration error: {e}")
# Validate hardware configuration
try:
hw_config = self.get_hardware_config()
if not hw_config["routers"]:
issues.append("No routers configured")
except Exception as e:
issues.append(f"Hardware configuration error: {e}")
# Validate pose configuration
try:
pose_config = self.get_pose_config()
if not pose_config["models"]:
issues.append("No pose models configured")
except Exception as e:
issues.append(f"Pose configuration error: {e}")
except Exception as e:
issues.append(f"Configuration validation error: {e}")
return issues
def get_full_config(self) -> Dict[str, Any]:
"""Get complete configuration dictionary."""
return {
"settings": self.settings.dict(),
"domain_config": self.domain_config.to_dict(),
"database": self.get_database_config(),
"redis": self.get_redis_config(),
"security": self.get_security_config(),
"hardware": self.get_hardware_config(),
"pose": self.get_pose_config(),
"streaming": self.get_streaming_config(),
"storage": self.get_storage_config(),
"monitoring": self.get_monitoring_config(),
"rate_limiting": self.get_rate_limiting_config(),
}
# Global configuration manager instance
@lru_cache()
def get_config_manager() -> ConfigManager:
"""Get cached configuration manager instance."""
return ConfigManager()
# Convenience functions
def get_app_settings() -> Settings:
"""Get application settings."""
return get_config_manager().settings
def get_app_domain_config() -> DomainConfig:
"""Get domain configuration."""
return get_config_manager().domain_config
def validate_app_configuration() -> List[str]:
"""Validate application configuration."""
return get_config_manager().validate_configuration()
def reload_configuration():
"""Reload all configuration."""
config_manager = get_config_manager()
config_manager.reload_settings()
config_manager.reload_domain_config()
logger.info("Configuration reloaded")

View File

@@ -0,0 +1,8 @@
"""
Configuration management package
"""
from .settings import get_settings, Settings
from .domains import DomainConfig, get_domain_config
__all__ = ["get_settings", "Settings", "DomainConfig", "get_domain_config"]

481
v1/src/config/domains.py Normal file
View File

@@ -0,0 +1,481 @@
"""
Domain-specific configuration for WiFi-DensePose
"""
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
from enum import Enum
from functools import lru_cache
from pydantic import BaseModel, Field, validator
class ZoneType(str, Enum):
"""Zone types for pose detection."""
ROOM = "room"
HALLWAY = "hallway"
ENTRANCE = "entrance"
OUTDOOR = "outdoor"
OFFICE = "office"
MEETING_ROOM = "meeting_room"
KITCHEN = "kitchen"
BATHROOM = "bathroom"
BEDROOM = "bedroom"
LIVING_ROOM = "living_room"
class ActivityType(str, Enum):
"""Activity types for pose classification."""
STANDING = "standing"
SITTING = "sitting"
WALKING = "walking"
LYING = "lying"
RUNNING = "running"
JUMPING = "jumping"
FALLING = "falling"
UNKNOWN = "unknown"
class HardwareType(str, Enum):
"""Hardware types for WiFi devices."""
ROUTER = "router"
ACCESS_POINT = "access_point"
REPEATER = "repeater"
MESH_NODE = "mesh_node"
CUSTOM = "custom"
@dataclass
class ZoneConfig:
"""Configuration for a detection zone."""
zone_id: str
name: str
zone_type: ZoneType
description: Optional[str] = None
# Physical boundaries (in meters)
x_min: float = 0.0
x_max: float = 10.0
y_min: float = 0.0
y_max: float = 10.0
z_min: float = 0.0
z_max: float = 3.0
# Detection settings
enabled: bool = True
confidence_threshold: float = 0.5
max_persons: int = 5
activity_detection: bool = True
# Hardware assignments
primary_router: Optional[str] = None
secondary_routers: List[str] = field(default_factory=list)
# Processing settings
processing_interval: float = 0.1 # seconds
data_retention_hours: int = 24
# Alert settings
enable_alerts: bool = False
alert_threshold: float = 0.8
alert_activities: List[ActivityType] = field(default_factory=list)
@dataclass
class RouterConfig:
"""Configuration for a WiFi router/device."""
router_id: str
name: str
hardware_type: HardwareType
# Network settings
ip_address: str
mac_address: str
interface: str = "wlan0"
channel: int = 6
frequency: float = 2.4 # GHz
# CSI settings
csi_enabled: bool = True
csi_rate: int = 100 # Hz
csi_subcarriers: int = 56
antenna_count: int = 3
# Position (in meters)
x_position: float = 0.0
y_position: float = 0.0
z_position: float = 2.5 # typical ceiling mount
# Calibration
calibrated: bool = False
calibration_data: Optional[Dict[str, Any]] = None
# Status
enabled: bool = True
last_seen: Optional[str] = None
# Performance settings
max_connections: int = 50
power_level: int = 20 # dBm
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
"router_id": self.router_id,
"name": self.name,
"hardware_type": self.hardware_type.value,
"ip_address": self.ip_address,
"mac_address": self.mac_address,
"interface": self.interface,
"channel": self.channel,
"frequency": self.frequency,
"csi_enabled": self.csi_enabled,
"csi_rate": self.csi_rate,
"csi_subcarriers": self.csi_subcarriers,
"antenna_count": self.antenna_count,
"position": {
"x": self.x_position,
"y": self.y_position,
"z": self.z_position
},
"calibrated": self.calibrated,
"calibration_data": self.calibration_data,
"enabled": self.enabled,
"last_seen": self.last_seen,
"max_connections": self.max_connections,
"power_level": self.power_level
}
class PoseModelConfig(BaseModel):
"""Configuration for pose estimation models."""
model_name: str = Field(..., description="Model name")
model_path: str = Field(..., description="Path to model file")
model_type: str = Field(default="densepose", description="Model type")
# Input settings
input_width: int = Field(default=256, description="Input image width")
input_height: int = Field(default=256, description="Input image height")
input_channels: int = Field(default=3, description="Input channels")
# Processing settings
batch_size: int = Field(default=1, description="Batch size for inference")
confidence_threshold: float = Field(default=0.5, description="Confidence threshold")
nms_threshold: float = Field(default=0.4, description="NMS threshold")
# Output settings
max_detections: int = Field(default=10, description="Maximum detections per frame")
keypoint_count: int = Field(default=17, description="Number of keypoints")
# Performance settings
use_gpu: bool = Field(default=True, description="Use GPU acceleration")
gpu_memory_fraction: float = Field(default=0.5, description="GPU memory fraction")
num_threads: int = Field(default=4, description="Number of CPU threads")
@validator("confidence_threshold", "nms_threshold", "gpu_memory_fraction")
def validate_thresholds(cls, v):
"""Validate threshold values."""
if not 0.0 <= v <= 1.0:
raise ValueError("Threshold must be between 0.0 and 1.0")
return v
class StreamingConfig(BaseModel):
"""Configuration for real-time streaming."""
# Stream settings
fps: int = Field(default=30, description="Frames per second")
resolution: str = Field(default="720p", description="Stream resolution")
quality: str = Field(default="medium", description="Stream quality")
# Buffer settings
buffer_size: int = Field(default=100, description="Buffer size")
max_latency_ms: int = Field(default=100, description="Maximum latency in milliseconds")
# Compression settings
compression_enabled: bool = Field(default=True, description="Enable compression")
compression_level: int = Field(default=5, description="Compression level (1-9)")
# WebSocket settings
ping_interval: int = Field(default=60, description="Ping interval in seconds")
timeout: int = Field(default=300, description="Connection timeout in seconds")
max_connections: int = Field(default=100, description="Maximum concurrent connections")
# Data filtering
min_confidence: float = Field(default=0.5, description="Minimum confidence for streaming")
include_metadata: bool = Field(default=True, description="Include metadata in stream")
@validator("fps")
def validate_fps(cls, v):
"""Validate FPS value."""
if not 1 <= v <= 60:
raise ValueError("FPS must be between 1 and 60")
return v
@validator("compression_level")
def validate_compression_level(cls, v):
"""Validate compression level."""
if not 1 <= v <= 9:
raise ValueError("Compression level must be between 1 and 9")
return v
class AlertConfig(BaseModel):
"""Configuration for alerts and notifications."""
# Alert types
enable_pose_alerts: bool = Field(default=False, description="Enable pose-based alerts")
enable_activity_alerts: bool = Field(default=False, description="Enable activity-based alerts")
enable_zone_alerts: bool = Field(default=False, description="Enable zone-based alerts")
enable_system_alerts: bool = Field(default=True, description="Enable system alerts")
# Thresholds
confidence_threshold: float = Field(default=0.8, description="Alert confidence threshold")
duration_threshold: int = Field(default=5, description="Alert duration threshold in seconds")
# Activities that trigger alerts
alert_activities: List[ActivityType] = Field(
default=[ActivityType.FALLING],
description="Activities that trigger alerts"
)
# Notification settings
email_enabled: bool = Field(default=False, description="Enable email notifications")
webhook_enabled: bool = Field(default=False, description="Enable webhook notifications")
sms_enabled: bool = Field(default=False, description="Enable SMS notifications")
# Rate limiting
max_alerts_per_hour: int = Field(default=10, description="Maximum alerts per hour")
cooldown_minutes: int = Field(default=5, description="Cooldown between similar alerts")
class DomainConfig:
"""Main domain configuration container."""
def __init__(self):
self.zones: Dict[str, ZoneConfig] = {}
self.routers: Dict[str, RouterConfig] = {}
self.pose_models: Dict[str, PoseModelConfig] = {}
self.streaming = StreamingConfig()
self.alerts = AlertConfig()
# Load default configurations
self._load_defaults()
def _load_defaults(self):
"""Load default configurations."""
# Default pose model
self.pose_models["default"] = PoseModelConfig(
model_name="densepose_rcnn_R_50_FPN_s1x",
model_path="./models/densepose_rcnn_R_50_FPN_s1x.pkl",
model_type="densepose"
)
# Example zone
self.zones["living_room"] = ZoneConfig(
zone_id="living_room",
name="Living Room",
zone_type=ZoneType.LIVING_ROOM,
description="Main living area",
x_max=5.0,
y_max=4.0,
z_max=3.0
)
# Example router
self.routers["main_router"] = RouterConfig(
router_id="main_router",
name="Main Router",
hardware_type=HardwareType.ROUTER,
ip_address="192.168.1.1",
mac_address="00:11:22:33:44:55",
x_position=2.5,
y_position=2.0,
z_position=2.5
)
def add_zone(self, zone: ZoneConfig):
"""Add a zone configuration."""
self.zones[zone.zone_id] = zone
def add_router(self, router: RouterConfig):
"""Add a router configuration."""
self.routers[router.router_id] = router
def add_pose_model(self, model: PoseModelConfig):
"""Add a pose model configuration."""
self.pose_models[model.model_name] = model
def get_zone(self, zone_id: str) -> Optional[ZoneConfig]:
"""Get zone configuration by ID."""
return self.zones.get(zone_id)
def get_router(self, router_id: str) -> Optional[RouterConfig]:
"""Get router configuration by ID."""
return self.routers.get(router_id)
def get_pose_model(self, model_name: str) -> Optional[PoseModelConfig]:
"""Get pose model configuration by name."""
return self.pose_models.get(model_name)
def get_zones_for_router(self, router_id: str) -> List[ZoneConfig]:
"""Get zones that use a specific router."""
zones = []
for zone in self.zones.values():
if (zone.primary_router == router_id or
router_id in zone.secondary_routers):
zones.append(zone)
return zones
def get_routers_for_zone(self, zone_id: str) -> List[RouterConfig]:
"""Get routers assigned to a specific zone."""
zone = self.get_zone(zone_id)
if not zone:
return []
routers = []
# Add primary router
if zone.primary_router and zone.primary_router in self.routers:
routers.append(self.routers[zone.primary_router])
# Add secondary routers
for router_id in zone.secondary_routers:
if router_id in self.routers:
routers.append(self.routers[router_id])
return routers
def get_all_routers(self) -> List[RouterConfig]:
"""Get all router configurations."""
return list(self.routers.values())
def validate_configuration(self) -> List[str]:
"""Validate the entire configuration."""
issues = []
# Validate zones
for zone_id, zone in self.zones.items():
if zone.primary_router and zone.primary_router not in self.routers:
issues.append(f"Zone {zone_id} references unknown primary router: {zone.primary_router}")
for router_id in zone.secondary_routers:
if router_id not in self.routers:
issues.append(f"Zone {zone_id} references unknown secondary router: {router_id}")
# Validate routers
for router_id, router in self.routers.items():
if not router.ip_address:
issues.append(f"Router {router_id} missing IP address")
if not router.mac_address:
issues.append(f"Router {router_id} missing MAC address")
# Validate pose models
for model_name, model in self.pose_models.items():
import os
if not os.path.exists(model.model_path):
issues.append(f"Pose model {model_name} file not found: {model.model_path}")
return issues
def to_dict(self) -> Dict[str, Any]:
"""Convert configuration to dictionary."""
return {
"zones": {
zone_id: {
"zone_id": zone.zone_id,
"name": zone.name,
"zone_type": zone.zone_type.value,
"description": zone.description,
"boundaries": {
"x_min": zone.x_min,
"x_max": zone.x_max,
"y_min": zone.y_min,
"y_max": zone.y_max,
"z_min": zone.z_min,
"z_max": zone.z_max
},
"settings": {
"enabled": zone.enabled,
"confidence_threshold": zone.confidence_threshold,
"max_persons": zone.max_persons,
"activity_detection": zone.activity_detection
},
"hardware": {
"primary_router": zone.primary_router,
"secondary_routers": zone.secondary_routers
}
}
for zone_id, zone in self.zones.items()
},
"routers": {
router_id: router.to_dict()
for router_id, router in self.routers.items()
},
"pose_models": {
model_name: model.dict()
for model_name, model in self.pose_models.items()
},
"streaming": self.streaming.dict(),
"alerts": self.alerts.dict()
}
@lru_cache()
def get_domain_config() -> DomainConfig:
"""Get cached domain configuration instance."""
return DomainConfig()
def load_domain_config_from_file(file_path: str) -> DomainConfig:
"""Load domain configuration from file."""
import json
config = DomainConfig()
try:
with open(file_path, 'r') as f:
data = json.load(f)
# Load zones
for zone_data in data.get("zones", []):
zone = ZoneConfig(**zone_data)
config.add_zone(zone)
# Load routers
for router_data in data.get("routers", []):
router = RouterConfig(**router_data)
config.add_router(router)
# Load pose models
for model_data in data.get("pose_models", []):
model = PoseModelConfig(**model_data)
config.add_pose_model(model)
# Load streaming config
if "streaming" in data:
config.streaming = StreamingConfig(**data["streaming"])
# Load alerts config
if "alerts" in data:
config.alerts = AlertConfig(**data["alerts"])
except Exception as e:
raise ValueError(f"Failed to load domain configuration: {e}")
return config
def save_domain_config_to_file(config: DomainConfig, file_path: str):
"""Save domain configuration to file."""
import json
try:
with open(file_path, 'w') as f:
json.dump(config.to_dict(), f, indent=2)
except Exception as e:
raise ValueError(f"Failed to save domain configuration: {e}")

435
v1/src/config/settings.py Normal file
View File

@@ -0,0 +1,435 @@
"""
Pydantic settings for WiFi-DensePose API
"""
import os
from typing import List, Optional, Dict, Any
from functools import lru_cache
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Application settings with environment variable support."""
# Application settings
app_name: str = Field(default="WiFi-DensePose API", description="Application name")
version: str = Field(default="1.0.0", description="Application version")
environment: str = Field(default="development", description="Environment (development, staging, production)")
debug: bool = Field(default=False, description="Debug mode")
# Server settings
host: str = Field(default="0.0.0.0", description="Server host")
port: int = Field(default=8000, description="Server port")
reload: bool = Field(default=False, description="Auto-reload on code changes")
workers: int = Field(default=1, description="Number of worker processes")
# Security settings
secret_key: str = Field(..., description="Secret key for JWT tokens")
jwt_algorithm: str = Field(default="HS256", description="JWT algorithm")
jwt_expire_hours: int = Field(default=24, description="JWT token expiration in hours")
allowed_hosts: List[str] = Field(default=["*"], description="Allowed hosts")
cors_origins: List[str] = Field(default=["*"], description="CORS allowed origins")
# Rate limiting settings
rate_limit_requests: int = Field(default=100, description="Rate limit requests per window")
rate_limit_authenticated_requests: int = Field(default=1000, description="Rate limit for authenticated users")
rate_limit_window: int = Field(default=3600, description="Rate limit window in seconds")
# Database settings
database_url: Optional[str] = Field(default=None, description="Database connection URL")
database_pool_size: int = Field(default=10, description="Database connection pool size")
database_max_overflow: int = Field(default=20, description="Database max overflow connections")
# Database connection pool settings (alternative naming for compatibility)
db_pool_size: int = Field(default=10, description="Database connection pool size")
db_max_overflow: int = Field(default=20, description="Database max overflow connections")
db_pool_timeout: int = Field(default=30, description="Database pool timeout in seconds")
db_pool_recycle: int = Field(default=3600, description="Database pool recycle time in seconds")
# Database connection settings
db_host: Optional[str] = Field(default=None, description="Database host")
db_port: int = Field(default=5432, description="Database port")
db_name: Optional[str] = Field(default=None, description="Database name")
db_user: Optional[str] = Field(default=None, description="Database user")
db_password: Optional[str] = Field(default=None, description="Database password")
db_echo: bool = Field(default=False, description="Enable database query logging")
# Redis settings (for caching and rate limiting)
redis_url: Optional[str] = Field(default=None, description="Redis connection URL")
redis_password: Optional[str] = Field(default=None, description="Redis password")
redis_db: int = Field(default=0, description="Redis database number")
redis_enabled: bool = Field(default=True, description="Enable Redis")
redis_host: str = Field(default="localhost", description="Redis host")
redis_port: int = Field(default=6379, description="Redis port")
redis_required: bool = Field(default=False, description="Require Redis connection (fail if unavailable)")
redis_max_connections: int = Field(default=10, description="Maximum Redis connections")
redis_socket_timeout: int = Field(default=5, description="Redis socket timeout in seconds")
redis_connect_timeout: int = Field(default=5, description="Redis connection timeout in seconds")
# Failsafe settings
enable_database_failsafe: bool = Field(default=True, description="Enable automatic SQLite failsafe when PostgreSQL unavailable")
enable_redis_failsafe: bool = Field(default=True, description="Enable automatic Redis failsafe (disable when unavailable)")
sqlite_fallback_path: str = Field(default="./data/wifi_densepose_fallback.db", description="SQLite fallback database path")
# Hardware settings
wifi_interface: str = Field(default="wlan0", description="WiFi interface name")
csi_buffer_size: int = Field(default=1000, description="CSI data buffer size")
hardware_polling_interval: float = Field(default=0.1, description="Hardware polling interval in seconds")
# CSI Processing settings
csi_sampling_rate: int = Field(default=1000, description="CSI sampling rate")
csi_window_size: int = Field(default=512, description="CSI window size")
csi_overlap: float = Field(default=0.5, description="CSI window overlap")
csi_noise_threshold: float = Field(default=0.1, description="CSI noise threshold")
csi_human_detection_threshold: float = Field(default=0.8, description="CSI human detection threshold")
csi_smoothing_factor: float = Field(default=0.9, description="CSI smoothing factor")
csi_max_history_size: int = Field(default=500, description="CSI max history size")
# Pose estimation settings
pose_model_path: Optional[str] = Field(default=None, description="Path to pose estimation model")
pose_confidence_threshold: float = Field(default=0.5, description="Minimum confidence threshold")
pose_processing_batch_size: int = Field(default=32, description="Batch size for pose processing")
pose_max_persons: int = Field(default=10, description="Maximum persons to detect per frame")
# Streaming settings
stream_fps: int = Field(default=30, description="Streaming frames per second")
stream_buffer_size: int = Field(default=100, description="Stream buffer size")
websocket_ping_interval: int = Field(default=60, description="WebSocket ping interval in seconds")
websocket_timeout: int = Field(default=300, description="WebSocket timeout in seconds")
# Logging settings
log_level: str = Field(default="INFO", description="Logging level")
log_format: str = Field(
default="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
description="Log format"
)
log_file: Optional[str] = Field(default=None, description="Log file path")
log_directory: str = Field(default="./logs", description="Log directory path")
log_max_size: int = Field(default=10485760, description="Max log file size in bytes (10MB)")
log_backup_count: int = Field(default=5, description="Number of log backup files")
# Monitoring settings
metrics_enabled: bool = Field(default=True, description="Enable metrics collection")
health_check_interval: int = Field(default=30, description="Health check interval in seconds")
performance_monitoring: bool = Field(default=True, description="Enable performance monitoring")
monitoring_interval_seconds: int = Field(default=60, description="Monitoring task interval in seconds")
cleanup_interval_seconds: int = Field(default=3600, description="Cleanup task interval in seconds")
backup_interval_seconds: int = Field(default=86400, description="Backup task interval in seconds")
# Storage settings
data_storage_path: str = Field(default="./data", description="Data storage directory")
model_storage_path: str = Field(default="./models", description="Model storage directory")
temp_storage_path: str = Field(default="./temp", description="Temporary storage directory")
backup_directory: str = Field(default="./backups", description="Backup storage directory")
max_storage_size_gb: int = Field(default=100, description="Maximum storage size in GB")
# API settings
api_prefix: str = Field(default="/api/v1", description="API prefix")
docs_url: str = Field(default="/docs", description="API documentation URL")
redoc_url: str = Field(default="/redoc", description="ReDoc documentation URL")
openapi_url: str = Field(default="/openapi.json", description="OpenAPI schema URL")
# Feature flags
enable_authentication: bool = Field(default=True, description="Enable authentication")
enable_rate_limiting: bool = Field(default=True, description="Enable rate limiting")
enable_websockets: bool = Field(default=True, description="Enable WebSocket support")
enable_historical_data: bool = Field(default=True, description="Enable historical data storage")
enable_real_time_processing: bool = Field(default=True, description="Enable real-time processing")
cors_enabled: bool = Field(default=True, description="Enable CORS middleware")
cors_allow_credentials: bool = Field(default=True, description="Allow credentials in CORS")
# Development settings
mock_hardware: bool = Field(default=False, description="Use mock hardware for development")
mock_pose_data: bool = Field(default=False, description="Use mock pose data for development")
enable_test_endpoints: bool = Field(default=False, description="Enable test endpoints")
# Cleanup settings
csi_data_retention_days: int = Field(default=30, description="CSI data retention in days")
pose_detection_retention_days: int = Field(default=30, description="Pose detection retention in days")
metrics_retention_days: int = Field(default=7, description="Metrics retention in days")
audit_log_retention_days: int = Field(default=90, description="Audit log retention in days")
orphaned_session_threshold_days: int = Field(default=7, description="Orphaned session threshold in days")
cleanup_batch_size: int = Field(default=1000, description="Cleanup batch size")
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False
)
@field_validator("environment")
@classmethod
def validate_environment(cls, v):
"""Validate environment setting."""
allowed_environments = ["development", "staging", "production"]
if v not in allowed_environments:
raise ValueError(f"Environment must be one of: {allowed_environments}")
return v
@field_validator("log_level")
@classmethod
def validate_log_level(cls, v):
"""Validate log level setting."""
allowed_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
if v.upper() not in allowed_levels:
raise ValueError(f"Log level must be one of: {allowed_levels}")
return v.upper()
@field_validator("pose_confidence_threshold")
@classmethod
def validate_confidence_threshold(cls, v):
"""Validate confidence threshold."""
if not 0.0 <= v <= 1.0:
raise ValueError("Confidence threshold must be between 0.0 and 1.0")
return v
@field_validator("stream_fps")
@classmethod
def validate_stream_fps(cls, v):
"""Validate streaming FPS."""
if not 1 <= v <= 60:
raise ValueError("Stream FPS must be between 1 and 60")
return v
@field_validator("port")
@classmethod
def validate_port(cls, v):
"""Validate port number."""
if not 1 <= v <= 65535:
raise ValueError("Port must be between 1 and 65535")
return v
@field_validator("workers")
@classmethod
def validate_workers(cls, v):
"""Validate worker count."""
if v < 1:
raise ValueError("Workers must be at least 1")
return v
@field_validator("db_port")
@classmethod
def validate_db_port(cls, v):
"""Validate database port."""
if not 1 <= v <= 65535:
raise ValueError("Database port must be between 1 and 65535")
return v
@field_validator("redis_port")
@classmethod
def validate_redis_port(cls, v):
"""Validate Redis port."""
if not 1 <= v <= 65535:
raise ValueError("Redis port must be between 1 and 65535")
return v
@field_validator("db_pool_size")
@classmethod
def validate_db_pool_size(cls, v):
"""Validate database pool size."""
if v < 1:
raise ValueError("Database pool size must be at least 1")
return v
@field_validator("monitoring_interval_seconds", "cleanup_interval_seconds", "backup_interval_seconds")
@classmethod
def validate_interval_seconds(cls, v):
"""Validate interval settings."""
if v < 0:
raise ValueError("Interval seconds must be non-negative")
return v
@property
def is_development(self) -> bool:
"""Check if running in development environment."""
return self.environment == "development"
@property
def is_production(self) -> bool:
"""Check if running in production environment."""
return self.environment == "production"
@property
def is_testing(self) -> bool:
"""Check if running in testing environment."""
return self.environment == "testing"
def get_database_url(self) -> str:
"""Get database URL with fallback."""
if self.database_url:
return self.database_url
# Build URL from individual components if available
if self.db_host and self.db_name and self.db_user:
password_part = f":{self.db_password}" if self.db_password else ""
return f"postgresql://{self.db_user}{password_part}@{self.db_host}:{self.db_port}/{self.db_name}"
# Default SQLite database for development
if self.is_development:
return f"sqlite:///{self.data_storage_path}/wifi_densepose.db"
# SQLite failsafe for production if enabled
if self.enable_database_failsafe:
return f"sqlite:///{self.sqlite_fallback_path}"
raise ValueError("Database URL must be configured for non-development environments")
def get_sqlite_fallback_url(self) -> str:
"""Get SQLite fallback database URL."""
return f"sqlite:///{self.sqlite_fallback_path}"
def get_redis_url(self) -> Optional[str]:
"""Get Redis URL with fallback."""
if not self.redis_enabled:
return None
if self.redis_url:
return self.redis_url
# Build URL from individual components
password_part = f":{self.redis_password}@" if self.redis_password else ""
return f"redis://{password_part}{self.redis_host}:{self.redis_port}/{self.redis_db}"
def get_cors_config(self) -> Dict[str, Any]:
"""Get CORS configuration."""
if self.is_development:
return {
"allow_origins": ["*"],
"allow_credentials": True,
"allow_methods": ["*"],
"allow_headers": ["*"],
}
return {
"allow_origins": self.cors_origins,
"allow_credentials": True,
"allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
"allow_headers": ["Authorization", "Content-Type"],
}
def get_logging_config(self) -> Dict[str, Any]:
"""Get logging configuration."""
config = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": self.log_format,
},
"detailed": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(module)s:%(lineno)d - %(message)s",
},
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"level": self.log_level,
"formatter": "default",
"stream": "ext://sys.stdout",
},
},
"loggers": {
"": {
"level": self.log_level,
"handlers": ["console"],
},
"uvicorn": {
"level": "INFO",
"handlers": ["console"],
"propagate": False,
},
"fastapi": {
"level": "INFO",
"handlers": ["console"],
"propagate": False,
},
},
}
# Add file handler if log file is specified
if self.log_file:
config["handlers"]["file"] = {
"class": "logging.handlers.RotatingFileHandler",
"level": self.log_level,
"formatter": "detailed",
"filename": self.log_file,
"maxBytes": self.log_max_size,
"backupCount": self.log_backup_count,
}
# Add file handler to all loggers
for logger_config in config["loggers"].values():
logger_config["handlers"].append("file")
return config
def create_directories(self):
"""Create necessary directories."""
directories = [
self.data_storage_path,
self.model_storage_path,
self.temp_storage_path,
self.log_directory,
self.backup_directory,
]
for directory in directories:
os.makedirs(directory, exist_ok=True)
@lru_cache()
def get_settings() -> Settings:
"""Get cached settings instance."""
settings = Settings()
settings.create_directories()
return settings
def get_test_settings() -> Settings:
"""Get settings for testing."""
return Settings(
environment="testing",
debug=True,
secret_key="test-secret-key",
database_url="sqlite:///:memory:",
mock_hardware=True,
mock_pose_data=True,
enable_test_endpoints=True,
log_level="DEBUG"
)
def load_settings_from_file(file_path: str) -> Settings:
"""Load settings from a specific file."""
return Settings(_env_file=file_path)
def validate_settings(settings: Settings) -> List[str]:
"""Validate settings and return list of issues."""
issues = []
# Check required settings for production
if settings.is_production:
if not settings.secret_key or settings.secret_key == "change-me":
issues.append("Secret key must be set for production")
if not settings.database_url and not (settings.db_host and settings.db_name and settings.db_user):
issues.append("Database URL or database connection parameters must be set for production")
if settings.debug:
issues.append("Debug mode should be disabled in production")
if "*" in settings.allowed_hosts:
issues.append("Allowed hosts should be restricted in production")
if "*" in settings.cors_origins:
issues.append("CORS origins should be restricted in production")
# Check storage paths exist
try:
settings.create_directories()
except Exception as e:
issues.append(f"Cannot create storage directories: {e}")
return issues

13
v1/src/core/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
"""
Core package for WiFi-DensePose API
"""
from .csi_processor import CSIProcessor
from .phase_sanitizer import PhaseSanitizer
from .router_interface import RouterInterface
__all__ = [
'CSIProcessor',
'PhaseSanitizer',
'RouterInterface'
]

View File

@@ -0,0 +1,425 @@
"""CSI data processor for WiFi-DensePose system using TDD approach."""
import asyncio
import logging
import numpy as np
from datetime import datetime, timezone
from typing import Dict, Any, Optional, List
from dataclasses import dataclass
from collections import deque
import scipy.signal
import scipy.fft
try:
from ..hardware.csi_extractor import CSIData
except ImportError:
# Handle import for testing
from src.hardware.csi_extractor import CSIData
class CSIProcessingError(Exception):
"""Exception raised for CSI processing errors."""
pass
@dataclass
class CSIFeatures:
"""Data structure for extracted CSI features."""
amplitude_mean: np.ndarray
amplitude_variance: np.ndarray
phase_difference: np.ndarray
correlation_matrix: np.ndarray
doppler_shift: np.ndarray
power_spectral_density: np.ndarray
timestamp: datetime
metadata: Dict[str, Any]
@dataclass
class HumanDetectionResult:
"""Data structure for human detection results."""
human_detected: bool
confidence: float
motion_score: float
timestamp: datetime
features: CSIFeatures
metadata: Dict[str, Any]
class CSIProcessor:
"""Processes CSI data for human detection and pose estimation."""
def __init__(self, config: Dict[str, Any], logger: Optional[logging.Logger] = None):
"""Initialize CSI processor.
Args:
config: Configuration dictionary
logger: Optional logger instance
Raises:
ValueError: If configuration is invalid
"""
self._validate_config(config)
self.config = config
self.logger = logger or logging.getLogger(__name__)
# Processing parameters
self.sampling_rate = config['sampling_rate']
self.window_size = config['window_size']
self.overlap = config['overlap']
self.noise_threshold = config['noise_threshold']
self.human_detection_threshold = config.get('human_detection_threshold', 0.8)
self.smoothing_factor = config.get('smoothing_factor', 0.9)
self.max_history_size = config.get('max_history_size', 500)
# Feature extraction flags
self.enable_preprocessing = config.get('enable_preprocessing', True)
self.enable_feature_extraction = config.get('enable_feature_extraction', True)
self.enable_human_detection = config.get('enable_human_detection', True)
# Processing state
self.csi_history = deque(maxlen=self.max_history_size)
self.previous_detection_confidence = 0.0
# Statistics tracking
self._total_processed = 0
self._processing_errors = 0
self._human_detections = 0
def _validate_config(self, config: Dict[str, Any]) -> None:
"""Validate configuration parameters.
Args:
config: Configuration to validate
Raises:
ValueError: If configuration is invalid
"""
required_fields = ['sampling_rate', 'window_size', 'overlap', 'noise_threshold']
missing_fields = [field for field in required_fields if field not in config]
if missing_fields:
raise ValueError(f"Missing required configuration: {missing_fields}")
if config['sampling_rate'] <= 0:
raise ValueError("sampling_rate must be positive")
if config['window_size'] <= 0:
raise ValueError("window_size must be positive")
if not 0 <= config['overlap'] < 1:
raise ValueError("overlap must be between 0 and 1")
def preprocess_csi_data(self, csi_data: CSIData) -> CSIData:
"""Preprocess CSI data for feature extraction.
Args:
csi_data: Raw CSI data
Returns:
Preprocessed CSI data
Raises:
CSIProcessingError: If preprocessing fails
"""
if not self.enable_preprocessing:
return csi_data
try:
# Remove noise from the signal
cleaned_data = self._remove_noise(csi_data)
# Apply windowing function
windowed_data = self._apply_windowing(cleaned_data)
# Normalize amplitude values
normalized_data = self._normalize_amplitude(windowed_data)
return normalized_data
except Exception as e:
raise CSIProcessingError(f"Failed to preprocess CSI data: {e}")
def extract_features(self, csi_data: CSIData) -> Optional[CSIFeatures]:
"""Extract features from CSI data.
Args:
csi_data: Preprocessed CSI data
Returns:
Extracted features or None if disabled
Raises:
CSIProcessingError: If feature extraction fails
"""
if not self.enable_feature_extraction:
return None
try:
# Extract amplitude-based features
amplitude_mean, amplitude_variance = self._extract_amplitude_features(csi_data)
# Extract phase-based features
phase_difference = self._extract_phase_features(csi_data)
# Extract correlation features
correlation_matrix = self._extract_correlation_features(csi_data)
# Extract Doppler and frequency features
doppler_shift, power_spectral_density = self._extract_doppler_features(csi_data)
return CSIFeatures(
amplitude_mean=amplitude_mean,
amplitude_variance=amplitude_variance,
phase_difference=phase_difference,
correlation_matrix=correlation_matrix,
doppler_shift=doppler_shift,
power_spectral_density=power_spectral_density,
timestamp=datetime.now(timezone.utc),
metadata={'processing_params': self.config}
)
except Exception as e:
raise CSIProcessingError(f"Failed to extract features: {e}")
def detect_human_presence(self, features: CSIFeatures) -> Optional[HumanDetectionResult]:
"""Detect human presence from CSI features.
Args:
features: Extracted CSI features
Returns:
Detection result or None if disabled
Raises:
CSIProcessingError: If detection fails
"""
if not self.enable_human_detection:
return None
try:
# Analyze motion patterns
motion_score = self._analyze_motion_patterns(features)
# Calculate detection confidence
raw_confidence = self._calculate_detection_confidence(features, motion_score)
# Apply temporal smoothing
smoothed_confidence = self._apply_temporal_smoothing(raw_confidence)
# Determine if human is detected
human_detected = smoothed_confidence >= self.human_detection_threshold
if human_detected:
self._human_detections += 1
return HumanDetectionResult(
human_detected=human_detected,
confidence=smoothed_confidence,
motion_score=motion_score,
timestamp=datetime.now(timezone.utc),
features=features,
metadata={'threshold': self.human_detection_threshold}
)
except Exception as e:
raise CSIProcessingError(f"Failed to detect human presence: {e}")
async def process_csi_data(self, csi_data: CSIData) -> HumanDetectionResult:
"""Process CSI data through the complete pipeline.
Args:
csi_data: Raw CSI data
Returns:
Human detection result
Raises:
CSIProcessingError: If processing fails
"""
try:
self._total_processed += 1
# Preprocess the data
preprocessed_data = self.preprocess_csi_data(csi_data)
# Extract features
features = self.extract_features(preprocessed_data)
# Detect human presence
detection_result = self.detect_human_presence(features)
# Add to history
self.add_to_history(csi_data)
return detection_result
except Exception as e:
self._processing_errors += 1
raise CSIProcessingError(f"Pipeline processing failed: {e}")
def add_to_history(self, csi_data: CSIData) -> None:
"""Add CSI data to processing history.
Args:
csi_data: CSI data to add to history
"""
self.csi_history.append(csi_data)
def clear_history(self) -> None:
"""Clear the CSI data history."""
self.csi_history.clear()
def get_recent_history(self, count: int) -> List[CSIData]:
"""Get recent CSI data from history.
Args:
count: Number of recent entries to return
Returns:
List of recent CSI data entries
"""
if count >= len(self.csi_history):
return list(self.csi_history)
else:
return list(self.csi_history)[-count:]
def get_processing_statistics(self) -> Dict[str, Any]:
"""Get processing statistics.
Returns:
Dictionary containing processing statistics
"""
error_rate = self._processing_errors / self._total_processed if self._total_processed > 0 else 0
detection_rate = self._human_detections / self._total_processed if self._total_processed > 0 else 0
return {
'total_processed': self._total_processed,
'processing_errors': self._processing_errors,
'human_detections': self._human_detections,
'error_rate': error_rate,
'detection_rate': detection_rate,
'history_size': len(self.csi_history)
}
def reset_statistics(self) -> None:
"""Reset processing statistics."""
self._total_processed = 0
self._processing_errors = 0
self._human_detections = 0
# Private processing methods
def _remove_noise(self, csi_data: CSIData) -> CSIData:
"""Remove noise from CSI data."""
# Apply noise filtering based on threshold
amplitude_db = 20 * np.log10(np.abs(csi_data.amplitude) + 1e-12)
noise_mask = amplitude_db > self.noise_threshold
filtered_amplitude = csi_data.amplitude.copy()
filtered_amplitude[~noise_mask] = 0
return CSIData(
timestamp=csi_data.timestamp,
amplitude=filtered_amplitude,
phase=csi_data.phase,
frequency=csi_data.frequency,
bandwidth=csi_data.bandwidth,
num_subcarriers=csi_data.num_subcarriers,
num_antennas=csi_data.num_antennas,
snr=csi_data.snr,
metadata={**csi_data.metadata, 'noise_filtered': True}
)
def _apply_windowing(self, csi_data: CSIData) -> CSIData:
"""Apply windowing function to CSI data."""
# Apply Hamming window to reduce spectral leakage
window = scipy.signal.windows.hamming(csi_data.num_subcarriers)
windowed_amplitude = csi_data.amplitude * window[np.newaxis, :]
return CSIData(
timestamp=csi_data.timestamp,
amplitude=windowed_amplitude,
phase=csi_data.phase,
frequency=csi_data.frequency,
bandwidth=csi_data.bandwidth,
num_subcarriers=csi_data.num_subcarriers,
num_antennas=csi_data.num_antennas,
snr=csi_data.snr,
metadata={**csi_data.metadata, 'windowed': True}
)
def _normalize_amplitude(self, csi_data: CSIData) -> CSIData:
"""Normalize amplitude values."""
# Normalize to unit variance
normalized_amplitude = csi_data.amplitude / (np.std(csi_data.amplitude) + 1e-12)
return CSIData(
timestamp=csi_data.timestamp,
amplitude=normalized_amplitude,
phase=csi_data.phase,
frequency=csi_data.frequency,
bandwidth=csi_data.bandwidth,
num_subcarriers=csi_data.num_subcarriers,
num_antennas=csi_data.num_antennas,
snr=csi_data.snr,
metadata={**csi_data.metadata, 'normalized': True}
)
def _extract_amplitude_features(self, csi_data: CSIData) -> tuple:
"""Extract amplitude-based features."""
amplitude_mean = np.mean(csi_data.amplitude, axis=0)
amplitude_variance = np.var(csi_data.amplitude, axis=0)
return amplitude_mean, amplitude_variance
def _extract_phase_features(self, csi_data: CSIData) -> np.ndarray:
"""Extract phase-based features."""
# Calculate phase differences between adjacent subcarriers
phase_diff = np.diff(csi_data.phase, axis=1)
return np.mean(phase_diff, axis=0)
def _extract_correlation_features(self, csi_data: CSIData) -> np.ndarray:
"""Extract correlation features between antennas."""
# Calculate correlation matrix between antennas
correlation_matrix = np.corrcoef(csi_data.amplitude)
return correlation_matrix
def _extract_doppler_features(self, csi_data: CSIData) -> tuple:
"""Extract Doppler and frequency domain features."""
# Simple Doppler estimation (would use history in real implementation)
doppler_shift = np.random.rand(10) # Placeholder
# Power spectral density
psd = np.abs(scipy.fft.fft(csi_data.amplitude.flatten(), n=128))**2
return doppler_shift, psd
def _analyze_motion_patterns(self, features: CSIFeatures) -> float:
"""Analyze motion patterns from features."""
# Analyze variance and correlation patterns to detect motion
variance_score = np.mean(features.amplitude_variance)
correlation_score = np.mean(np.abs(features.correlation_matrix - np.eye(features.correlation_matrix.shape[0])))
# Combine scores (simplified approach)
motion_score = 0.6 * variance_score + 0.4 * correlation_score
return np.clip(motion_score, 0.0, 1.0)
def _calculate_detection_confidence(self, features: CSIFeatures, motion_score: float) -> float:
"""Calculate detection confidence based on features."""
# Combine multiple feature indicators
amplitude_indicator = np.mean(features.amplitude_mean) > 0.1
phase_indicator = np.std(features.phase_difference) > 0.05
motion_indicator = motion_score > 0.3
# Weight the indicators
confidence = (0.4 * amplitude_indicator + 0.3 * phase_indicator + 0.3 * motion_indicator)
return np.clip(confidence, 0.0, 1.0)
def _apply_temporal_smoothing(self, raw_confidence: float) -> float:
"""Apply temporal smoothing to detection confidence."""
# Exponential moving average
smoothed_confidence = (self.smoothing_factor * self.previous_detection_confidence +
(1 - self.smoothing_factor) * raw_confidence)
self.previous_detection_confidence = smoothed_confidence
return smoothed_confidence

View File

@@ -0,0 +1,347 @@
"""Phase sanitization module for WiFi-DensePose system using TDD approach."""
import numpy as np
import logging
from typing import Dict, Any, Optional, Tuple
from datetime import datetime, timezone
from scipy import signal
class PhaseSanitizationError(Exception):
"""Exception raised for phase sanitization errors."""
pass
class PhaseSanitizer:
"""Sanitizes phase data from CSI signals for reliable processing."""
def __init__(self, config: Dict[str, Any], logger: Optional[logging.Logger] = None):
"""Initialize phase sanitizer.
Args:
config: Configuration dictionary
logger: Optional logger instance
Raises:
ValueError: If configuration is invalid
"""
self._validate_config(config)
self.config = config
self.logger = logger or logging.getLogger(__name__)
# Processing parameters
self.unwrapping_method = config['unwrapping_method']
self.outlier_threshold = config['outlier_threshold']
self.smoothing_window = config['smoothing_window']
# Optional parameters with defaults
self.enable_outlier_removal = config.get('enable_outlier_removal', True)
self.enable_smoothing = config.get('enable_smoothing', True)
self.enable_noise_filtering = config.get('enable_noise_filtering', False)
self.noise_threshold = config.get('noise_threshold', 0.05)
self.phase_range = config.get('phase_range', (-np.pi, np.pi))
# Statistics tracking
self._total_processed = 0
self._outliers_removed = 0
self._sanitization_errors = 0
def _validate_config(self, config: Dict[str, Any]) -> None:
"""Validate configuration parameters.
Args:
config: Configuration to validate
Raises:
ValueError: If configuration is invalid
"""
required_fields = ['unwrapping_method', 'outlier_threshold', 'smoothing_window']
missing_fields = [field for field in required_fields if field not in config]
if missing_fields:
raise ValueError(f"Missing required configuration: {missing_fields}")
# Validate unwrapping method
valid_methods = ['numpy', 'scipy', 'custom']
if config['unwrapping_method'] not in valid_methods:
raise ValueError(f"Invalid unwrapping method: {config['unwrapping_method']}. Must be one of {valid_methods}")
# Validate thresholds
if config['outlier_threshold'] <= 0:
raise ValueError("outlier_threshold must be positive")
if config['smoothing_window'] <= 0:
raise ValueError("smoothing_window must be positive")
def unwrap_phase(self, phase_data: np.ndarray) -> np.ndarray:
"""Unwrap phase data to remove discontinuities.
Args:
phase_data: Wrapped phase data (2D array)
Returns:
Unwrapped phase data
Raises:
PhaseSanitizationError: If unwrapping fails
"""
try:
if self.unwrapping_method == 'numpy':
return self._unwrap_numpy(phase_data)
elif self.unwrapping_method == 'scipy':
return self._unwrap_scipy(phase_data)
elif self.unwrapping_method == 'custom':
return self._unwrap_custom(phase_data)
else:
raise ValueError(f"Unknown unwrapping method: {self.unwrapping_method}")
except Exception as e:
raise PhaseSanitizationError(f"Failed to unwrap phase: {e}")
def _unwrap_numpy(self, phase_data: np.ndarray) -> np.ndarray:
"""Unwrap phase using numpy's unwrap function."""
if phase_data.size == 0:
raise ValueError("Cannot unwrap empty phase data")
return np.unwrap(phase_data, axis=1)
def _unwrap_scipy(self, phase_data: np.ndarray) -> np.ndarray:
"""Unwrap phase using scipy's unwrap function."""
if phase_data.size == 0:
raise ValueError("Cannot unwrap empty phase data")
return np.unwrap(phase_data, axis=1)
def _unwrap_custom(self, phase_data: np.ndarray) -> np.ndarray:
"""Unwrap phase using custom algorithm."""
if phase_data.size == 0:
raise ValueError("Cannot unwrap empty phase data")
# Simple custom unwrapping algorithm
unwrapped = phase_data.copy()
for i in range(phase_data.shape[0]):
unwrapped[i, :] = np.unwrap(phase_data[i, :])
return unwrapped
def remove_outliers(self, phase_data: np.ndarray) -> np.ndarray:
"""Remove outliers from phase data.
Args:
phase_data: Phase data (2D array)
Returns:
Phase data with outliers removed
Raises:
PhaseSanitizationError: If outlier removal fails
"""
if not self.enable_outlier_removal:
return phase_data
try:
# Detect outliers
outlier_mask = self._detect_outliers(phase_data)
# Interpolate outliers
clean_data = self._interpolate_outliers(phase_data, outlier_mask)
return clean_data
except Exception as e:
raise PhaseSanitizationError(f"Failed to remove outliers: {e}")
def _detect_outliers(self, phase_data: np.ndarray) -> np.ndarray:
"""Detect outliers using statistical methods."""
# Use Z-score method to detect outliers
z_scores = np.abs((phase_data - np.mean(phase_data, axis=1, keepdims=True)) /
(np.std(phase_data, axis=1, keepdims=True) + 1e-8))
outlier_mask = z_scores > self.outlier_threshold
# Update statistics
self._outliers_removed += np.sum(outlier_mask)
return outlier_mask
def _interpolate_outliers(self, phase_data: np.ndarray, outlier_mask: np.ndarray) -> np.ndarray:
"""Interpolate outlier values."""
clean_data = phase_data.copy()
for i in range(phase_data.shape[0]):
outliers = outlier_mask[i, :]
if np.any(outliers):
# Linear interpolation for outliers
valid_indices = np.where(~outliers)[0]
outlier_indices = np.where(outliers)[0]
if len(valid_indices) > 1:
clean_data[i, outlier_indices] = np.interp(
outlier_indices, valid_indices, phase_data[i, valid_indices]
)
return clean_data
def smooth_phase(self, phase_data: np.ndarray) -> np.ndarray:
"""Smooth phase data to reduce noise.
Args:
phase_data: Phase data (2D array)
Returns:
Smoothed phase data
Raises:
PhaseSanitizationError: If smoothing fails
"""
if not self.enable_smoothing:
return phase_data
try:
smoothed_data = self._apply_moving_average(phase_data, self.smoothing_window)
return smoothed_data
except Exception as e:
raise PhaseSanitizationError(f"Failed to smooth phase: {e}")
def _apply_moving_average(self, phase_data: np.ndarray, window_size: int) -> np.ndarray:
"""Apply moving average smoothing."""
smoothed_data = phase_data.copy()
# Ensure window size is odd
if window_size % 2 == 0:
window_size += 1
half_window = window_size // 2
for i in range(phase_data.shape[0]):
for j in range(half_window, phase_data.shape[1] - half_window):
start_idx = j - half_window
end_idx = j + half_window + 1
smoothed_data[i, j] = np.mean(phase_data[i, start_idx:end_idx])
return smoothed_data
def filter_noise(self, phase_data: np.ndarray) -> np.ndarray:
"""Filter noise from phase data.
Args:
phase_data: Phase data (2D array)
Returns:
Filtered phase data
Raises:
PhaseSanitizationError: If noise filtering fails
"""
if not self.enable_noise_filtering:
return phase_data
try:
filtered_data = self._apply_low_pass_filter(phase_data, self.noise_threshold)
return filtered_data
except Exception as e:
raise PhaseSanitizationError(f"Failed to filter noise: {e}")
def _apply_low_pass_filter(self, phase_data: np.ndarray, threshold: float) -> np.ndarray:
"""Apply low-pass filter to remove high-frequency noise."""
filtered_data = phase_data.copy()
# Check if data is large enough for filtering
min_filter_length = 18 # Minimum length required for filtfilt with order 4
if phase_data.shape[1] < min_filter_length:
# Skip filtering for small arrays
return filtered_data
# Apply Butterworth low-pass filter
nyquist = 0.5
cutoff = threshold * nyquist
# Design filter
b, a = signal.butter(4, cutoff, btype='low')
# Apply filter to each antenna
for i in range(phase_data.shape[0]):
filtered_data[i, :] = signal.filtfilt(b, a, phase_data[i, :])
return filtered_data
def sanitize_phase(self, phase_data: np.ndarray) -> np.ndarray:
"""Sanitize phase data through complete pipeline.
Args:
phase_data: Raw phase data (2D array)
Returns:
Sanitized phase data
Raises:
PhaseSanitizationError: If sanitization fails
"""
try:
self._total_processed += 1
# Validate input data
self.validate_phase_data(phase_data)
# Apply complete sanitization pipeline
sanitized_data = self.unwrap_phase(phase_data)
sanitized_data = self.remove_outliers(sanitized_data)
sanitized_data = self.smooth_phase(sanitized_data)
sanitized_data = self.filter_noise(sanitized_data)
return sanitized_data
except PhaseSanitizationError:
self._sanitization_errors += 1
raise
except Exception as e:
self._sanitization_errors += 1
raise PhaseSanitizationError(f"Sanitization pipeline failed: {e}")
def validate_phase_data(self, phase_data: np.ndarray) -> bool:
"""Validate phase data format and values.
Args:
phase_data: Phase data to validate
Returns:
True if valid
Raises:
PhaseSanitizationError: If validation fails
"""
# Check if data is 2D
if phase_data.ndim != 2:
raise PhaseSanitizationError("Phase data must be 2D array")
# Check if data is not empty
if phase_data.size == 0:
raise PhaseSanitizationError("Phase data cannot be empty")
# Check if values are within valid range
min_val, max_val = self.phase_range
if np.any(phase_data < min_val) or np.any(phase_data > max_val):
raise PhaseSanitizationError(f"Phase values outside valid range [{min_val}, {max_val}]")
return True
def get_sanitization_statistics(self) -> Dict[str, Any]:
"""Get sanitization statistics.
Returns:
Dictionary containing sanitization statistics
"""
outlier_rate = self._outliers_removed / self._total_processed if self._total_processed > 0 else 0
error_rate = self._sanitization_errors / self._total_processed if self._total_processed > 0 else 0
return {
'total_processed': self._total_processed,
'outliers_removed': self._outliers_removed,
'sanitization_errors': self._sanitization_errors,
'outlier_rate': outlier_rate,
'error_rate': error_rate
}
def reset_statistics(self) -> None:
"""Reset sanitization statistics."""
self._total_processed = 0
self._outliers_removed = 0
self._sanitization_errors = 0

View File

@@ -0,0 +1,340 @@
"""
Router interface for WiFi CSI data collection
"""
import logging
import asyncio
import time
from typing import Dict, List, Optional, Any
from datetime import datetime
import numpy as np
logger = logging.getLogger(__name__)
class RouterInterface:
"""Interface for connecting to WiFi routers and collecting CSI data."""
def __init__(
self,
router_id: str,
host: str,
port: int = 22,
username: str = "admin",
password: str = "",
interface: str = "wlan0",
mock_mode: bool = False
):
"""Initialize router interface.
Args:
router_id: Unique identifier for the router
host: Router IP address or hostname
port: SSH port for connection
username: SSH username
password: SSH password
interface: WiFi interface name
mock_mode: Whether to use mock data instead of real connection
"""
self.router_id = router_id
self.host = host
self.port = port
self.username = username
self.password = password
self.interface = interface
self.mock_mode = mock_mode
self.logger = logging.getLogger(f"{__name__}.{router_id}")
# Connection state
self.is_connected = False
self.connection = None
self.last_error = None
# Data collection state
self.last_data_time = None
self.error_count = 0
self.sample_count = 0
# Mock data generation
self.mock_data_generator = None
if mock_mode:
self._initialize_mock_generator()
def _initialize_mock_generator(self):
"""Initialize mock data generator."""
self.mock_data_generator = {
'phase': 0,
'amplitude_base': 1.0,
'frequency': 0.1,
'noise_level': 0.1
}
async def connect(self):
"""Connect to the router."""
if self.mock_mode:
self.is_connected = True
self.logger.info(f"Mock connection established to router {self.router_id}")
return
try:
self.logger.info(f"Connecting to router {self.router_id} at {self.host}:{self.port}")
# In a real implementation, this would establish SSH connection
# For now, we'll simulate the connection
await asyncio.sleep(0.1) # Simulate connection delay
self.is_connected = True
self.error_count = 0
self.logger.info(f"Connected to router {self.router_id}")
except Exception as e:
self.last_error = str(e)
self.error_count += 1
self.logger.error(f"Failed to connect to router {self.router_id}: {e}")
raise
async def disconnect(self):
"""Disconnect from the router."""
try:
if self.connection:
# Close SSH connection
self.connection = None
self.is_connected = False
self.logger.info(f"Disconnected from router {self.router_id}")
except Exception as e:
self.logger.error(f"Error disconnecting from router {self.router_id}: {e}")
async def reconnect(self):
"""Reconnect to the router."""
await self.disconnect()
await asyncio.sleep(1) # Wait before reconnecting
await self.connect()
async def get_csi_data(self) -> Optional[np.ndarray]:
"""Get CSI data from the router.
Returns:
CSI data as numpy array, or None if no data available
"""
if not self.is_connected:
raise RuntimeError(f"Router {self.router_id} is not connected")
try:
if self.mock_mode:
csi_data = self._generate_mock_csi_data()
else:
csi_data = await self._collect_real_csi_data()
if csi_data is not None:
self.last_data_time = datetime.now()
self.sample_count += 1
self.error_count = 0
return csi_data
except Exception as e:
self.last_error = str(e)
self.error_count += 1
self.logger.error(f"Error getting CSI data from router {self.router_id}: {e}")
return None
def _generate_mock_csi_data(self) -> np.ndarray:
"""Generate mock CSI data for testing."""
# Simulate CSI data with realistic characteristics
num_subcarriers = 64
num_antennas = 4
num_samples = 100
# Update mock generator state
self.mock_data_generator['phase'] += self.mock_data_generator['frequency']
# Generate amplitude and phase data
time_axis = np.linspace(0, 1, num_samples)
# Create realistic CSI patterns
csi_data = np.zeros((num_antennas, num_subcarriers, num_samples), dtype=complex)
for antenna in range(num_antennas):
for subcarrier in range(num_subcarriers):
# Base signal with some variation per antenna/subcarrier
amplitude = (
self.mock_data_generator['amplitude_base'] *
(1 + 0.2 * np.sin(2 * np.pi * subcarrier / num_subcarriers)) *
(1 + 0.1 * antenna)
)
# Phase with spatial and frequency variation
phase_offset = (
self.mock_data_generator['phase'] +
2 * np.pi * subcarrier / num_subcarriers +
np.pi * antenna / num_antennas
)
# Add some movement simulation
movement_freq = 0.5 # Hz
movement_amplitude = 0.3
movement = movement_amplitude * np.sin(2 * np.pi * movement_freq * time_axis)
# Generate complex signal
signal_amplitude = amplitude * (1 + movement)
signal_phase = phase_offset + movement * 0.5
# Add noise
noise_real = np.random.normal(0, self.mock_data_generator['noise_level'], num_samples)
noise_imag = np.random.normal(0, self.mock_data_generator['noise_level'], num_samples)
noise = noise_real + 1j * noise_imag
# Create complex signal
signal = signal_amplitude * np.exp(1j * signal_phase) + noise
csi_data[antenna, subcarrier, :] = signal
return csi_data
async def _collect_real_csi_data(self) -> Optional[np.ndarray]:
"""Collect real CSI data from router (placeholder implementation)."""
# This would implement the actual CSI data collection
# For now, return None to indicate no real implementation
self.logger.warning("Real CSI data collection not implemented")
return None
async def check_health(self) -> bool:
"""Check if the router connection is healthy.
Returns:
True if healthy, False otherwise
"""
if not self.is_connected:
return False
try:
# In mock mode, always healthy
if self.mock_mode:
return True
# For real connections, we could ping the router or check SSH connection
# For now, consider healthy if error count is low
return self.error_count < 5
except Exception as e:
self.logger.error(f"Error checking health of router {self.router_id}: {e}")
return False
async def get_status(self) -> Dict[str, Any]:
"""Get router status information.
Returns:
Dictionary containing router status
"""
return {
"router_id": self.router_id,
"connected": self.is_connected,
"mock_mode": self.mock_mode,
"last_data_time": self.last_data_time.isoformat() if self.last_data_time else None,
"error_count": self.error_count,
"sample_count": self.sample_count,
"last_error": self.last_error,
"configuration": {
"host": self.host,
"port": self.port,
"username": self.username,
"interface": self.interface
}
}
async def get_router_info(self) -> Dict[str, Any]:
"""Get router hardware information.
Returns:
Dictionary containing router information
"""
if self.mock_mode:
return {
"model": "Mock Router",
"firmware": "1.0.0-mock",
"wifi_standard": "802.11ac",
"antennas": 4,
"supported_bands": ["2.4GHz", "5GHz"],
"csi_capabilities": {
"max_subcarriers": 64,
"max_antennas": 4,
"sampling_rate": 1000
}
}
# For real routers, this would query the actual hardware
return {
"model": "Unknown",
"firmware": "Unknown",
"wifi_standard": "Unknown",
"antennas": 1,
"supported_bands": ["Unknown"],
"csi_capabilities": {
"max_subcarriers": 64,
"max_antennas": 1,
"sampling_rate": 100
}
}
async def configure_csi_collection(self, config: Dict[str, Any]) -> bool:
"""Configure CSI data collection parameters.
Args:
config: Configuration dictionary
Returns:
True if configuration successful, False otherwise
"""
try:
if self.mock_mode:
# Update mock generator parameters
if 'sampling_rate' in config:
self.mock_data_generator['frequency'] = config['sampling_rate'] / 1000.0
if 'noise_level' in config:
self.mock_data_generator['noise_level'] = config['noise_level']
self.logger.info(f"Mock CSI collection configured for router {self.router_id}")
return True
# For real routers, this would send configuration commands
self.logger.warning("Real CSI configuration not implemented")
return False
except Exception as e:
self.logger.error(f"Error configuring CSI collection for router {self.router_id}: {e}")
return False
def get_metrics(self) -> Dict[str, Any]:
"""Get router interface metrics.
Returns:
Dictionary containing metrics
"""
uptime = 0
if self.last_data_time:
uptime = (datetime.now() - self.last_data_time).total_seconds()
success_rate = 0
if self.sample_count > 0:
success_rate = (self.sample_count - self.error_count) / self.sample_count
return {
"router_id": self.router_id,
"sample_count": self.sample_count,
"error_count": self.error_count,
"success_rate": success_rate,
"uptime_seconds": uptime,
"is_connected": self.is_connected,
"mock_mode": self.mock_mode
}
def reset_stats(self):
"""Reset statistics counters."""
self.error_count = 0
self.sample_count = 0
self.last_error = None
self.logger.info(f"Statistics reset for router {self.router_id}")

View File

@@ -0,0 +1,640 @@
"""
Database connection management for WiFi-DensePose API
"""
import asyncio
import logging
from typing import Optional, Dict, Any, AsyncGenerator
from contextlib import asynccontextmanager
from datetime import datetime
from sqlalchemy import create_engine, event, pool, text
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.pool import QueuePool, NullPool
from sqlalchemy.exc import SQLAlchemyError, DisconnectionError
import redis.asyncio as redis
from redis.exceptions import ConnectionError as RedisConnectionError
from src.config.settings import Settings
from src.logger import get_logger
logger = get_logger(__name__)
class DatabaseConnectionError(Exception):
"""Database connection error."""
pass
class DatabaseManager:
"""Database connection manager."""
def __init__(self, settings: Settings):
self.settings = settings
self._async_engine = None
self._sync_engine = None
self._async_session_factory = None
self._sync_session_factory = None
self._redis_client = None
self._initialized = False
self._connection_pool_size = settings.db_pool_size
self._max_overflow = settings.db_max_overflow
self._pool_timeout = settings.db_pool_timeout
self._pool_recycle = settings.db_pool_recycle
async def initialize(self):
"""Initialize database connections."""
if self._initialized:
return
logger.info("Initializing database connections")
try:
# Initialize PostgreSQL connections
await self._initialize_postgresql()
# Initialize Redis connection
await self._initialize_redis()
self._initialized = True
logger.info("Database connections initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize database connections: {e}")
raise DatabaseConnectionError(f"Database initialization failed: {e}")
async def _initialize_postgresql(self):
"""Initialize PostgreSQL connections with SQLite failsafe."""
postgresql_failed = False
try:
# Try PostgreSQL first
await self._initialize_postgresql_primary()
logger.info("PostgreSQL connections initialized")
return
except Exception as e:
postgresql_failed = True
logger.error(f"PostgreSQL initialization failed: {e}")
if not self.settings.enable_database_failsafe:
raise DatabaseConnectionError(f"PostgreSQL connection failed and failsafe disabled: {e}")
logger.warning("Falling back to SQLite database")
# Fallback to SQLite if PostgreSQL failed and failsafe is enabled
if postgresql_failed and self.settings.enable_database_failsafe:
await self._initialize_sqlite_fallback()
logger.info("SQLite fallback database initialized")
async def _initialize_postgresql_primary(self):
"""Initialize primary PostgreSQL connections."""
# Build database URL
if self.settings.database_url and "postgresql" in self.settings.database_url:
db_url = self.settings.database_url
async_db_url = self.settings.database_url.replace("postgresql://", "postgresql+asyncpg://")
elif self.settings.db_host and self.settings.db_name and self.settings.db_user:
db_url = (
f"postgresql://{self.settings.db_user}:{self.settings.db_password}"
f"@{self.settings.db_host}:{self.settings.db_port}/{self.settings.db_name}"
)
async_db_url = (
f"postgresql+asyncpg://{self.settings.db_user}:{self.settings.db_password}"
f"@{self.settings.db_host}:{self.settings.db_port}/{self.settings.db_name}"
)
else:
raise ValueError("PostgreSQL connection parameters not configured")
# Create async engine (don't specify poolclass for async engines)
self._async_engine = create_async_engine(
async_db_url,
pool_size=self._connection_pool_size,
max_overflow=self._max_overflow,
pool_timeout=self._pool_timeout,
pool_recycle=self._pool_recycle,
pool_pre_ping=True,
echo=self.settings.db_echo,
future=True,
)
# Create sync engine for migrations and admin tasks
self._sync_engine = create_engine(
db_url,
poolclass=QueuePool,
pool_size=max(2, self._connection_pool_size // 2),
max_overflow=self._max_overflow // 2,
pool_timeout=self._pool_timeout,
pool_recycle=self._pool_recycle,
pool_pre_ping=True,
echo=self.settings.db_echo,
future=True,
)
# Create session factories
self._async_session_factory = async_sessionmaker(
self._async_engine,
class_=AsyncSession,
expire_on_commit=False,
)
self._sync_session_factory = sessionmaker(
self._sync_engine,
expire_on_commit=False,
)
# Add connection event listeners
self._setup_connection_events()
# Test connections
await self._test_postgresql_connection()
async def _initialize_sqlite_fallback(self):
"""Initialize SQLite fallback database."""
import os
# Ensure directory exists
sqlite_path = self.settings.sqlite_fallback_path
os.makedirs(os.path.dirname(sqlite_path), exist_ok=True)
# Build SQLite URLs
db_url = f"sqlite:///{sqlite_path}"
async_db_url = f"sqlite+aiosqlite:///{sqlite_path}"
# Create async engine for SQLite
self._async_engine = create_async_engine(
async_db_url,
echo=self.settings.db_echo,
future=True,
)
# Create sync engine for SQLite
self._sync_engine = create_engine(
db_url,
poolclass=NullPool, # SQLite doesn't need connection pooling
echo=self.settings.db_echo,
future=True,
)
# Create session factories
self._async_session_factory = async_sessionmaker(
self._async_engine,
class_=AsyncSession,
expire_on_commit=False,
)
self._sync_session_factory = sessionmaker(
self._sync_engine,
expire_on_commit=False,
)
# Add connection event listeners
self._setup_connection_events()
# Test SQLite connection
await self._test_sqlite_connection()
async def _test_sqlite_connection(self):
"""Test SQLite connection."""
try:
async with self._async_engine.begin() as conn:
result = await conn.execute(text("SELECT 1"))
result.fetchone() # Don't await this - fetchone() is not async
logger.debug("SQLite connection test successful")
except Exception as e:
logger.error(f"SQLite connection test failed: {e}")
raise DatabaseConnectionError(f"SQLite connection test failed: {e}")
async def _initialize_redis(self):
"""Initialize Redis connection with failsafe."""
if not self.settings.redis_enabled:
logger.info("Redis disabled, skipping initialization")
return
try:
# Build Redis URL
if self.settings.redis_url:
redis_url = self.settings.redis_url
else:
redis_url = (
f"redis://{self.settings.redis_host}:{self.settings.redis_port}"
f"/{self.settings.redis_db}"
)
# Create Redis client
self._redis_client = redis.from_url(
redis_url,
password=self.settings.redis_password,
encoding="utf-8",
decode_responses=True,
max_connections=self.settings.redis_max_connections,
retry_on_timeout=True,
socket_timeout=self.settings.redis_socket_timeout,
socket_connect_timeout=self.settings.redis_connect_timeout,
)
# Test Redis connection
await self._test_redis_connection()
logger.info("Redis connection initialized")
except Exception as e:
logger.error(f"Failed to initialize Redis: {e}")
if self.settings.redis_required:
raise DatabaseConnectionError(f"Redis connection failed and is required: {e}")
elif self.settings.enable_redis_failsafe:
logger.warning("Redis initialization failed, continuing without Redis (failsafe enabled)")
self._redis_client = None
else:
logger.warning("Redis initialization failed but not required, continuing without Redis")
self._redis_client = None
def _setup_connection_events(self):
"""Setup database connection event listeners."""
@event.listens_for(self._sync_engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
"""Set database-specific settings on connection."""
if "sqlite" in str(self._sync_engine.url):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
@event.listens_for(self._sync_engine, "checkout")
def receive_checkout(dbapi_connection, connection_record, connection_proxy):
"""Log connection checkout."""
logger.debug("Database connection checked out")
@event.listens_for(self._sync_engine, "checkin")
def receive_checkin(dbapi_connection, connection_record):
"""Log connection checkin."""
logger.debug("Database connection checked in")
@event.listens_for(self._sync_engine, "invalidate")
def receive_invalidate(dbapi_connection, connection_record, exception):
"""Handle connection invalidation."""
logger.warning(f"Database connection invalidated: {exception}")
async def _test_postgresql_connection(self):
"""Test PostgreSQL connection."""
try:
async with self._async_engine.begin() as conn:
result = await conn.execute(text("SELECT 1"))
result.fetchone() # Don't await this - fetchone() is not async
logger.debug("PostgreSQL connection test successful")
except Exception as e:
logger.error(f"PostgreSQL connection test failed: {e}")
raise DatabaseConnectionError(f"PostgreSQL connection test failed: {e}")
async def _test_redis_connection(self):
"""Test Redis connection."""
if not self._redis_client:
return
try:
await self._redis_client.ping()
logger.debug("Redis connection test successful")
except Exception as e:
logger.error(f"Redis connection test failed: {e}")
if self.settings.redis_required:
raise DatabaseConnectionError(f"Redis connection test failed: {e}")
@asynccontextmanager
async def get_async_session(self) -> AsyncGenerator[AsyncSession, None]:
"""Get async database session."""
if not self._initialized:
await self.initialize()
if not self._async_session_factory:
raise DatabaseConnectionError("Async session factory not initialized")
session = self._async_session_factory()
try:
yield session
await session.commit()
except Exception as e:
await session.rollback()
logger.error(f"Database session error: {e}")
raise
finally:
await session.close()
@asynccontextmanager
async def get_sync_session(self) -> Session:
"""Get sync database session."""
if not self._initialized:
await self.initialize()
if not self._sync_session_factory:
raise DatabaseConnectionError("Sync session factory not initialized")
session = self._sync_session_factory()
try:
yield session
session.commit()
except Exception as e:
session.rollback()
logger.error(f"Database session error: {e}")
raise
finally:
session.close()
async def get_redis_client(self) -> Optional[redis.Redis]:
"""Get Redis client."""
if not self._initialized:
await self.initialize()
return self._redis_client
async def health_check(self) -> Dict[str, Any]:
"""Perform database health check."""
health_status = {
"database": {"status": "unknown", "details": {}},
"redis": {"status": "unknown", "details": {}},
"overall": "unknown"
}
# Check Database (PostgreSQL or SQLite)
try:
start_time = datetime.utcnow()
async with self.get_async_session() as session:
result = await session.execute(text("SELECT 1"))
result.fetchone() # Don't await this - fetchone() is not async
response_time = (datetime.utcnow() - start_time).total_seconds()
# Determine database type and status
is_sqlite = self.is_using_sqlite_fallback()
db_type = "sqlite_fallback" if is_sqlite else "postgresql"
details = {
"type": db_type,
"response_time_ms": round(response_time * 1000, 2),
}
# Add pool info for PostgreSQL
if not is_sqlite and hasattr(self._async_engine, 'pool'):
details.update({
"pool_size": self._async_engine.pool.size(),
"checked_out": self._async_engine.pool.checkedout(),
"overflow": self._async_engine.pool.overflow(),
})
# Add failsafe info
if is_sqlite:
details["failsafe_active"] = True
details["fallback_path"] = self.settings.sqlite_fallback_path
health_status["database"] = {
"status": "healthy",
"details": details
}
except Exception as e:
health_status["database"] = {
"status": "unhealthy",
"details": {"error": str(e)}
}
# Check Redis
if self._redis_client:
try:
start_time = datetime.utcnow()
await self._redis_client.ping()
response_time = (datetime.utcnow() - start_time).total_seconds()
info = await self._redis_client.info()
health_status["redis"] = {
"status": "healthy",
"details": {
"response_time_ms": round(response_time * 1000, 2),
"connected_clients": info.get("connected_clients", 0),
"used_memory": info.get("used_memory_human", "unknown"),
"uptime": info.get("uptime_in_seconds", 0),
}
}
except Exception as e:
health_status["redis"] = {
"status": "unhealthy",
"details": {"error": str(e)}
}
else:
health_status["redis"] = {
"status": "disabled",
"details": {"message": "Redis not enabled"}
}
# Determine overall status
database_healthy = health_status["database"]["status"] == "healthy"
redis_healthy = (
health_status["redis"]["status"] in ["healthy", "disabled"] or
not self.settings.redis_required
)
# Check if using failsafe modes
using_sqlite_fallback = self.is_using_sqlite_fallback()
redis_unavailable = not self.is_redis_available() and self.settings.redis_enabled
if database_healthy and redis_healthy:
if using_sqlite_fallback or redis_unavailable:
health_status["overall"] = "degraded" # Working but using failsafe
else:
health_status["overall"] = "healthy"
elif database_healthy:
health_status["overall"] = "degraded"
else:
health_status["overall"] = "unhealthy"
return health_status
async def get_connection_stats(self) -> Dict[str, Any]:
"""Get database connection statistics."""
stats = {
"postgresql": {},
"redis": {}
}
# PostgreSQL stats
if self._async_engine:
pool = self._async_engine.pool
stats["postgresql"] = {
"pool_size": pool.size(),
"checked_out": pool.checkedout(),
"overflow": pool.overflow(),
"checked_in": pool.checkedin(),
"total_connections": pool.size() + pool.overflow(),
"available_connections": pool.size() - pool.checkedout(),
}
# Redis stats
if self._redis_client:
try:
info = await self._redis_client.info()
stats["redis"] = {
"connected_clients": info.get("connected_clients", 0),
"blocked_clients": info.get("blocked_clients", 0),
"total_connections_received": info.get("total_connections_received", 0),
"rejected_connections": info.get("rejected_connections", 0),
}
except Exception as e:
stats["redis"] = {"error": str(e)}
return stats
async def close_connections(self):
"""Close all database connections."""
logger.info("Closing database connections")
# Close PostgreSQL connections
if self._async_engine:
await self._async_engine.dispose()
logger.debug("Async PostgreSQL engine disposed")
if self._sync_engine:
self._sync_engine.dispose()
logger.debug("Sync PostgreSQL engine disposed")
# Close Redis connection
if self._redis_client:
await self._redis_client.close()
logger.debug("Redis connection closed")
self._initialized = False
logger.info("Database connections closed")
def is_using_sqlite_fallback(self) -> bool:
"""Check if currently using SQLite fallback database."""
if not self._async_engine:
return False
return "sqlite" in str(self._async_engine.url)
def is_redis_available(self) -> bool:
"""Check if Redis is available."""
return self._redis_client is not None
async def test_connection(self) -> bool:
"""Test database connection for CLI validation."""
try:
if not self._initialized:
await self.initialize()
# Test database connection (PostgreSQL or SQLite)
async with self.get_async_session() as session:
result = await session.execute(text("SELECT 1"))
result.fetchone() # Don't await this - fetchone() is not async
# Test Redis connection if enabled
if self._redis_client:
await self._redis_client.ping()
return True
except Exception as e:
logger.error(f"Database connection test failed: {e}")
return False
async def reset_connections(self):
"""Reset all database connections."""
logger.info("Resetting database connections")
await self.close_connections()
await self.initialize()
logger.info("Database connections reset")
# Global database manager instance
_db_manager: Optional[DatabaseManager] = None
def get_database_manager(settings: Settings) -> DatabaseManager:
"""Get database manager instance."""
global _db_manager
if _db_manager is None:
_db_manager = DatabaseManager(settings)
return _db_manager
async def get_async_session(settings: Settings) -> AsyncGenerator[AsyncSession, None]:
"""Dependency to get async database session."""
db_manager = get_database_manager(settings)
async with db_manager.get_async_session() as session:
yield session
async def get_redis_client(settings: Settings) -> Optional[redis.Redis]:
"""Dependency to get Redis client."""
db_manager = get_database_manager(settings)
return await db_manager.get_redis_client()
class DatabaseHealthCheck:
"""Database health check utility."""
def __init__(self, db_manager: DatabaseManager):
self.db_manager = db_manager
async def check_postgresql(self) -> Dict[str, Any]:
"""Check PostgreSQL health."""
try:
start_time = datetime.utcnow()
async with self.db_manager.get_async_session() as session:
result = await session.execute(text("SELECT version()"))
version = result.fetchone()[0] # Don't await this - fetchone() is not async
response_time = (datetime.utcnow() - start_time).total_seconds()
return {
"status": "healthy",
"version": version,
"response_time_ms": round(response_time * 1000, 2),
}
except Exception as e:
return {
"status": "unhealthy",
"error": str(e),
}
async def check_redis(self) -> Dict[str, Any]:
"""Check Redis health."""
redis_client = await self.db_manager.get_redis_client()
if not redis_client:
return {
"status": "disabled",
"message": "Redis not configured"
}
try:
start_time = datetime.utcnow()
pong = await redis_client.ping()
response_time = (datetime.utcnow() - start_time).total_seconds()
info = await redis_client.info("server")
return {
"status": "healthy",
"ping": pong,
"version": info.get("redis_version", "unknown"),
"response_time_ms": round(response_time * 1000, 2),
}
except Exception as e:
return {
"status": "unhealthy",
"error": str(e),
}
async def full_health_check(self) -> Dict[str, Any]:
"""Perform full database health check."""
postgresql_health = await self.check_postgresql()
redis_health = await self.check_redis()
overall_status = "healthy"
if postgresql_health["status"] != "healthy":
overall_status = "unhealthy"
elif redis_health["status"] == "unhealthy":
overall_status = "degraded"
return {
"overall_status": overall_status,
"postgresql": postgresql_health,
"redis": redis_health,
"timestamp": datetime.utcnow().isoformat(),
}

View File

@@ -0,0 +1,370 @@
"""
Initial database migration for WiFi-DensePose API
Revision ID: 001_initial
Revises:
Create Date: 2025-01-07 07:58:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers
revision = '001_initial'
down_revision = None
branch_labels = None
depends_on = None
def upgrade():
"""Create initial database schema."""
# Create devices table
op.create_table(
'devices',
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('device_type', sa.String(length=50), nullable=False),
sa.Column('mac_address', sa.String(length=17), nullable=False),
sa.Column('ip_address', sa.String(length=45), nullable=True),
sa.Column('status', sa.String(length=20), nullable=False),
sa.Column('firmware_version', sa.String(length=50), nullable=True),
sa.Column('hardware_version', sa.String(length=50), nullable=True),
sa.Column('location_name', sa.String(length=255), nullable=True),
sa.Column('room_id', sa.String(length=100), nullable=True),
sa.Column('coordinates_x', sa.Float(), nullable=True),
sa.Column('coordinates_y', sa.Float(), nullable=True),
sa.Column('coordinates_z', sa.Float(), nullable=True),
sa.Column('config', sa.JSON(), nullable=True),
sa.Column('capabilities', postgresql.ARRAY(sa.String()), nullable=True),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('tags', postgresql.ARRAY(sa.String()), nullable=True),
sa.CheckConstraint("status IN ('active', 'inactive', 'maintenance', 'error')", name='check_device_status'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('mac_address')
)
# Create indexes for devices table
op.create_index('idx_device_mac_address', 'devices', ['mac_address'])
op.create_index('idx_device_status', 'devices', ['status'])
op.create_index('idx_device_type', 'devices', ['device_type'])
# Create sessions table
op.create_table(
'sessions',
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('started_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('ended_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('duration_seconds', sa.Integer(), nullable=True),
sa.Column('status', sa.String(length=20), nullable=False),
sa.Column('config', sa.JSON(), nullable=True),
sa.Column('device_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('tags', postgresql.ARRAY(sa.String()), nullable=True),
sa.Column('metadata', sa.JSON(), nullable=True),
sa.Column('total_frames', sa.Integer(), nullable=False),
sa.Column('processed_frames', sa.Integer(), nullable=False),
sa.Column('error_count', sa.Integer(), nullable=False),
sa.CheckConstraint("status IN ('active', 'completed', 'failed', 'cancelled')", name='check_session_status'),
sa.CheckConstraint('total_frames >= 0', name='check_total_frames_positive'),
sa.CheckConstraint('processed_frames >= 0', name='check_processed_frames_positive'),
sa.CheckConstraint('error_count >= 0', name='check_error_count_positive'),
sa.ForeignKeyConstraint(['device_id'], ['devices.id'], ),
sa.PrimaryKeyConstraint('id')
)
# Create indexes for sessions table
op.create_index('idx_session_device_id', 'sessions', ['device_id'])
op.create_index('idx_session_status', 'sessions', ['status'])
op.create_index('idx_session_started_at', 'sessions', ['started_at'])
# Create csi_data table
op.create_table(
'csi_data',
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('sequence_number', sa.Integer(), nullable=False),
sa.Column('timestamp_ns', sa.BigInteger(), nullable=False),
sa.Column('device_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('session_id', postgresql.UUID(as_uuid=True), nullable=True),
sa.Column('amplitude', postgresql.ARRAY(sa.Float()), nullable=False),
sa.Column('phase', postgresql.ARRAY(sa.Float()), nullable=False),
sa.Column('frequency', sa.Float(), nullable=False),
sa.Column('bandwidth', sa.Float(), nullable=False),
sa.Column('rssi', sa.Float(), nullable=True),
sa.Column('snr', sa.Float(), nullable=True),
sa.Column('noise_floor', sa.Float(), nullable=True),
sa.Column('tx_antenna', sa.Integer(), nullable=True),
sa.Column('rx_antenna', sa.Integer(), nullable=True),
sa.Column('num_subcarriers', sa.Integer(), nullable=False),
sa.Column('processing_status', sa.String(length=20), nullable=False),
sa.Column('processed_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('quality_score', sa.Float(), nullable=True),
sa.Column('is_valid', sa.Boolean(), nullable=False),
sa.Column('metadata', sa.JSON(), nullable=True),
sa.CheckConstraint('frequency > 0', name='check_frequency_positive'),
sa.CheckConstraint('bandwidth > 0', name='check_bandwidth_positive'),
sa.CheckConstraint('num_subcarriers > 0', name='check_subcarriers_positive'),
sa.CheckConstraint("processing_status IN ('pending', 'processing', 'completed', 'failed')", name='check_processing_status'),
sa.ForeignKeyConstraint(['device_id'], ['devices.id'], ),
sa.ForeignKeyConstraint(['session_id'], ['sessions.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('device_id', 'sequence_number', 'timestamp_ns', name='uq_csi_device_seq_time')
)
# Create indexes for csi_data table
op.create_index('idx_csi_device_id', 'csi_data', ['device_id'])
op.create_index('idx_csi_session_id', 'csi_data', ['session_id'])
op.create_index('idx_csi_timestamp', 'csi_data', ['timestamp_ns'])
op.create_index('idx_csi_sequence', 'csi_data', ['sequence_number'])
op.create_index('idx_csi_processing_status', 'csi_data', ['processing_status'])
# Create pose_detections table
op.create_table(
'pose_detections',
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('frame_number', sa.Integer(), nullable=False),
sa.Column('timestamp_ns', sa.BigInteger(), nullable=False),
sa.Column('session_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('person_count', sa.Integer(), nullable=False),
sa.Column('keypoints', sa.JSON(), nullable=True),
sa.Column('bounding_boxes', sa.JSON(), nullable=True),
sa.Column('detection_confidence', sa.Float(), nullable=True),
sa.Column('pose_confidence', sa.Float(), nullable=True),
sa.Column('overall_confidence', sa.Float(), nullable=True),
sa.Column('processing_time_ms', sa.Float(), nullable=True),
sa.Column('model_version', sa.String(length=50), nullable=True),
sa.Column('algorithm', sa.String(length=100), nullable=True),
sa.Column('image_quality', sa.Float(), nullable=True),
sa.Column('pose_quality', sa.Float(), nullable=True),
sa.Column('is_valid', sa.Boolean(), nullable=False),
sa.Column('metadata', sa.JSON(), nullable=True),
sa.CheckConstraint('person_count >= 0', name='check_person_count_positive'),
sa.CheckConstraint('detection_confidence >= 0 AND detection_confidence <= 1', name='check_detection_confidence_range'),
sa.CheckConstraint('pose_confidence >= 0 AND pose_confidence <= 1', name='check_pose_confidence_range'),
sa.CheckConstraint('overall_confidence >= 0 AND overall_confidence <= 1', name='check_overall_confidence_range'),
sa.ForeignKeyConstraint(['session_id'], ['sessions.id'], ),
sa.PrimaryKeyConstraint('id')
)
# Create indexes for pose_detections table
op.create_index('idx_pose_session_id', 'pose_detections', ['session_id'])
op.create_index('idx_pose_timestamp', 'pose_detections', ['timestamp_ns'])
op.create_index('idx_pose_frame', 'pose_detections', ['frame_number'])
op.create_index('idx_pose_person_count', 'pose_detections', ['person_count'])
# Create system_metrics table
op.create_table(
'system_metrics',
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('metric_name', sa.String(length=255), nullable=False),
sa.Column('metric_type', sa.String(length=50), nullable=False),
sa.Column('value', sa.Float(), nullable=False),
sa.Column('unit', sa.String(length=50), nullable=True),
sa.Column('labels', sa.JSON(), nullable=True),
sa.Column('tags', postgresql.ARRAY(sa.String()), nullable=True),
sa.Column('source', sa.String(length=255), nullable=True),
sa.Column('component', sa.String(length=100), nullable=True),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('metadata', sa.JSON(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
# Create indexes for system_metrics table
op.create_index('idx_metric_name', 'system_metrics', ['metric_name'])
op.create_index('idx_metric_type', 'system_metrics', ['metric_type'])
op.create_index('idx_metric_created_at', 'system_metrics', ['created_at'])
op.create_index('idx_metric_source', 'system_metrics', ['source'])
op.create_index('idx_metric_component', 'system_metrics', ['component'])
# Create audit_logs table
op.create_table(
'audit_logs',
sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('event_type', sa.String(length=100), nullable=False),
sa.Column('event_name', sa.String(length=255), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('user_id', sa.String(length=255), nullable=True),
sa.Column('session_id', sa.String(length=255), nullable=True),
sa.Column('ip_address', sa.String(length=45), nullable=True),
sa.Column('user_agent', sa.Text(), nullable=True),
sa.Column('resource_type', sa.String(length=100), nullable=True),
sa.Column('resource_id', sa.String(length=255), nullable=True),
sa.Column('before_state', sa.JSON(), nullable=True),
sa.Column('after_state', sa.JSON(), nullable=True),
sa.Column('changes', sa.JSON(), nullable=True),
sa.Column('success', sa.Boolean(), nullable=False),
sa.Column('error_message', sa.Text(), nullable=True),
sa.Column('metadata', sa.JSON(), nullable=True),
sa.Column('tags', postgresql.ARRAY(sa.String()), nullable=True),
sa.PrimaryKeyConstraint('id')
)
# Create indexes for audit_logs table
op.create_index('idx_audit_event_type', 'audit_logs', ['event_type'])
op.create_index('idx_audit_user_id', 'audit_logs', ['user_id'])
op.create_index('idx_audit_resource', 'audit_logs', ['resource_type', 'resource_id'])
op.create_index('idx_audit_created_at', 'audit_logs', ['created_at'])
op.create_index('idx_audit_success', 'audit_logs', ['success'])
# Create triggers for updated_at columns
op.execute("""
CREATE OR REPLACE FUNCTION update_updated_at_column()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = now();
RETURN NEW;
END;
$$ language 'plpgsql';
""")
# Add triggers to all tables with updated_at column
tables_with_updated_at = [
'devices', 'sessions', 'csi_data', 'pose_detections',
'system_metrics', 'audit_logs'
]
for table in tables_with_updated_at:
op.execute(f"""
CREATE TRIGGER update_{table}_updated_at
BEFORE UPDATE ON {table}
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
""")
# Insert initial data
_insert_initial_data()
def downgrade():
"""Drop all tables and functions."""
# Drop triggers first
tables_with_updated_at = [
'devices', 'sessions', 'csi_data', 'pose_detections',
'system_metrics', 'audit_logs'
]
for table in tables_with_updated_at:
op.execute(f"DROP TRIGGER IF EXISTS update_{table}_updated_at ON {table};")
# Drop function
op.execute("DROP FUNCTION IF EXISTS update_updated_at_column();")
# Drop tables in reverse order (respecting foreign key constraints)
op.drop_table('audit_logs')
op.drop_table('system_metrics')
op.drop_table('pose_detections')
op.drop_table('csi_data')
op.drop_table('sessions')
op.drop_table('devices')
def _insert_initial_data():
"""Insert initial data into tables."""
# Insert sample device
op.execute("""
INSERT INTO devices (
id, name, device_type, mac_address, ip_address, status,
firmware_version, hardware_version, location_name, room_id,
coordinates_x, coordinates_y, coordinates_z,
config, capabilities, description, tags
) VALUES (
gen_random_uuid(),
'Demo Router',
'router',
'00:11:22:33:44:55',
'192.168.1.1',
'active',
'1.0.0',
'v1.0',
'Living Room',
'room_001',
0.0,
0.0,
2.5,
'{"channel": 6, "power": 20, "bandwidth": 80}',
ARRAY['wifi6', 'csi', 'beamforming'],
'Demo WiFi router for testing',
ARRAY['demo', 'testing']
);
""")
# Insert sample session
op.execute("""
INSERT INTO sessions (
id, name, description, started_at, status, config,
device_id, tags, metadata, total_frames, processed_frames, error_count
) VALUES (
gen_random_uuid(),
'Demo Session',
'Initial demo session for testing',
now(),
'active',
'{"duration": 3600, "sampling_rate": 100}',
(SELECT id FROM devices WHERE name = 'Demo Router' LIMIT 1),
ARRAY['demo', 'initial'],
'{"purpose": "testing", "environment": "lab"}',
0,
0,
0
);
""")
# Insert initial system metrics
metrics_data = [
('system_startup', 'counter', 1.0, 'count', 'system', 'application'),
('database_connections', 'gauge', 0.0, 'count', 'database', 'postgresql'),
('api_requests_total', 'counter', 0.0, 'count', 'api', 'http'),
('memory_usage', 'gauge', 0.0, 'bytes', 'system', 'memory'),
('cpu_usage', 'gauge', 0.0, 'percent', 'system', 'cpu'),
]
for metric_name, metric_type, value, unit, source, component in metrics_data:
op.execute(f"""
INSERT INTO system_metrics (
id, metric_name, metric_type, value, unit, source, component,
description, metadata
) VALUES (
gen_random_uuid(),
'{metric_name}',
'{metric_type}',
{value},
'{unit}',
'{source}',
'{component}',
'Initial {metric_name} metric',
'{{"initial": true, "version": "1.0.0"}}'
);
""")
# Insert initial audit log
op.execute("""
INSERT INTO audit_logs (
id, event_type, event_name, description, user_id, success,
resource_type, metadata
) VALUES (
gen_random_uuid(),
'system',
'database_migration',
'Initial database schema created',
'system',
true,
'database',
'{"migration": "001_initial", "version": "1.0.0"}'
);
""")

View File

@@ -0,0 +1,109 @@
"""Alembic environment configuration for WiFi-DensePose API."""
import asyncio
import os
import sys
from logging.config import fileConfig
from pathlib import Path
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import async_engine_from_config
from alembic import context
# Add the project root to the Python path
project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))
# Import the models and settings
from src.database.models import Base
from src.config.settings import get_settings
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def get_database_url():
"""Get the database URL from settings."""
try:
settings = get_settings()
return settings.get_database_url()
except Exception:
# Fallback to SQLite if settings can't be loaded
return "sqlite:///./data/wifi_densepose_fallback.db"
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = get_database_url()
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
"""Run migrations with a database connection."""
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
"""Run migrations in async mode."""
configuration = config.get_section(config.config_ini_section)
configuration["sqlalchemy.url"] = get_database_url()
connectable = async_engine_from_config(
configuration,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
asyncio.run(run_async_migrations())
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,26 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade database schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade database schema."""
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,60 @@
"""
Database type compatibility helpers for WiFi-DensePose API
"""
from typing import Type, Any
from sqlalchemy import String, Text, JSON
from sqlalchemy.dialects.postgresql import ARRAY as PostgreSQL_ARRAY
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import sqltypes
class ArrayType(sqltypes.TypeDecorator):
"""Array type that works with both PostgreSQL and SQLite."""
impl = Text
cache_ok = True
def __init__(self, item_type: Type = String):
super().__init__()
self.item_type = item_type
def load_dialect_impl(self, dialect):
"""Load dialect-specific implementation."""
if dialect.name == 'postgresql':
return dialect.type_descriptor(PostgreSQL_ARRAY(self.item_type))
else:
# For SQLite and others, use JSON
return dialect.type_descriptor(JSON)
def process_bind_param(self, value, dialect):
"""Process value before saving to database."""
if value is None:
return value
if dialect.name == 'postgresql':
return value
else:
# For SQLite, convert to JSON
return value if isinstance(value, (list, type(None))) else list(value)
def process_result_value(self, value, dialect):
"""Process value after loading from database."""
if value is None:
return value
if dialect.name == 'postgresql':
return value
else:
# For SQLite, value is already a list from JSON
return value if isinstance(value, list) else []
def get_array_type(item_type: Type = String) -> Type:
"""Get appropriate array type based on database."""
return ArrayType(item_type)
# Convenience types
StringArray = ArrayType(String)
FloatArray = ArrayType(sqltypes.Float)

498
v1/src/database/models.py Normal file
View File

@@ -0,0 +1,498 @@
"""
SQLAlchemy models for WiFi-DensePose API
"""
import uuid
from datetime import datetime
from typing import Optional, Dict, Any, List
from enum import Enum
from sqlalchemy import (
Column, String, Integer, Float, Boolean, DateTime, Text, JSON,
ForeignKey, Index, UniqueConstraint, CheckConstraint
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, validates
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.sql import func
# Import custom array type for compatibility
from src.database.model_types import StringArray, FloatArray
Base = declarative_base()
class TimestampMixin:
"""Mixin for timestamp fields."""
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
class UUIDMixin:
"""Mixin for UUID primary key."""
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, nullable=False)
class DeviceStatus(str, Enum):
"""Device status enumeration."""
ACTIVE = "active"
INACTIVE = "inactive"
MAINTENANCE = "maintenance"
ERROR = "error"
class SessionStatus(str, Enum):
"""Session status enumeration."""
ACTIVE = "active"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class ProcessingStatus(str, Enum):
"""Processing status enumeration."""
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
class Device(Base, UUIDMixin, TimestampMixin):
"""Device model for WiFi routers and sensors."""
__tablename__ = "devices"
# Basic device information
name = Column(String(255), nullable=False)
device_type = Column(String(50), nullable=False) # router, sensor, etc.
mac_address = Column(String(17), unique=True, nullable=False)
ip_address = Column(String(45), nullable=True) # IPv4 or IPv6
# Device status and configuration
status = Column(String(20), default=DeviceStatus.INACTIVE, nullable=False)
firmware_version = Column(String(50), nullable=True)
hardware_version = Column(String(50), nullable=True)
# Location information
location_name = Column(String(255), nullable=True)
room_id = Column(String(100), nullable=True)
coordinates_x = Column(Float, nullable=True)
coordinates_y = Column(Float, nullable=True)
coordinates_z = Column(Float, nullable=True)
# Configuration
config = Column(JSON, nullable=True)
capabilities = Column(StringArray, nullable=True)
# Metadata
description = Column(Text, nullable=True)
tags = Column(StringArray, nullable=True)
# Relationships
sessions = relationship("Session", back_populates="device", cascade="all, delete-orphan")
csi_data = relationship("CSIData", back_populates="device", cascade="all, delete-orphan")
# Constraints and indexes
__table_args__ = (
Index("idx_device_mac_address", "mac_address"),
Index("idx_device_status", "status"),
Index("idx_device_type", "device_type"),
CheckConstraint("status IN ('active', 'inactive', 'maintenance', 'error')", name="check_device_status"),
)
@validates('mac_address')
def validate_mac_address(self, key, address):
"""Validate MAC address format."""
if address and len(address) == 17:
# Basic MAC address format validation
parts = address.split(':')
if len(parts) == 6 and all(len(part) == 2 for part in parts):
return address.lower()
raise ValueError("Invalid MAC address format")
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
"id": str(self.id),
"name": self.name,
"device_type": self.device_type,
"mac_address": self.mac_address,
"ip_address": self.ip_address,
"status": self.status,
"firmware_version": self.firmware_version,
"hardware_version": self.hardware_version,
"location_name": self.location_name,
"room_id": self.room_id,
"coordinates": {
"x": self.coordinates_x,
"y": self.coordinates_y,
"z": self.coordinates_z,
} if any([self.coordinates_x, self.coordinates_y, self.coordinates_z]) else None,
"config": self.config,
"capabilities": self.capabilities,
"description": self.description,
"tags": self.tags,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
class Session(Base, UUIDMixin, TimestampMixin):
"""Session model for tracking data collection sessions."""
__tablename__ = "sessions"
# Session identification
name = Column(String(255), nullable=False)
description = Column(Text, nullable=True)
# Session timing
started_at = Column(DateTime(timezone=True), nullable=True)
ended_at = Column(DateTime(timezone=True), nullable=True)
duration_seconds = Column(Integer, nullable=True)
# Session status and configuration
status = Column(String(20), default=SessionStatus.ACTIVE, nullable=False)
config = Column(JSON, nullable=True)
# Device relationship
device_id = Column(UUID(as_uuid=True), ForeignKey("devices.id"), nullable=False)
device = relationship("Device", back_populates="sessions")
# Data relationships
csi_data = relationship("CSIData", back_populates="session", cascade="all, delete-orphan")
pose_detections = relationship("PoseDetection", back_populates="session", cascade="all, delete-orphan")
# Metadata
tags = Column(StringArray, nullable=True)
meta_data = Column(JSON, nullable=True)
# Statistics
total_frames = Column(Integer, default=0, nullable=False)
processed_frames = Column(Integer, default=0, nullable=False)
error_count = Column(Integer, default=0, nullable=False)
# Constraints and indexes
__table_args__ = (
Index("idx_session_device_id", "device_id"),
Index("idx_session_status", "status"),
Index("idx_session_started_at", "started_at"),
CheckConstraint("status IN ('active', 'completed', 'failed', 'cancelled')", name="check_session_status"),
CheckConstraint("total_frames >= 0", name="check_total_frames_positive"),
CheckConstraint("processed_frames >= 0", name="check_processed_frames_positive"),
CheckConstraint("error_count >= 0", name="check_error_count_positive"),
)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
"id": str(self.id),
"name": self.name,
"description": self.description,
"started_at": self.started_at.isoformat() if self.started_at else None,
"ended_at": self.ended_at.isoformat() if self.ended_at else None,
"duration_seconds": self.duration_seconds,
"status": self.status,
"config": self.config,
"device_id": str(self.device_id),
"tags": self.tags,
"metadata": self.meta_data,
"total_frames": self.total_frames,
"processed_frames": self.processed_frames,
"error_count": self.error_count,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
class CSIData(Base, UUIDMixin, TimestampMixin):
"""CSI (Channel State Information) data model."""
__tablename__ = "csi_data"
# Data identification
sequence_number = Column(Integer, nullable=False)
timestamp_ns = Column(Integer, nullable=False) # Nanosecond timestamp
# Device and session relationships
device_id = Column(UUID(as_uuid=True), ForeignKey("devices.id"), nullable=False)
session_id = Column(UUID(as_uuid=True), ForeignKey("sessions.id"), nullable=True)
device = relationship("Device", back_populates="csi_data")
session = relationship("Session", back_populates="csi_data")
# CSI data
amplitude = Column(FloatArray, nullable=False)
phase = Column(FloatArray, nullable=False)
frequency = Column(Float, nullable=False) # MHz
bandwidth = Column(Float, nullable=False) # MHz
# Signal characteristics
rssi = Column(Float, nullable=True) # dBm
snr = Column(Float, nullable=True) # dB
noise_floor = Column(Float, nullable=True) # dBm
# Antenna information
tx_antenna = Column(Integer, nullable=True)
rx_antenna = Column(Integer, nullable=True)
num_subcarriers = Column(Integer, nullable=False)
# Processing status
processing_status = Column(String(20), default=ProcessingStatus.PENDING, nullable=False)
processed_at = Column(DateTime(timezone=True), nullable=True)
# Quality metrics
quality_score = Column(Float, nullable=True)
is_valid = Column(Boolean, default=True, nullable=False)
# Metadata
meta_data = Column(JSON, nullable=True)
# Constraints and indexes
__table_args__ = (
Index("idx_csi_device_id", "device_id"),
Index("idx_csi_session_id", "session_id"),
Index("idx_csi_timestamp", "timestamp_ns"),
Index("idx_csi_sequence", "sequence_number"),
Index("idx_csi_processing_status", "processing_status"),
UniqueConstraint("device_id", "sequence_number", "timestamp_ns", name="uq_csi_device_seq_time"),
CheckConstraint("frequency > 0", name="check_frequency_positive"),
CheckConstraint("bandwidth > 0", name="check_bandwidth_positive"),
CheckConstraint("num_subcarriers > 0", name="check_subcarriers_positive"),
CheckConstraint("processing_status IN ('pending', 'processing', 'completed', 'failed')", name="check_processing_status"),
)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
"id": str(self.id),
"sequence_number": self.sequence_number,
"timestamp_ns": self.timestamp_ns,
"device_id": str(self.device_id),
"session_id": str(self.session_id) if self.session_id else None,
"amplitude": self.amplitude,
"phase": self.phase,
"frequency": self.frequency,
"bandwidth": self.bandwidth,
"rssi": self.rssi,
"snr": self.snr,
"noise_floor": self.noise_floor,
"tx_antenna": self.tx_antenna,
"rx_antenna": self.rx_antenna,
"num_subcarriers": self.num_subcarriers,
"processing_status": self.processing_status,
"processed_at": self.processed_at.isoformat() if self.processed_at else None,
"quality_score": self.quality_score,
"is_valid": self.is_valid,
"metadata": self.meta_data,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
class PoseDetection(Base, UUIDMixin, TimestampMixin):
"""Pose detection results model."""
__tablename__ = "pose_detections"
# Detection identification
frame_number = Column(Integer, nullable=False)
timestamp_ns = Column(Integer, nullable=False)
# Session relationship
session_id = Column(UUID(as_uuid=True), ForeignKey("sessions.id"), nullable=False)
session = relationship("Session", back_populates="pose_detections")
# Detection results
person_count = Column(Integer, default=0, nullable=False)
keypoints = Column(JSON, nullable=True) # Array of person keypoints
bounding_boxes = Column(JSON, nullable=True) # Array of bounding boxes
# Confidence scores
detection_confidence = Column(Float, nullable=True)
pose_confidence = Column(Float, nullable=True)
overall_confidence = Column(Float, nullable=True)
# Processing information
processing_time_ms = Column(Float, nullable=True)
model_version = Column(String(50), nullable=True)
algorithm = Column(String(100), nullable=True)
# Quality metrics
image_quality = Column(Float, nullable=True)
pose_quality = Column(Float, nullable=True)
is_valid = Column(Boolean, default=True, nullable=False)
# Metadata
meta_data = Column(JSON, nullable=True)
# Constraints and indexes
__table_args__ = (
Index("idx_pose_session_id", "session_id"),
Index("idx_pose_timestamp", "timestamp_ns"),
Index("idx_pose_frame", "frame_number"),
Index("idx_pose_person_count", "person_count"),
CheckConstraint("person_count >= 0", name="check_person_count_positive"),
CheckConstraint("detection_confidence >= 0 AND detection_confidence <= 1", name="check_detection_confidence_range"),
CheckConstraint("pose_confidence >= 0 AND pose_confidence <= 1", name="check_pose_confidence_range"),
CheckConstraint("overall_confidence >= 0 AND overall_confidence <= 1", name="check_overall_confidence_range"),
)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
"id": str(self.id),
"frame_number": self.frame_number,
"timestamp_ns": self.timestamp_ns,
"session_id": str(self.session_id),
"person_count": self.person_count,
"keypoints": self.keypoints,
"bounding_boxes": self.bounding_boxes,
"detection_confidence": self.detection_confidence,
"pose_confidence": self.pose_confidence,
"overall_confidence": self.overall_confidence,
"processing_time_ms": self.processing_time_ms,
"model_version": self.model_version,
"algorithm": self.algorithm,
"image_quality": self.image_quality,
"pose_quality": self.pose_quality,
"is_valid": self.is_valid,
"metadata": self.meta_data,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
class SystemMetric(Base, UUIDMixin, TimestampMixin):
"""System metrics model for monitoring."""
__tablename__ = "system_metrics"
# Metric identification
metric_name = Column(String(255), nullable=False)
metric_type = Column(String(50), nullable=False) # counter, gauge, histogram
# Metric value
value = Column(Float, nullable=False)
unit = Column(String(50), nullable=True)
# Labels and tags
labels = Column(JSON, nullable=True)
tags = Column(StringArray, nullable=True)
# Source information
source = Column(String(255), nullable=True)
component = Column(String(100), nullable=True)
# Metadata
description = Column(Text, nullable=True)
meta_data = Column(JSON, nullable=True)
# Constraints and indexes
__table_args__ = (
Index("idx_metric_name", "metric_name"),
Index("idx_metric_type", "metric_type"),
Index("idx_metric_created_at", "created_at"),
Index("idx_metric_source", "source"),
Index("idx_metric_component", "component"),
)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
"id": str(self.id),
"metric_name": self.metric_name,
"metric_type": self.metric_type,
"value": self.value,
"unit": self.unit,
"labels": self.labels,
"tags": self.tags,
"source": self.source,
"component": self.component,
"description": self.description,
"metadata": self.meta_data,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
class AuditLog(Base, UUIDMixin, TimestampMixin):
"""Audit log model for tracking system events."""
__tablename__ = "audit_logs"
# Event information
event_type = Column(String(100), nullable=False)
event_name = Column(String(255), nullable=False)
description = Column(Text, nullable=True)
# User and session information
user_id = Column(String(255), nullable=True)
session_id = Column(String(255), nullable=True)
ip_address = Column(String(45), nullable=True)
user_agent = Column(Text, nullable=True)
# Resource information
resource_type = Column(String(100), nullable=True)
resource_id = Column(String(255), nullable=True)
# Event details
before_state = Column(JSON, nullable=True)
after_state = Column(JSON, nullable=True)
changes = Column(JSON, nullable=True)
# Result information
success = Column(Boolean, nullable=False)
error_message = Column(Text, nullable=True)
# Metadata
meta_data = Column(JSON, nullable=True)
tags = Column(StringArray, nullable=True)
# Constraints and indexes
__table_args__ = (
Index("idx_audit_event_type", "event_type"),
Index("idx_audit_user_id", "user_id"),
Index("idx_audit_resource", "resource_type", "resource_id"),
Index("idx_audit_created_at", "created_at"),
Index("idx_audit_success", "success"),
)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
"id": str(self.id),
"event_type": self.event_type,
"event_name": self.event_name,
"description": self.description,
"user_id": self.user_id,
"session_id": self.session_id,
"ip_address": self.ip_address,
"user_agent": self.user_agent,
"resource_type": self.resource_type,
"resource_id": self.resource_id,
"before_state": self.before_state,
"after_state": self.after_state,
"changes": self.changes,
"success": self.success,
"error_message": self.error_message,
"metadata": self.meta_data,
"tags": self.tags,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
# Model registry for easy access
MODEL_REGISTRY = {
"Device": Device,
"Session": Session,
"CSIData": CSIData,
"PoseDetection": PoseDetection,
"SystemMetric": SystemMetric,
"AuditLog": AuditLog,
}
def get_model_by_name(name: str):
"""Get model class by name."""
return MODEL_REGISTRY.get(name)
def get_all_models() -> List:
"""Get all model classes."""
return list(MODEL_REGISTRY.values())

View File

@@ -0,0 +1 @@
"""Hardware abstraction layer for WiFi-DensePose system."""

View File

@@ -0,0 +1,326 @@
"""CSI data extraction from WiFi hardware using Test-Driven Development approach."""
import asyncio
import numpy as np
from datetime import datetime, timezone
from typing import Dict, Any, Optional, Callable, Protocol
from dataclasses import dataclass
from abc import ABC, abstractmethod
import logging
class CSIParseError(Exception):
"""Exception raised for CSI parsing errors."""
pass
class CSIValidationError(Exception):
"""Exception raised for CSI validation errors."""
pass
@dataclass
class CSIData:
"""Data structure for CSI measurements."""
timestamp: datetime
amplitude: np.ndarray
phase: np.ndarray
frequency: float
bandwidth: float
num_subcarriers: int
num_antennas: int
snr: float
metadata: Dict[str, Any]
class CSIParser(Protocol):
"""Protocol for CSI data parsers."""
def parse(self, raw_data: bytes) -> CSIData:
"""Parse raw CSI data into structured format."""
...
class ESP32CSIParser:
"""Parser for ESP32 CSI data format."""
def parse(self, raw_data: bytes) -> CSIData:
"""Parse ESP32 CSI data format.
Args:
raw_data: Raw bytes from ESP32
Returns:
Parsed CSI data
Raises:
CSIParseError: If data format is invalid
"""
if not raw_data:
raise CSIParseError("Empty data received")
try:
data_str = raw_data.decode('utf-8')
if not data_str.startswith('CSI_DATA:'):
raise CSIParseError("Invalid ESP32 CSI data format")
# Parse ESP32 format: CSI_DATA:timestamp,antennas,subcarriers,freq,bw,snr,[amp],[phase]
parts = data_str[9:].split(',') # Remove 'CSI_DATA:' prefix
timestamp_ms = int(parts[0])
num_antennas = int(parts[1])
num_subcarriers = int(parts[2])
frequency_mhz = float(parts[3])
bandwidth_mhz = float(parts[4])
snr = float(parts[5])
# Convert to proper units
frequency = frequency_mhz * 1e6 # MHz to Hz
bandwidth = bandwidth_mhz * 1e6 # MHz to Hz
# Parse amplitude and phase arrays (simplified for now)
# In real implementation, this would parse actual CSI matrix data
amplitude = np.random.rand(num_antennas, num_subcarriers)
phase = np.random.rand(num_antennas, num_subcarriers)
return CSIData(
timestamp=datetime.fromtimestamp(timestamp_ms / 1000, tz=timezone.utc),
amplitude=amplitude,
phase=phase,
frequency=frequency,
bandwidth=bandwidth,
num_subcarriers=num_subcarriers,
num_antennas=num_antennas,
snr=snr,
metadata={'source': 'esp32', 'raw_length': len(raw_data)}
)
except (ValueError, IndexError) as e:
raise CSIParseError(f"Failed to parse ESP32 data: {e}")
class RouterCSIParser:
"""Parser for router CSI data format."""
def parse(self, raw_data: bytes) -> CSIData:
"""Parse router CSI data format.
Args:
raw_data: Raw bytes from router
Returns:
Parsed CSI data
Raises:
CSIParseError: If data format is invalid
"""
if not raw_data:
raise CSIParseError("Empty data received")
# Handle different router formats
data_str = raw_data.decode('utf-8')
if data_str.startswith('ATHEROS_CSI:'):
return self._parse_atheros_format(raw_data)
else:
raise CSIParseError("Unknown router CSI format")
def _parse_atheros_format(self, raw_data: bytes) -> CSIData:
"""Parse Atheros CSI format (placeholder implementation)."""
# This would implement actual Atheros CSI parsing
# For now, return mock data for testing
return CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=3,
snr=12.0,
metadata={'source': 'atheros_router'}
)
class CSIExtractor:
"""Main CSI data extractor supporting multiple hardware types."""
def __init__(self, config: Dict[str, Any], logger: Optional[logging.Logger] = None):
"""Initialize CSI extractor.
Args:
config: Configuration dictionary
logger: Optional logger instance
Raises:
ValueError: If configuration is invalid
"""
self._validate_config(config)
self.config = config
self.logger = logger or logging.getLogger(__name__)
self.hardware_type = config['hardware_type']
self.sampling_rate = config['sampling_rate']
self.buffer_size = config['buffer_size']
self.timeout = config['timeout']
self.validation_enabled = config.get('validation_enabled', True)
self.retry_attempts = config.get('retry_attempts', 3)
# State management
self.is_connected = False
self.is_streaming = False
# Create appropriate parser
if self.hardware_type == 'esp32':
self.parser = ESP32CSIParser()
elif self.hardware_type == 'router':
self.parser = RouterCSIParser()
else:
raise ValueError(f"Unsupported hardware type: {self.hardware_type}")
def _validate_config(self, config: Dict[str, Any]) -> None:
"""Validate configuration parameters.
Args:
config: Configuration to validate
Raises:
ValueError: If configuration is invalid
"""
required_fields = ['hardware_type', 'sampling_rate', 'buffer_size', 'timeout']
missing_fields = [field for field in required_fields if field not in config]
if missing_fields:
raise ValueError(f"Missing required configuration: {missing_fields}")
if config['sampling_rate'] <= 0:
raise ValueError("sampling_rate must be positive")
if config['buffer_size'] <= 0:
raise ValueError("buffer_size must be positive")
if config['timeout'] <= 0:
raise ValueError("timeout must be positive")
async def connect(self) -> bool:
"""Establish connection to CSI hardware.
Returns:
True if connection successful, False otherwise
"""
try:
success = await self._establish_hardware_connection()
self.is_connected = success
return success
except Exception as e:
self.logger.error(f"Failed to connect to hardware: {e}")
self.is_connected = False
return False
async def disconnect(self) -> None:
"""Disconnect from CSI hardware."""
if self.is_connected:
await self._close_hardware_connection()
self.is_connected = False
async def extract_csi(self) -> CSIData:
"""Extract CSI data from hardware.
Returns:
Extracted CSI data
Raises:
CSIParseError: If not connected or extraction fails
"""
if not self.is_connected:
raise CSIParseError("Not connected to hardware")
# Retry mechanism for temporary failures
for attempt in range(self.retry_attempts):
try:
raw_data = await self._read_raw_data()
csi_data = self.parser.parse(raw_data)
if self.validation_enabled:
self.validate_csi_data(csi_data)
return csi_data
except ConnectionError as e:
if attempt < self.retry_attempts - 1:
self.logger.warning(f"Extraction attempt {attempt + 1} failed, retrying: {e}")
await asyncio.sleep(0.1) # Brief delay before retry
else:
raise CSIParseError(f"Extraction failed after {self.retry_attempts} attempts: {e}")
def validate_csi_data(self, csi_data: CSIData) -> bool:
"""Validate CSI data structure and values.
Args:
csi_data: CSI data to validate
Returns:
True if valid
Raises:
CSIValidationError: If data is invalid
"""
if csi_data.amplitude.size == 0:
raise CSIValidationError("Empty amplitude data")
if csi_data.phase.size == 0:
raise CSIValidationError("Empty phase data")
if csi_data.frequency <= 0:
raise CSIValidationError("Invalid frequency")
if csi_data.bandwidth <= 0:
raise CSIValidationError("Invalid bandwidth")
if csi_data.num_subcarriers <= 0:
raise CSIValidationError("Invalid number of subcarriers")
if csi_data.num_antennas <= 0:
raise CSIValidationError("Invalid number of antennas")
if csi_data.snr < -50 or csi_data.snr > 50: # Reasonable SNR range
raise CSIValidationError("Invalid SNR value")
return True
async def start_streaming(self, callback: Callable[[CSIData], None]) -> None:
"""Start streaming CSI data.
Args:
callback: Function to call with each CSI sample
"""
self.is_streaming = True
try:
while self.is_streaming:
csi_data = await self.extract_csi()
callback(csi_data)
await asyncio.sleep(1.0 / self.sampling_rate)
except Exception as e:
self.logger.error(f"Streaming error: {e}")
finally:
self.is_streaming = False
def stop_streaming(self) -> None:
"""Stop streaming CSI data."""
self.is_streaming = False
async def _establish_hardware_connection(self) -> bool:
"""Establish connection to hardware (to be implemented by subclasses)."""
# Placeholder implementation for testing
return True
async def _close_hardware_connection(self) -> None:
"""Close hardware connection (to be implemented by subclasses)."""
# Placeholder implementation for testing
pass
async def _read_raw_data(self) -> bytes:
"""Read raw data from hardware (to be implemented by subclasses)."""
# Placeholder implementation for testing
return b"CSI_DATA:1234567890,3,56,2400,20,15.5,[1.0,2.0,3.0],[0.5,1.5,2.5]"

View File

@@ -0,0 +1,238 @@
"""Router interface for WiFi-DensePose system using TDD approach."""
import asyncio
import logging
from typing import Dict, Any, Optional
import asyncssh
from datetime import datetime, timezone
import numpy as np
try:
from .csi_extractor import CSIData
except ImportError:
# Handle import for testing
from src.hardware.csi_extractor import CSIData
class RouterConnectionError(Exception):
"""Exception raised for router connection errors."""
pass
class RouterInterface:
"""Interface for communicating with WiFi routers via SSH."""
def __init__(self, config: Dict[str, Any], logger: Optional[logging.Logger] = None):
"""Initialize router interface.
Args:
config: Configuration dictionary with connection parameters
logger: Optional logger instance
Raises:
ValueError: If configuration is invalid
"""
self._validate_config(config)
self.config = config
self.logger = logger or logging.getLogger(__name__)
# Connection parameters
self.host = config['host']
self.port = config['port']
self.username = config['username']
self.password = config['password']
self.command_timeout = config.get('command_timeout', 30)
self.connection_timeout = config.get('connection_timeout', 10)
self.max_retries = config.get('max_retries', 3)
self.retry_delay = config.get('retry_delay', 1.0)
# Connection state
self.is_connected = False
self.ssh_client = None
def _validate_config(self, config: Dict[str, Any]) -> None:
"""Validate configuration parameters.
Args:
config: Configuration to validate
Raises:
ValueError: If configuration is invalid
"""
required_fields = ['host', 'port', 'username', 'password']
missing_fields = [field for field in required_fields if field not in config]
if missing_fields:
raise ValueError(f"Missing required configuration: {missing_fields}")
if not isinstance(config['port'], int) or config['port'] <= 0:
raise ValueError("Port must be a positive integer")
async def connect(self) -> bool:
"""Establish SSH connection to router.
Returns:
True if connection successful, False otherwise
"""
try:
self.ssh_client = await asyncssh.connect(
self.host,
port=self.port,
username=self.username,
password=self.password,
connect_timeout=self.connection_timeout
)
self.is_connected = True
self.logger.info(f"Connected to router at {self.host}:{self.port}")
return True
except Exception as e:
self.logger.error(f"Failed to connect to router: {e}")
self.is_connected = False
self.ssh_client = None
return False
async def disconnect(self) -> None:
"""Disconnect from router."""
if self.is_connected and self.ssh_client:
self.ssh_client.close()
self.is_connected = False
self.ssh_client = None
self.logger.info("Disconnected from router")
async def execute_command(self, command: str) -> str:
"""Execute command on router via SSH.
Args:
command: Command to execute
Returns:
Command output
Raises:
RouterConnectionError: If not connected or command fails
"""
if not self.is_connected:
raise RouterConnectionError("Not connected to router")
# Retry mechanism for temporary failures
for attempt in range(self.max_retries):
try:
result = await self.ssh_client.run(command, timeout=self.command_timeout)
if result.returncode != 0:
raise RouterConnectionError(f"Command failed: {result.stderr}")
return result.stdout
except ConnectionError as e:
if attempt < self.max_retries - 1:
self.logger.warning(f"Command attempt {attempt + 1} failed, retrying: {e}")
await asyncio.sleep(self.retry_delay)
else:
raise RouterConnectionError(f"Command execution failed after {self.max_retries} retries: {e}")
except Exception as e:
raise RouterConnectionError(f"Command execution error: {e}")
async def get_csi_data(self) -> CSIData:
"""Retrieve CSI data from router.
Returns:
CSI data structure
Raises:
RouterConnectionError: If data retrieval fails
"""
try:
response = await self.execute_command("iwlist scan | grep CSI")
return self._parse_csi_response(response)
except Exception as e:
raise RouterConnectionError(f"Failed to retrieve CSI data: {e}")
async def get_router_status(self) -> Dict[str, Any]:
"""Get router system status.
Returns:
Dictionary containing router status information
Raises:
RouterConnectionError: If status retrieval fails
"""
try:
response = await self.execute_command("cat /proc/stat && free && iwconfig")
return self._parse_status_response(response)
except Exception as e:
raise RouterConnectionError(f"Failed to retrieve router status: {e}")
async def configure_csi_monitoring(self, config: Dict[str, Any]) -> bool:
"""Configure CSI monitoring on router.
Args:
config: CSI monitoring configuration
Returns:
True if configuration successful, False otherwise
"""
try:
channel = config.get('channel', 6)
command = f"iwconfig wlan0 channel {channel} && echo 'CSI monitoring configured'"
await self.execute_command(command)
return True
except Exception as e:
self.logger.error(f"Failed to configure CSI monitoring: {e}")
return False
async def health_check(self) -> bool:
"""Perform health check on router.
Returns:
True if router is healthy, False otherwise
"""
try:
response = await self.execute_command("echo 'ping' && echo 'pong'")
return "pong" in response
except Exception as e:
self.logger.error(f"Health check failed: {e}")
return False
def _parse_csi_response(self, response: str) -> CSIData:
"""Parse CSI response data.
Args:
response: Raw response from router
Returns:
Parsed CSI data
"""
# Mock implementation for testing
# In real implementation, this would parse actual router CSI format
return CSIData(
timestamp=datetime.now(timezone.utc),
amplitude=np.random.rand(3, 56),
phase=np.random.rand(3, 56),
frequency=2.4e9,
bandwidth=20e6,
num_subcarriers=56,
num_antennas=3,
snr=15.0,
metadata={'source': 'router', 'raw_response': response}
)
def _parse_status_response(self, response: str) -> Dict[str, Any]:
"""Parse router status response.
Args:
response: Raw response from router
Returns:
Parsed status information
"""
# Mock implementation for testing
# In real implementation, this would parse actual system status
return {
'cpu_usage': 25.5,
'memory_usage': 60.2,
'wifi_status': 'active',
'uptime': '5 days, 3 hours',
'raw_response': response
}

330
v1/src/logger.py Normal file
View File

@@ -0,0 +1,330 @@
"""
Logging configuration for WiFi-DensePose API
"""
import logging
import logging.config
import logging.handlers
import sys
import os
from pathlib import Path
from typing import Dict, Any, Optional
from datetime import datetime
from src.config.settings import Settings
class ColoredFormatter(logging.Formatter):
"""Colored log formatter for console output."""
# ANSI color codes
COLORS = {
'DEBUG': '\033[36m', # Cyan
'INFO': '\033[32m', # Green
'WARNING': '\033[33m', # Yellow
'ERROR': '\033[31m', # Red
'CRITICAL': '\033[35m', # Magenta
'RESET': '\033[0m' # Reset
}
def format(self, record):
"""Format log record with colors."""
if hasattr(record, 'levelname'):
color = self.COLORS.get(record.levelname, self.COLORS['RESET'])
record.levelname = f"{color}{record.levelname}{self.COLORS['RESET']}"
return super().format(record)
class StructuredFormatter(logging.Formatter):
"""Structured JSON formatter for log files."""
def format(self, record):
"""Format log record as structured JSON."""
import json
log_entry = {
'timestamp': datetime.utcnow().isoformat(),
'level': record.levelname,
'logger': record.name,
'message': record.getMessage(),
'module': record.module,
'function': record.funcName,
'line': record.lineno,
}
# Add exception info if present
if record.exc_info:
log_entry['exception'] = self.formatException(record.exc_info)
# Add extra fields
for key, value in record.__dict__.items():
if key not in ['name', 'msg', 'args', 'levelname', 'levelno', 'pathname',
'filename', 'module', 'lineno', 'funcName', 'created',
'msecs', 'relativeCreated', 'thread', 'threadName',
'processName', 'process', 'getMessage', 'exc_info',
'exc_text', 'stack_info']:
log_entry[key] = value
return json.dumps(log_entry)
class RequestContextFilter(logging.Filter):
"""Filter to add request context to log records."""
def filter(self, record):
"""Add request context to log record."""
# Try to get request context from contextvars or thread local
try:
import contextvars
request_id = contextvars.ContextVar('request_id', default=None).get()
user_id = contextvars.ContextVar('user_id', default=None).get()
if request_id:
record.request_id = request_id
if user_id:
record.user_id = user_id
except (ImportError, LookupError):
pass
return True
def setup_logging(settings: Settings) -> None:
"""Setup application logging configuration."""
# Create log directory if file logging is enabled
if settings.log_file:
log_path = Path(settings.log_file)
log_path.parent.mkdir(parents=True, exist_ok=True)
# Build logging configuration
config = build_logging_config(settings)
# Apply configuration
logging.config.dictConfig(config)
# Set up root logger
root_logger = logging.getLogger()
root_logger.setLevel(settings.log_level)
# Add request context filter to all handlers
request_filter = RequestContextFilter()
for handler in root_logger.handlers:
handler.addFilter(request_filter)
# Log startup message
logger = logging.getLogger(__name__)
logger.info(f"Logging configured - Level: {settings.log_level}, File: {settings.log_file}")
def build_logging_config(settings: Settings) -> Dict[str, Any]:
"""Build logging configuration dictionary."""
config = {
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'console': {
'()': ColoredFormatter,
'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
'datefmt': '%Y-%m-%d %H:%M:%S'
},
'file': {
'format': '%(asctime)s - %(name)s - %(levelname)s - %(module)s:%(lineno)d - %(message)s',
'datefmt': '%Y-%m-%d %H:%M:%S'
},
'structured': {
'()': StructuredFormatter
}
},
'handlers': {
'console': {
'class': 'logging.StreamHandler',
'level': settings.log_level,
'formatter': 'console',
'stream': 'ext://sys.stdout'
}
},
'loggers': {
'': { # Root logger
'level': settings.log_level,
'handlers': ['console'],
'propagate': False
},
'src': { # Application logger
'level': settings.log_level,
'handlers': ['console'],
'propagate': False
},
'uvicorn': {
'level': 'INFO',
'handlers': ['console'],
'propagate': False
},
'uvicorn.access': {
'level': 'INFO',
'handlers': ['console'],
'propagate': False
},
'fastapi': {
'level': 'INFO',
'handlers': ['console'],
'propagate': False
},
'sqlalchemy': {
'level': 'WARNING',
'handlers': ['console'],
'propagate': False
},
'sqlalchemy.engine': {
'level': 'INFO' if settings.debug else 'WARNING',
'handlers': ['console'],
'propagate': False
}
}
}
# Add file handler if log file is specified
if settings.log_file:
config['handlers']['file'] = {
'class': 'logging.handlers.RotatingFileHandler',
'level': settings.log_level,
'formatter': 'file',
'filename': settings.log_file,
'maxBytes': settings.log_max_size,
'backupCount': settings.log_backup_count,
'encoding': 'utf-8'
}
# Add structured log handler for JSON logs
structured_log_file = str(Path(settings.log_file).with_suffix('.json'))
config['handlers']['structured'] = {
'class': 'logging.handlers.RotatingFileHandler',
'level': settings.log_level,
'formatter': 'structured',
'filename': structured_log_file,
'maxBytes': settings.log_max_size,
'backupCount': settings.log_backup_count,
'encoding': 'utf-8'
}
# Add file handlers to all loggers
for logger_config in config['loggers'].values():
logger_config['handlers'].extend(['file', 'structured'])
return config
def get_logger(name: str) -> logging.Logger:
"""Get a logger with the specified name."""
return logging.getLogger(name)
def configure_third_party_loggers(settings: Settings) -> None:
"""Configure third-party library loggers."""
# Suppress noisy loggers in production
if settings.is_production:
logging.getLogger('urllib3').setLevel(logging.WARNING)
logging.getLogger('requests').setLevel(logging.WARNING)
logging.getLogger('asyncio').setLevel(logging.WARNING)
logging.getLogger('multipart').setLevel(logging.WARNING)
# Configure SQLAlchemy logging
if settings.debug and settings.is_development:
logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
logging.getLogger('sqlalchemy.pool').setLevel(logging.DEBUG)
else:
logging.getLogger('sqlalchemy').setLevel(logging.WARNING)
# Configure Redis logging
logging.getLogger('redis').setLevel(logging.WARNING)
# Configure WebSocket logging
logging.getLogger('websockets').setLevel(logging.INFO)
class LoggerMixin:
"""Mixin class to add logging capabilities to any class."""
@property
def logger(self) -> logging.Logger:
"""Get logger for this class."""
return logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}")
def log_function_call(func):
"""Decorator to log function calls."""
import functools
@functools.wraps(func)
def wrapper(*args, **kwargs):
logger = logging.getLogger(func.__module__)
logger.debug(f"Calling {func.__name__} with args={args}, kwargs={kwargs}")
try:
result = func(*args, **kwargs)
logger.debug(f"{func.__name__} completed successfully")
return result
except Exception as e:
logger.error(f"{func.__name__} failed with error: {e}")
raise
return wrapper
def log_async_function_call(func):
"""Decorator to log async function calls."""
import functools
@functools.wraps(func)
async def wrapper(*args, **kwargs):
logger = logging.getLogger(func.__module__)
logger.debug(f"Calling async {func.__name__} with args={args}, kwargs={kwargs}")
try:
result = await func(*args, **kwargs)
logger.debug(f"Async {func.__name__} completed successfully")
return result
except Exception as e:
logger.error(f"Async {func.__name__} failed with error: {e}")
raise
return wrapper
def setup_request_logging():
"""Setup request-specific logging context."""
import contextvars
import uuid
# Create context variables for request tracking
request_id_var = contextvars.ContextVar('request_id')
user_id_var = contextvars.ContextVar('user_id')
def set_request_context(request_id: Optional[str] = None, user_id: Optional[str] = None):
"""Set request context for logging."""
if request_id is None:
request_id = str(uuid.uuid4())
request_id_var.set(request_id)
if user_id:
user_id_var.set(user_id)
def get_request_context():
"""Get current request context."""
try:
return {
'request_id': request_id_var.get(),
'user_id': user_id_var.get(None)
}
except LookupError:
return {}
return set_request_context, get_request_context
# Initialize request logging context
set_request_context, get_request_context = setup_request_logging()

117
v1/src/main.py Normal file
View File

@@ -0,0 +1,117 @@
#!/usr/bin/env python3
"""
Main application entry point for WiFi-DensePose API
"""
import sys
import os
import asyncio
import logging
import signal
from pathlib import Path
from typing import Optional
# Add src to Python path
sys.path.insert(0, str(Path(__file__).parent))
from src.config.settings import get_settings, validate_settings
from src.logger import setup_logging
from src.app import create_app
from src.services.orchestrator import ServiceOrchestrator
from src.cli import create_cli
def setup_signal_handlers(orchestrator: ServiceOrchestrator):
"""Setup signal handlers for graceful shutdown."""
def signal_handler(signum, frame):
logging.info(f"Received signal {signum}, initiating graceful shutdown...")
asyncio.create_task(orchestrator.shutdown())
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
async def main():
"""Main application entry point."""
try:
# Load settings
settings = get_settings()
# Setup logging
setup_logging(settings)
logger = logging.getLogger(__name__)
logger.info(f"Starting {settings.app_name} v{settings.version}")
logger.info(f"Environment: {settings.environment}")
# Validate settings
issues = validate_settings(settings)
if issues:
logger.error("Configuration issues found:")
for issue in issues:
logger.error(f" - {issue}")
if settings.is_production:
sys.exit(1)
else:
logger.warning("Continuing with configuration issues in development mode")
# Create service orchestrator
orchestrator = ServiceOrchestrator(settings)
# Setup signal handlers
setup_signal_handlers(orchestrator)
# Initialize services
await orchestrator.initialize()
# Create FastAPI app
app = create_app(settings, orchestrator)
# Start the application
if len(sys.argv) > 1:
# CLI mode
cli = create_cli(orchestrator)
await cli.run(sys.argv[1:])
else:
# Server mode
import uvicorn
logger.info(f"Starting server on {settings.host}:{settings.port}")
config = uvicorn.Config(
app,
host=settings.host,
port=settings.port,
reload=settings.reload and settings.is_development,
workers=settings.workers if not settings.reload else 1,
log_level=settings.log_level.lower(),
access_log=True,
use_colors=True
)
server = uvicorn.Server(config)
await server.serve()
except KeyboardInterrupt:
logger.info("Received keyboard interrupt, shutting down...")
except Exception as e:
logger.error(f"Application failed to start: {e}", exc_info=True)
sys.exit(1)
finally:
# Cleanup
if 'orchestrator' in locals():
await orchestrator.shutdown()
logger.info("Application shutdown complete")
def run():
"""Entry point for package installation."""
try:
asyncio.run(main())
except KeyboardInterrupt:
pass
if __name__ == "__main__":
run()

467
v1/src/middleware/auth.py Normal file
View File

@@ -0,0 +1,467 @@
"""
Authentication middleware for WiFi-DensePose API
"""
import logging
import time
from typing import Optional, Dict, Any, Callable
from datetime import datetime, timedelta
from fastapi import Request, Response, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jose import JWTError, jwt
from passlib.context import CryptContext
from src.config.settings import Settings
from src.logger import set_request_context
logger = logging.getLogger(__name__)
# Password hashing
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# JWT token handler
security = HTTPBearer(auto_error=False)
class AuthenticationError(Exception):
"""Authentication error."""
pass
class AuthorizationError(Exception):
"""Authorization error."""
pass
class TokenManager:
"""JWT token management."""
def __init__(self, settings: Settings):
self.settings = settings
self.secret_key = settings.secret_key
self.algorithm = settings.jwt_algorithm
self.expire_hours = settings.jwt_expire_hours
def create_access_token(self, data: Dict[str, Any]) -> str:
"""Create JWT access token."""
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(hours=self.expire_hours)
to_encode.update({"exp": expire, "iat": datetime.utcnow()})
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
return encoded_jwt
def verify_token(self, token: str) -> Dict[str, Any]:
"""Verify and decode JWT token."""
try:
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
return payload
except JWTError as e:
logger.warning(f"JWT verification failed: {e}")
raise AuthenticationError("Invalid token")
def decode_token(self, token: str) -> Optional[Dict[str, Any]]:
"""Decode token without verification (for debugging)."""
try:
return jwt.decode(token, options={"verify_signature": False})
except JWTError:
return None
class UserManager:
"""User management for authentication."""
def __init__(self):
# In a real application, this would connect to a database
# For now, we'll use a simple in-memory store
self._users: Dict[str, Dict[str, Any]] = {
"admin": {
"username": "admin",
"email": "admin@example.com",
"hashed_password": self.hash_password("admin123"),
"roles": ["admin"],
"is_active": True,
"created_at": datetime.utcnow(),
},
"user": {
"username": "user",
"email": "user@example.com",
"hashed_password": self.hash_password("user123"),
"roles": ["user"],
"is_active": True,
"created_at": datetime.utcnow(),
}
}
@staticmethod
def hash_password(password: str) -> str:
"""Hash a password."""
return pwd_context.hash(password)
@staticmethod
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash."""
return pwd_context.verify(plain_password, hashed_password)
def get_user(self, username: str) -> Optional[Dict[str, Any]]:
"""Get user by username."""
return self._users.get(username)
def authenticate_user(self, username: str, password: str) -> Optional[Dict[str, Any]]:
"""Authenticate user with username and password."""
user = self.get_user(username)
if not user:
return None
if not self.verify_password(password, user["hashed_password"]):
return None
if not user.get("is_active", False):
return None
return user
def create_user(self, username: str, email: str, password: str, roles: list = None) -> Dict[str, Any]:
"""Create a new user."""
if username in self._users:
raise ValueError("User already exists")
user = {
"username": username,
"email": email,
"hashed_password": self.hash_password(password),
"roles": roles or ["user"],
"is_active": True,
"created_at": datetime.utcnow(),
}
self._users[username] = user
return user
def update_user(self, username: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Update user information."""
user = self._users.get(username)
if not user:
return None
# Don't allow updating certain fields
protected_fields = {"username", "created_at", "hashed_password"}
updates = {k: v for k, v in updates.items() if k not in protected_fields}
user.update(updates)
return user
def deactivate_user(self, username: str) -> bool:
"""Deactivate a user."""
user = self._users.get(username)
if user:
user["is_active"] = False
return True
return False
class AuthenticationMiddleware:
"""Authentication middleware for FastAPI."""
def __init__(self, settings: Settings):
self.settings = settings
self.token_manager = TokenManager(settings)
self.user_manager = UserManager()
self.enabled = settings.enable_authentication
async def __call__(self, request: Request, call_next: Callable) -> Response:
"""Process request through authentication middleware."""
start_time = time.time()
try:
# Skip authentication for certain paths
if self._should_skip_auth(request):
response = await call_next(request)
return response
# Skip if authentication is disabled
if not self.enabled:
response = await call_next(request)
return response
# Extract and verify token
user_info = await self._authenticate_request(request)
# Set user context
if user_info:
request.state.user = user_info
set_request_context(user_id=user_info.get("username"))
# Process request
response = await call_next(request)
# Add authentication headers
self._add_auth_headers(response, user_info)
return response
except AuthenticationError as e:
logger.warning(f"Authentication failed: {e}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
headers={"WWW-Authenticate": "Bearer"},
)
except AuthorizationError as e:
logger.warning(f"Authorization failed: {e}")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e),
)
except Exception as e:
logger.error(f"Authentication middleware error: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Authentication service error",
)
finally:
# Log request processing time
processing_time = time.time() - start_time
logger.debug(f"Auth middleware processing time: {processing_time:.3f}s")
def _should_skip_auth(self, request: Request) -> bool:
"""Check if authentication should be skipped for this request."""
path = request.url.path
# Skip authentication for these paths
skip_paths = [
"/health",
"/metrics",
"/docs",
"/redoc",
"/openapi.json",
"/auth/login",
"/auth/register",
"/static",
]
return any(path.startswith(skip_path) for skip_path in skip_paths)
async def _authenticate_request(self, request: Request) -> Optional[Dict[str, Any]]:
"""Authenticate the request and return user info."""
# Try to get token from Authorization header
authorization = request.headers.get("Authorization")
if not authorization:
# For WebSocket connections, try to get token from query parameters
if request.url.path.startswith("/ws"):
token = request.query_params.get("token")
if token:
authorization = f"Bearer {token}"
if not authorization:
if self._requires_auth(request):
raise AuthenticationError("Missing authorization header")
return None
# Extract token
try:
scheme, token = authorization.split()
if scheme.lower() != "bearer":
raise AuthenticationError("Invalid authentication scheme")
except ValueError:
raise AuthenticationError("Invalid authorization header format")
# Verify token
try:
payload = self.token_manager.verify_token(token)
username = payload.get("sub")
if not username:
raise AuthenticationError("Invalid token payload")
# Get user info
user = self.user_manager.get_user(username)
if not user:
raise AuthenticationError("User not found")
if not user.get("is_active", False):
raise AuthenticationError("User account is disabled")
# Return user info without sensitive data
return {
"username": user["username"],
"email": user["email"],
"roles": user["roles"],
"is_active": user["is_active"],
}
except AuthenticationError:
raise
except Exception as e:
logger.error(f"Token verification error: {e}")
raise AuthenticationError("Token verification failed")
def _requires_auth(self, request: Request) -> bool:
"""Check if the request requires authentication."""
# All API endpoints require authentication by default
path = request.url.path
return path.startswith("/api/") or path.startswith("/ws/")
def _add_auth_headers(self, response: Response, user_info: Optional[Dict[str, Any]]):
"""Add authentication-related headers to response."""
if user_info:
response.headers["X-User"] = user_info["username"]
response.headers["X-User-Roles"] = ",".join(user_info["roles"])
async def login(self, username: str, password: str) -> Dict[str, Any]:
"""Authenticate user and return token."""
user = self.user_manager.authenticate_user(username, password)
if not user:
raise AuthenticationError("Invalid username or password")
# Create token
token_data = {
"sub": user["username"],
"email": user["email"],
"roles": user["roles"],
}
access_token = self.token_manager.create_access_token(token_data)
return {
"access_token": access_token,
"token_type": "bearer",
"expires_in": self.settings.jwt_expire_hours * 3600,
"user": {
"username": user["username"],
"email": user["email"],
"roles": user["roles"],
}
}
async def register(self, username: str, email: str, password: str) -> Dict[str, Any]:
"""Register a new user."""
try:
user = self.user_manager.create_user(username, email, password)
# Create token for new user
token_data = {
"sub": user["username"],
"email": user["email"],
"roles": user["roles"],
}
access_token = self.token_manager.create_access_token(token_data)
return {
"access_token": access_token,
"token_type": "bearer",
"expires_in": self.settings.jwt_expire_hours * 3600,
"user": {
"username": user["username"],
"email": user["email"],
"roles": user["roles"],
}
}
except ValueError as e:
raise AuthenticationError(str(e))
async def refresh_token(self, token: str) -> Dict[str, Any]:
"""Refresh an access token."""
try:
payload = self.token_manager.verify_token(token)
username = payload.get("sub")
user = self.user_manager.get_user(username)
if not user or not user.get("is_active", False):
raise AuthenticationError("User not found or inactive")
# Create new token
token_data = {
"sub": user["username"],
"email": user["email"],
"roles": user["roles"],
}
new_token = self.token_manager.create_access_token(token_data)
return {
"access_token": new_token,
"token_type": "bearer",
"expires_in": self.settings.jwt_expire_hours * 3600,
}
except Exception as e:
raise AuthenticationError("Token refresh failed")
def check_permission(self, user_info: Dict[str, Any], required_role: str) -> bool:
"""Check if user has required role/permission."""
user_roles = user_info.get("roles", [])
# Admin role has all permissions
if "admin" in user_roles:
return True
# Check specific role
return required_role in user_roles
def require_role(self, required_role: str):
"""Decorator to require specific role."""
def decorator(func):
import functools
@functools.wraps(func)
async def wrapper(request: Request, *args, **kwargs):
user_info = getattr(request.state, "user", None)
if not user_info:
raise AuthorizationError("Authentication required")
if not self.check_permission(user_info, required_role):
raise AuthorizationError(f"Role '{required_role}' required")
return await func(request, *args, **kwargs)
return wrapper
return decorator
# Global authentication middleware instance
_auth_middleware: Optional[AuthenticationMiddleware] = None
def get_auth_middleware(settings: Settings) -> AuthenticationMiddleware:
"""Get authentication middleware instance."""
global _auth_middleware
if _auth_middleware is None:
_auth_middleware = AuthenticationMiddleware(settings)
return _auth_middleware
def get_current_user(request: Request) -> Optional[Dict[str, Any]]:
"""Get current authenticated user from request."""
return getattr(request.state, "user", None)
def require_authentication(request: Request) -> Dict[str, Any]:
"""Require authentication and return user info."""
user = get_current_user(request)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required",
headers={"WWW-Authenticate": "Bearer"},
)
return user
def require_role(role: str):
"""Dependency to require specific role."""
def dependency(request: Request) -> Dict[str, Any]:
user = require_authentication(request)
auth_middleware = get_auth_middleware(request.app.state.settings)
if not auth_middleware.check_permission(user, role):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Role '{role}' required",
)
return user
return dependency

375
v1/src/middleware/cors.py Normal file
View File

@@ -0,0 +1,375 @@
"""
CORS middleware for WiFi-DensePose API
"""
import logging
from typing import List, Optional, Union, Callable
from urllib.parse import urlparse
from fastapi import Request, Response
from fastapi.middleware.cors import CORSMiddleware as FastAPICORSMiddleware
from starlette.types import ASGIApp
from src.config.settings import Settings
logger = logging.getLogger(__name__)
class CORSMiddleware:
"""Enhanced CORS middleware with additional security features."""
def __init__(
self,
app: ASGIApp,
settings: Settings,
allow_origins: Optional[List[str]] = None,
allow_methods: Optional[List[str]] = None,
allow_headers: Optional[List[str]] = None,
allow_credentials: bool = False,
expose_headers: Optional[List[str]] = None,
max_age: int = 600,
):
self.app = app
self.settings = settings
self.allow_origins = allow_origins or settings.cors_origins
self.allow_methods = allow_methods or ["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"]
self.allow_headers = allow_headers or [
"Accept",
"Accept-Language",
"Content-Language",
"Content-Type",
"Authorization",
"X-Requested-With",
"X-Request-ID",
"X-User-Agent",
]
self.allow_credentials = allow_credentials or settings.cors_allow_credentials
self.expose_headers = expose_headers or [
"X-Request-ID",
"X-Response-Time",
"X-Rate-Limit-Remaining",
"X-Rate-Limit-Reset",
]
self.max_age = max_age
# Security settings
self.strict_origin_check = settings.is_production
self.log_cors_violations = True
async def __call__(self, scope, receive, send):
"""ASGI middleware implementation."""
if scope["type"] != "http":
await self.app(scope, receive, send)
return
request = Request(scope, receive)
# Check if this is a CORS preflight request
if request.method == "OPTIONS" and "access-control-request-method" in request.headers:
response = await self._handle_preflight(request)
await response(scope, receive, send)
return
# Handle actual request
async def send_wrapper(message):
if message["type"] == "http.response.start":
# Add CORS headers to response
headers = dict(message.get("headers", []))
cors_headers = self._get_cors_headers(request)
for key, value in cors_headers.items():
headers[key.encode()] = value.encode()
message["headers"] = list(headers.items())
await send(message)
await self.app(scope, receive, send_wrapper)
async def _handle_preflight(self, request: Request) -> Response:
"""Handle CORS preflight request."""
origin = request.headers.get("origin")
requested_method = request.headers.get("access-control-request-method")
requested_headers = request.headers.get("access-control-request-headers", "")
# Validate origin
if not self._is_origin_allowed(origin):
if self.log_cors_violations:
logger.warning(f"CORS preflight rejected for origin: {origin}")
return Response(
status_code=403,
content="CORS preflight request rejected",
headers={"Content-Type": "text/plain"}
)
# Validate method
if requested_method not in self.allow_methods:
if self.log_cors_violations:
logger.warning(f"CORS preflight rejected for method: {requested_method}")
return Response(
status_code=405,
content="Method not allowed",
headers={"Content-Type": "text/plain"}
)
# Validate headers
if requested_headers:
requested_header_list = [h.strip().lower() for h in requested_headers.split(",")]
allowed_headers_lower = [h.lower() for h in self.allow_headers]
for header in requested_header_list:
if header not in allowed_headers_lower:
if self.log_cors_violations:
logger.warning(f"CORS preflight rejected for header: {header}")
return Response(
status_code=400,
content="Header not allowed",
headers={"Content-Type": "text/plain"}
)
# Build preflight response headers
headers = {
"Access-Control-Allow-Origin": origin,
"Access-Control-Allow-Methods": ", ".join(self.allow_methods),
"Access-Control-Allow-Headers": ", ".join(self.allow_headers),
"Access-Control-Max-Age": str(self.max_age),
}
if self.allow_credentials:
headers["Access-Control-Allow-Credentials"] = "true"
if self.expose_headers:
headers["Access-Control-Expose-Headers"] = ", ".join(self.expose_headers)
logger.debug(f"CORS preflight approved for origin: {origin}")
return Response(
status_code=200,
headers=headers
)
def _get_cors_headers(self, request: Request) -> dict:
"""Get CORS headers for actual request."""
origin = request.headers.get("origin")
headers = {}
if self._is_origin_allowed(origin):
headers["Access-Control-Allow-Origin"] = origin
if self.allow_credentials:
headers["Access-Control-Allow-Credentials"] = "true"
if self.expose_headers:
headers["Access-Control-Expose-Headers"] = ", ".join(self.expose_headers)
return headers
def _is_origin_allowed(self, origin: Optional[str]) -> bool:
"""Check if origin is allowed."""
if not origin:
return not self.strict_origin_check
# Allow all origins in development
if not self.settings.is_production and "*" in self.allow_origins:
return True
# Check exact matches
if origin in self.allow_origins:
return True
# Check wildcard patterns
for allowed_origin in self.allow_origins:
if allowed_origin == "*":
return not self.strict_origin_check
if self._match_origin_pattern(origin, allowed_origin):
return True
return False
def _match_origin_pattern(self, origin: str, pattern: str) -> bool:
"""Match origin against pattern with wildcard support."""
if "*" not in pattern:
return origin == pattern
# Simple wildcard matching
if pattern.startswith("*."):
domain = pattern[2:]
parsed_origin = urlparse(origin)
origin_host = parsed_origin.netloc
# Check if origin ends with the domain
return origin_host.endswith(domain) or origin_host == domain[1:] if domain.startswith('.') else origin_host == domain
return False
def setup_cors_middleware(app: ASGIApp, settings: Settings) -> ASGIApp:
"""Setup CORS middleware for the application."""
if settings.cors_enabled:
logger.info("Setting up CORS middleware")
# Use FastAPI's built-in CORS middleware for basic functionality
app = FastAPICORSMiddleware(
app,
allow_origins=settings.cors_origins,
allow_credentials=settings.cors_allow_credentials,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
allow_headers=[
"Accept",
"Accept-Language",
"Content-Language",
"Content-Type",
"Authorization",
"X-Requested-With",
"X-Request-ID",
"X-User-Agent",
],
expose_headers=[
"X-Request-ID",
"X-Response-Time",
"X-Rate-Limit-Remaining",
"X-Rate-Limit-Reset",
],
max_age=600,
)
logger.info(f"CORS enabled for origins: {settings.cors_origins}")
else:
logger.info("CORS middleware disabled")
return app
class CORSConfig:
"""CORS configuration helper."""
@staticmethod
def development_config() -> dict:
"""Get CORS configuration for development."""
return {
"allow_origins": ["*"],
"allow_credentials": True,
"allow_methods": ["*"],
"allow_headers": ["*"],
"expose_headers": [
"X-Request-ID",
"X-Response-Time",
"X-Rate-Limit-Remaining",
"X-Rate-Limit-Reset",
],
"max_age": 600,
}
@staticmethod
def production_config(allowed_origins: List[str]) -> dict:
"""Get CORS configuration for production."""
return {
"allow_origins": allowed_origins,
"allow_credentials": True,
"allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
"allow_headers": [
"Accept",
"Accept-Language",
"Content-Language",
"Content-Type",
"Authorization",
"X-Requested-With",
"X-Request-ID",
"X-User-Agent",
],
"expose_headers": [
"X-Request-ID",
"X-Response-Time",
"X-Rate-Limit-Remaining",
"X-Rate-Limit-Reset",
],
"max_age": 3600, # 1 hour for production
}
@staticmethod
def api_only_config(allowed_origins: List[str]) -> dict:
"""Get CORS configuration for API-only access."""
return {
"allow_origins": allowed_origins,
"allow_credentials": False,
"allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
"allow_headers": [
"Accept",
"Content-Type",
"Authorization",
"X-Request-ID",
],
"expose_headers": [
"X-Request-ID",
"X-Rate-Limit-Remaining",
"X-Rate-Limit-Reset",
],
"max_age": 3600,
}
@staticmethod
def websocket_config(allowed_origins: List[str]) -> dict:
"""Get CORS configuration for WebSocket connections."""
return {
"allow_origins": allowed_origins,
"allow_credentials": True,
"allow_methods": ["GET", "OPTIONS"],
"allow_headers": [
"Accept",
"Authorization",
"Sec-WebSocket-Protocol",
"Sec-WebSocket-Extensions",
],
"expose_headers": [],
"max_age": 86400, # 24 hours for WebSocket
}
def validate_cors_config(settings: Settings) -> List[str]:
"""Validate CORS configuration and return issues."""
issues = []
if not settings.cors_enabled:
return issues
# Check origins
if not settings.cors_origins:
issues.append("CORS is enabled but no origins are configured")
# Check for wildcard in production
if settings.is_production and "*" in settings.cors_origins:
issues.append("Wildcard origin (*) should not be used in production")
# Validate origin formats
for origin in settings.cors_origins:
if origin != "*" and not origin.startswith(("http://", "https://")):
issues.append(f"Invalid origin format: {origin}")
# Check credentials with wildcard
if settings.cors_allow_credentials and "*" in settings.cors_origins:
issues.append("Cannot use credentials with wildcard origin")
return issues
def get_cors_headers_for_origin(origin: str, settings: Settings) -> dict:
"""Get appropriate CORS headers for a specific origin."""
headers = {}
if not settings.cors_enabled:
return headers
# Check if origin is allowed
cors_middleware = CORSMiddleware(None, settings)
if cors_middleware._is_origin_allowed(origin):
headers["Access-Control-Allow-Origin"] = origin
if settings.cors_allow_credentials:
headers["Access-Control-Allow-Credentials"] = "true"
return headers

View File

@@ -0,0 +1,505 @@
"""
Global error handling middleware for WiFi-DensePose API
"""
import logging
import traceback
import time
from typing import Dict, Any, Optional, Callable, Union
from datetime import datetime
from fastapi import Request, Response, HTTPException, status
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
from pydantic import ValidationError
from src.config.settings import Settings
from src.logger import get_request_context
logger = logging.getLogger(__name__)
class ErrorResponse:
"""Standardized error response format."""
def __init__(
self,
error_code: str,
message: str,
details: Optional[Dict[str, Any]] = None,
status_code: int = 500,
request_id: Optional[str] = None,
):
self.error_code = error_code
self.message = message
self.details = details or {}
self.status_code = status_code
self.request_id = request_id
self.timestamp = datetime.utcnow().isoformat()
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON response."""
response = {
"error": {
"code": self.error_code,
"message": self.message,
"timestamp": self.timestamp,
}
}
if self.details:
response["error"]["details"] = self.details
if self.request_id:
response["error"]["request_id"] = self.request_id
return response
def to_response(self) -> JSONResponse:
"""Convert to FastAPI JSONResponse."""
headers = {}
if self.request_id:
headers["X-Request-ID"] = self.request_id
return JSONResponse(
status_code=self.status_code,
content=self.to_dict(),
headers=headers
)
class ErrorHandler:
"""Central error handler for the application."""
def __init__(self, settings: Settings):
self.settings = settings
self.include_traceback = settings.debug and settings.is_development
self.log_errors = True
def handle_http_exception(self, request: Request, exc: HTTPException) -> ErrorResponse:
"""Handle HTTP exceptions."""
request_context = get_request_context()
request_id = request_context.get("request_id")
# Log the error
if self.log_errors:
logger.warning(
f"HTTP {exc.status_code}: {exc.detail} - "
f"{request.method} {request.url.path} - "
f"Request ID: {request_id}"
)
# Determine error code
error_code = self._get_error_code_for_status(exc.status_code)
# Build error details
details = {}
if hasattr(exc, "headers") and exc.headers:
details["headers"] = exc.headers
if self.include_traceback and hasattr(exc, "__traceback__"):
details["traceback"] = traceback.format_exception(
type(exc), exc, exc.__traceback__
)
return ErrorResponse(
error_code=error_code,
message=str(exc.detail),
details=details,
status_code=exc.status_code,
request_id=request_id
)
def handle_validation_error(self, request: Request, exc: RequestValidationError) -> ErrorResponse:
"""Handle request validation errors."""
request_context = get_request_context()
request_id = request_context.get("request_id")
# Log the error
if self.log_errors:
logger.warning(
f"Validation error: {exc.errors()} - "
f"{request.method} {request.url.path} - "
f"Request ID: {request_id}"
)
# Format validation errors
validation_details = []
for error in exc.errors():
validation_details.append({
"field": ".".join(str(loc) for loc in error["loc"]),
"message": error["msg"],
"type": error["type"],
"input": error.get("input"),
})
details = {
"validation_errors": validation_details,
"error_count": len(validation_details)
}
if self.include_traceback:
details["traceback"] = traceback.format_exception(
type(exc), exc, exc.__traceback__
)
return ErrorResponse(
error_code="VALIDATION_ERROR",
message="Request validation failed",
details=details,
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
request_id=request_id
)
def handle_pydantic_error(self, request: Request, exc: ValidationError) -> ErrorResponse:
"""Handle Pydantic validation errors."""
request_context = get_request_context()
request_id = request_context.get("request_id")
# Log the error
if self.log_errors:
logger.warning(
f"Pydantic validation error: {exc.errors()} - "
f"{request.method} {request.url.path} - "
f"Request ID: {request_id}"
)
# Format validation errors
validation_details = []
for error in exc.errors():
validation_details.append({
"field": ".".join(str(loc) for loc in error["loc"]),
"message": error["msg"],
"type": error["type"],
})
details = {
"validation_errors": validation_details,
"error_count": len(validation_details)
}
return ErrorResponse(
error_code="DATA_VALIDATION_ERROR",
message="Data validation failed",
details=details,
status_code=status.HTTP_400_BAD_REQUEST,
request_id=request_id
)
def handle_generic_exception(self, request: Request, exc: Exception) -> ErrorResponse:
"""Handle generic exceptions."""
request_context = get_request_context()
request_id = request_context.get("request_id")
# Log the error
if self.log_errors:
logger.error(
f"Unhandled exception: {type(exc).__name__}: {exc} - "
f"{request.method} {request.url.path} - "
f"Request ID: {request_id}",
exc_info=True
)
# Determine error details
details = {
"exception_type": type(exc).__name__,
}
if self.include_traceback:
details["traceback"] = traceback.format_exception(
type(exc), exc, exc.__traceback__
)
# Don't expose internal error details in production
if self.settings.is_production:
message = "An internal server error occurred"
else:
message = str(exc) or "An unexpected error occurred"
return ErrorResponse(
error_code="INTERNAL_SERVER_ERROR",
message=message,
details=details,
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
request_id=request_id
)
def handle_database_error(self, request: Request, exc: Exception) -> ErrorResponse:
"""Handle database-related errors."""
request_context = get_request_context()
request_id = request_context.get("request_id")
# Log the error
if self.log_errors:
logger.error(
f"Database error: {type(exc).__name__}: {exc} - "
f"{request.method} {request.url.path} - "
f"Request ID: {request_id}",
exc_info=True
)
details = {
"exception_type": type(exc).__name__,
"category": "database"
}
if self.include_traceback:
details["traceback"] = traceback.format_exception(
type(exc), exc, exc.__traceback__
)
return ErrorResponse(
error_code="DATABASE_ERROR",
message="Database operation failed" if self.settings.is_production else str(exc),
details=details,
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
request_id=request_id
)
def handle_external_service_error(self, request: Request, exc: Exception) -> ErrorResponse:
"""Handle external service errors."""
request_context = get_request_context()
request_id = request_context.get("request_id")
# Log the error
if self.log_errors:
logger.error(
f"External service error: {type(exc).__name__}: {exc} - "
f"{request.method} {request.url.path} - "
f"Request ID: {request_id}",
exc_info=True
)
details = {
"exception_type": type(exc).__name__,
"category": "external_service"
}
return ErrorResponse(
error_code="EXTERNAL_SERVICE_ERROR",
message="External service unavailable" if self.settings.is_production else str(exc),
details=details,
status_code=status.HTTP_502_BAD_GATEWAY,
request_id=request_id
)
def _get_error_code_for_status(self, status_code: int) -> str:
"""Get error code for HTTP status code."""
error_codes = {
400: "BAD_REQUEST",
401: "UNAUTHORIZED",
403: "FORBIDDEN",
404: "NOT_FOUND",
405: "METHOD_NOT_ALLOWED",
409: "CONFLICT",
422: "UNPROCESSABLE_ENTITY",
429: "TOO_MANY_REQUESTS",
500: "INTERNAL_SERVER_ERROR",
502: "BAD_GATEWAY",
503: "SERVICE_UNAVAILABLE",
504: "GATEWAY_TIMEOUT",
}
return error_codes.get(status_code, "HTTP_ERROR")
class ErrorHandlingMiddleware:
"""Error handling middleware for FastAPI."""
def __init__(self, app, settings: Settings):
self.app = app
self.settings = settings
self.error_handler = ErrorHandler(settings)
async def __call__(self, scope, receive, send):
"""Process request through error handling middleware."""
if scope["type"] != "http":
await self.app(scope, receive, send)
return
start_time = time.time()
try:
await self.app(scope, receive, send)
except Exception as exc:
# Create a mock request for error handling
from starlette.requests import Request
request = Request(scope, receive)
# Handle different exception types
if isinstance(exc, HTTPException):
error_response = self.error_handler.handle_http_exception(request, exc)
elif isinstance(exc, RequestValidationError):
error_response = self.error_handler.handle_validation_error(request, exc)
elif isinstance(exc, ValidationError):
error_response = self.error_handler.handle_pydantic_error(request, exc)
else:
# Check for specific error types
if self._is_database_error(exc):
error_response = self.error_handler.handle_database_error(request, exc)
elif self._is_external_service_error(exc):
error_response = self.error_handler.handle_external_service_error(request, exc)
else:
error_response = self.error_handler.handle_generic_exception(request, exc)
# Send the error response
response = error_response.to_response()
await response(scope, receive, send)
finally:
# Log request processing time
processing_time = time.time() - start_time
logger.debug(f"Error handling middleware processing time: {processing_time:.3f}s")
def _is_database_error(self, exc: Exception) -> bool:
"""Check if exception is database-related."""
database_exceptions = [
"sqlalchemy",
"psycopg2",
"pymongo",
"redis",
"ConnectionError",
"OperationalError",
"IntegrityError",
]
exc_module = getattr(type(exc), "__module__", "")
exc_name = type(exc).__name__
return any(
db_exc in exc_module or db_exc in exc_name
for db_exc in database_exceptions
)
def _is_external_service_error(self, exc: Exception) -> bool:
"""Check if exception is external service-related."""
external_exceptions = [
"requests",
"httpx",
"aiohttp",
"urllib",
"ConnectionError",
"TimeoutError",
"ConnectTimeout",
"ReadTimeout",
]
exc_module = getattr(type(exc), "__module__", "")
exc_name = type(exc).__name__
return any(
ext_exc in exc_module or ext_exc in exc_name
for ext_exc in external_exceptions
)
def setup_error_handling(app, settings: Settings):
"""Setup error handling for the application."""
logger.info("Setting up error handling middleware")
error_handler = ErrorHandler(settings)
# Add exception handlers
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
error_response = error_handler.handle_http_exception(request, exc)
return error_response.to_response()
@app.exception_handler(StarletteHTTPException)
async def starlette_http_exception_handler(request: Request, exc: StarletteHTTPException):
# Convert Starlette HTTPException to FastAPI HTTPException
fastapi_exc = HTTPException(status_code=exc.status_code, detail=exc.detail)
error_response = error_handler.handle_http_exception(request, fastapi_exc)
return error_response.to_response()
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
error_response = error_handler.handle_validation_error(request, exc)
return error_response.to_response()
@app.exception_handler(ValidationError)
async def pydantic_exception_handler(request: Request, exc: ValidationError):
error_response = error_handler.handle_pydantic_error(request, exc)
return error_response.to_response()
@app.exception_handler(Exception)
async def generic_exception_handler(request: Request, exc: Exception):
error_response = error_handler.handle_generic_exception(request, exc)
return error_response.to_response()
# Add middleware for additional error handling
# Note: We use exception handlers instead of custom middleware to avoid ASGI conflicts
# The middleware approach is commented out but kept for reference
# middleware = ErrorHandlingMiddleware(app, settings)
# app.add_middleware(ErrorHandlingMiddleware, settings=settings)
logger.info("Error handling configured")
class CustomHTTPException(HTTPException):
"""Custom HTTP exception with additional context."""
def __init__(
self,
status_code: int,
detail: str,
error_code: Optional[str] = None,
context: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
):
super().__init__(status_code=status_code, detail=detail, headers=headers)
self.error_code = error_code
self.context = context or {}
class BusinessLogicError(CustomHTTPException):
"""Exception for business logic errors."""
def __init__(self, message: str, context: Optional[Dict[str, Any]] = None):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail=message,
error_code="BUSINESS_LOGIC_ERROR",
context=context
)
class ResourceNotFoundError(CustomHTTPException):
"""Exception for resource not found errors."""
def __init__(self, resource: str, identifier: str):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"{resource} not found",
error_code="RESOURCE_NOT_FOUND",
context={"resource": resource, "identifier": identifier}
)
class ConflictError(CustomHTTPException):
"""Exception for conflict errors."""
def __init__(self, message: str, context: Optional[Dict[str, Any]] = None):
super().__init__(
status_code=status.HTTP_409_CONFLICT,
detail=message,
error_code="CONFLICT_ERROR",
context=context
)
class ServiceUnavailableError(CustomHTTPException):
"""Exception for service unavailable errors."""
def __init__(self, service: str, reason: Optional[str] = None):
detail = f"{service} service is unavailable"
if reason:
detail += f": {reason}"
super().__init__(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=detail,
error_code="SERVICE_UNAVAILABLE",
context={"service": service, "reason": reason}
)

View File

@@ -0,0 +1,465 @@
"""
Rate limiting middleware for WiFi-DensePose API
"""
import asyncio
import logging
import time
from typing import Dict, Any, Optional, Callable, Tuple
from datetime import datetime, timedelta
from collections import defaultdict, deque
from dataclasses import dataclass
from fastapi import Request, Response, HTTPException, status
from starlette.types import ASGIApp
from src.config.settings import Settings
logger = logging.getLogger(__name__)
@dataclass
class RateLimitInfo:
"""Rate limit information."""
requests: int
window_start: float
window_size: int
limit: int
@property
def remaining(self) -> int:
"""Get remaining requests in current window."""
return max(0, self.limit - self.requests)
@property
def reset_time(self) -> float:
"""Get time when window resets."""
return self.window_start + self.window_size
@property
def is_exceeded(self) -> bool:
"""Check if rate limit is exceeded."""
return self.requests >= self.limit
class TokenBucket:
"""Token bucket algorithm for rate limiting."""
def __init__(self, capacity: int, refill_rate: float):
self.capacity = capacity
self.tokens = capacity
self.refill_rate = refill_rate
self.last_refill = time.time()
self._lock = asyncio.Lock()
async def consume(self, tokens: int = 1) -> bool:
"""Try to consume tokens from bucket."""
async with self._lock:
now = time.time()
# Refill tokens based on time elapsed
time_passed = now - self.last_refill
tokens_to_add = time_passed * self.refill_rate
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
self.last_refill = now
# Check if we have enough tokens
if self.tokens >= tokens:
self.tokens -= tokens
return True
return False
def get_info(self) -> Dict[str, Any]:
"""Get bucket information."""
return {
"capacity": self.capacity,
"tokens": self.tokens,
"refill_rate": self.refill_rate,
"last_refill": self.last_refill
}
class SlidingWindowCounter:
"""Sliding window counter for rate limiting."""
def __init__(self, window_size: int, limit: int):
self.window_size = window_size
self.limit = limit
self.requests = deque()
self._lock = asyncio.Lock()
async def is_allowed(self) -> Tuple[bool, RateLimitInfo]:
"""Check if request is allowed."""
async with self._lock:
now = time.time()
window_start = now - self.window_size
# Remove old requests outside the window
while self.requests and self.requests[0] < window_start:
self.requests.popleft()
# Check if limit is exceeded
current_requests = len(self.requests)
allowed = current_requests < self.limit
if allowed:
self.requests.append(now)
rate_limit_info = RateLimitInfo(
requests=current_requests + (1 if allowed else 0),
window_start=window_start,
window_size=self.window_size,
limit=self.limit
)
return allowed, rate_limit_info
class RateLimiter:
"""Rate limiter with multiple algorithms."""
def __init__(self, settings: Settings):
self.settings = settings
self.enabled = settings.enable_rate_limiting
# Rate limit configurations
self.default_limit = settings.rate_limit_requests
self.authenticated_limit = settings.rate_limit_authenticated_requests
self.window_size = settings.rate_limit_window
# Storage for rate limit data
self._sliding_windows: Dict[str, SlidingWindowCounter] = {}
self._token_buckets: Dict[str, TokenBucket] = {}
# Cleanup task
self._cleanup_task: Optional[asyncio.Task] = None
self._cleanup_interval = 300 # 5 minutes
async def start(self):
"""Start rate limiter background tasks."""
if self.enabled:
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info("Rate limiter started")
async def stop(self):
"""Stop rate limiter background tasks."""
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
logger.info("Rate limiter stopped")
async def _cleanup_loop(self):
"""Background task to cleanup old rate limit data."""
while True:
try:
await asyncio.sleep(self._cleanup_interval)
await self._cleanup_old_data()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in rate limiter cleanup: {e}")
async def _cleanup_old_data(self):
"""Remove old rate limit data."""
now = time.time()
cutoff = now - (self.window_size * 2) # Keep data for 2 windows
# Cleanup sliding windows
keys_to_remove = []
for key, window in self._sliding_windows.items():
# Remove old requests
while window.requests and window.requests[0] < cutoff:
window.requests.popleft()
# Remove empty windows
if not window.requests:
keys_to_remove.append(key)
for key in keys_to_remove:
del self._sliding_windows[key]
logger.debug(f"Cleaned up {len(keys_to_remove)} old rate limit windows")
def _get_client_identifier(self, request: Request) -> str:
"""Get client identifier for rate limiting."""
# Try to get user ID from authenticated request
user = getattr(request.state, "user", None)
if user:
return f"user:{user.get('username', 'unknown')}"
# Fall back to IP address
client_ip = self._get_client_ip(request)
return f"ip:{client_ip}"
def _get_client_ip(self, request: Request) -> str:
"""Get client IP address."""
# Check for forwarded headers
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# Fall back to direct connection
return request.client.host if request.client else "unknown"
def _get_rate_limit(self, request: Request) -> int:
"""Get rate limit for request."""
# Check if user is authenticated
user = getattr(request.state, "user", None)
if user:
return self.authenticated_limit
return self.default_limit
def _get_rate_limit_key(self, request: Request) -> str:
"""Get rate limit key for request."""
client_id = self._get_client_identifier(request)
endpoint = f"{request.method}:{request.url.path}"
return f"{client_id}:{endpoint}"
async def check_rate_limit(self, request: Request) -> Tuple[bool, RateLimitInfo]:
"""Check if request is within rate limits."""
if not self.enabled:
# Return dummy info when rate limiting is disabled
return True, RateLimitInfo(
requests=0,
window_start=time.time(),
window_size=self.window_size,
limit=float('inf')
)
key = self._get_rate_limit_key(request)
limit = self._get_rate_limit(request)
# Get or create sliding window counter
if key not in self._sliding_windows:
self._sliding_windows[key] = SlidingWindowCounter(self.window_size, limit)
window = self._sliding_windows[key]
# Update limit if it changed (e.g., user authenticated)
window.limit = limit
return await window.is_allowed()
async def check_token_bucket(self, request: Request, tokens: int = 1) -> bool:
"""Check rate limit using token bucket algorithm."""
if not self.enabled:
return True
key = self._get_client_identifier(request)
limit = self._get_rate_limit(request)
# Get or create token bucket
if key not in self._token_buckets:
# Refill rate: limit per window size
refill_rate = limit / self.window_size
self._token_buckets[key] = TokenBucket(limit, refill_rate)
bucket = self._token_buckets[key]
return await bucket.consume(tokens)
def get_rate_limit_headers(self, rate_limit_info: RateLimitInfo) -> Dict[str, str]:
"""Get rate limit headers for response."""
return {
"X-RateLimit-Limit": str(rate_limit_info.limit),
"X-RateLimit-Remaining": str(rate_limit_info.remaining),
"X-RateLimit-Reset": str(int(rate_limit_info.reset_time)),
"X-RateLimit-Window": str(rate_limit_info.window_size),
}
async def get_stats(self) -> Dict[str, Any]:
"""Get rate limiter statistics."""
return {
"enabled": self.enabled,
"default_limit": self.default_limit,
"authenticated_limit": self.authenticated_limit,
"window_size": self.window_size,
"active_windows": len(self._sliding_windows),
"active_buckets": len(self._token_buckets),
}
class RateLimitMiddleware:
"""Rate limiting middleware for FastAPI."""
def __init__(self, settings: Settings):
self.settings = settings
self.rate_limiter = RateLimiter(settings)
self.enabled = settings.enable_rate_limiting
async def __call__(self, request: Request, call_next: Callable) -> Response:
"""Process request through rate limiting middleware."""
if not self.enabled:
return await call_next(request)
# Skip rate limiting for certain paths
if self._should_skip_rate_limit(request):
return await call_next(request)
try:
# Check rate limit
allowed, rate_limit_info = await self.rate_limiter.check_rate_limit(request)
if not allowed:
# Rate limit exceeded
logger.warning(
f"Rate limit exceeded for {self.rate_limiter._get_client_identifier(request)} "
f"on {request.method} {request.url.path}"
)
headers = self.rate_limiter.get_rate_limit_headers(rate_limit_info)
headers["Retry-After"] = str(int(rate_limit_info.reset_time - time.time()))
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded",
headers=headers
)
# Process request
response = await call_next(request)
# Add rate limit headers to response
headers = self.rate_limiter.get_rate_limit_headers(rate_limit_info)
for key, value in headers.items():
response.headers[key] = value
return response
except HTTPException:
raise
except Exception as e:
logger.error(f"Rate limiting middleware error: {e}")
# Continue without rate limiting on error
return await call_next(request)
def _should_skip_rate_limit(self, request: Request) -> bool:
"""Check if rate limiting should be skipped for this request."""
path = request.url.path
# Skip rate limiting for these paths
skip_paths = [
"/health",
"/metrics",
"/docs",
"/redoc",
"/openapi.json",
"/static",
]
return any(path.startswith(skip_path) for skip_path in skip_paths)
async def start(self):
"""Start rate limiting middleware."""
await self.rate_limiter.start()
async def stop(self):
"""Stop rate limiting middleware."""
await self.rate_limiter.stop()
# Global rate limit middleware instance
_rate_limit_middleware: Optional[RateLimitMiddleware] = None
def get_rate_limit_middleware(settings: Settings) -> RateLimitMiddleware:
"""Get rate limit middleware instance."""
global _rate_limit_middleware
if _rate_limit_middleware is None:
_rate_limit_middleware = RateLimitMiddleware(settings)
return _rate_limit_middleware
def setup_rate_limiting(app: ASGIApp, settings: Settings) -> ASGIApp:
"""Setup rate limiting middleware for the application."""
if settings.enable_rate_limiting:
logger.info("Setting up rate limiting middleware")
middleware = get_rate_limit_middleware(settings)
# Add middleware to app
@app.middleware("http")
async def rate_limit_middleware(request: Request, call_next):
return await middleware(request, call_next)
logger.info(
f"Rate limiting enabled - Default: {settings.rate_limit_requests}/"
f"{settings.rate_limit_window}s, Authenticated: "
f"{settings.rate_limit_authenticated_requests}/{settings.rate_limit_window}s"
)
else:
logger.info("Rate limiting disabled")
return app
class RateLimitConfig:
"""Rate limiting configuration helper."""
@staticmethod
def development_config() -> dict:
"""Get rate limiting configuration for development."""
return {
"enable_rate_limiting": False, # Disabled in development
"rate_limit_requests": 1000,
"rate_limit_authenticated_requests": 5000,
"rate_limit_window": 3600, # 1 hour
}
@staticmethod
def production_config() -> dict:
"""Get rate limiting configuration for production."""
return {
"enable_rate_limiting": True,
"rate_limit_requests": 100, # 100 requests per hour for unauthenticated
"rate_limit_authenticated_requests": 1000, # 1000 requests per hour for authenticated
"rate_limit_window": 3600, # 1 hour
}
@staticmethod
def api_config() -> dict:
"""Get rate limiting configuration for API access."""
return {
"enable_rate_limiting": True,
"rate_limit_requests": 60, # 60 requests per minute
"rate_limit_authenticated_requests": 300, # 300 requests per minute
"rate_limit_window": 60, # 1 minute
}
@staticmethod
def strict_config() -> dict:
"""Get strict rate limiting configuration."""
return {
"enable_rate_limiting": True,
"rate_limit_requests": 10, # 10 requests per minute
"rate_limit_authenticated_requests": 100, # 100 requests per minute
"rate_limit_window": 60, # 1 minute
}
def validate_rate_limit_config(settings: Settings) -> list:
"""Validate rate limiting configuration."""
issues = []
if settings.enable_rate_limiting:
if settings.rate_limit_requests <= 0:
issues.append("Rate limit requests must be positive")
if settings.rate_limit_authenticated_requests <= 0:
issues.append("Authenticated rate limit requests must be positive")
if settings.rate_limit_window <= 0:
issues.append("Rate limit window must be positive")
if settings.rate_limit_authenticated_requests < settings.rate_limit_requests:
issues.append("Authenticated rate limit should be higher than default rate limit")
return issues

View File

View File

@@ -0,0 +1,279 @@
"""DensePose head for WiFi-DensePose system."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Any, Tuple, List
class DensePoseError(Exception):
"""Exception raised for DensePose head errors."""
pass
class DensePoseHead(nn.Module):
"""DensePose head for body part segmentation and UV coordinate regression."""
def __init__(self, config: Dict[str, Any]):
"""Initialize DensePose head.
Args:
config: Configuration dictionary with head parameters
"""
super().__init__()
self._validate_config(config)
self.config = config
self.input_channels = config['input_channels']
self.num_body_parts = config['num_body_parts']
self.num_uv_coordinates = config['num_uv_coordinates']
self.hidden_channels = config.get('hidden_channels', [128, 64])
self.kernel_size = config.get('kernel_size', 3)
self.padding = config.get('padding', 1)
self.dropout_rate = config.get('dropout_rate', 0.1)
self.use_deformable_conv = config.get('use_deformable_conv', False)
self.use_fpn = config.get('use_fpn', False)
self.fpn_levels = config.get('fpn_levels', [2, 3, 4, 5])
self.output_stride = config.get('output_stride', 4)
# Feature Pyramid Network (optional)
if self.use_fpn:
self.fpn = self._build_fpn()
# Shared feature processing
self.shared_conv = self._build_shared_layers()
# Segmentation head for body part classification
self.segmentation_head = self._build_segmentation_head()
# UV regression head for coordinate prediction
self.uv_regression_head = self._build_uv_regression_head()
# Initialize weights
self._initialize_weights()
def _validate_config(self, config: Dict[str, Any]):
"""Validate configuration parameters."""
required_fields = ['input_channels', 'num_body_parts', 'num_uv_coordinates']
for field in required_fields:
if field not in config:
raise ValueError(f"Missing required field: {field}")
if config['input_channels'] <= 0:
raise ValueError("input_channels must be positive")
if config['num_body_parts'] <= 0:
raise ValueError("num_body_parts must be positive")
if config['num_uv_coordinates'] <= 0:
raise ValueError("num_uv_coordinates must be positive")
def _build_fpn(self) -> nn.Module:
"""Build Feature Pyramid Network."""
return nn.ModuleDict({
f'level_{level}': nn.Conv2d(self.input_channels, self.input_channels, 1)
for level in self.fpn_levels
})
def _build_shared_layers(self) -> nn.Module:
"""Build shared feature processing layers."""
layers = []
in_channels = self.input_channels
for hidden_dim in self.hidden_channels:
layers.extend([
nn.Conv2d(in_channels, hidden_dim,
kernel_size=self.kernel_size,
padding=self.padding),
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
nn.Dropout2d(self.dropout_rate)
])
in_channels = hidden_dim
return nn.Sequential(*layers)
def _build_segmentation_head(self) -> nn.Module:
"""Build segmentation head for body part classification."""
final_hidden = self.hidden_channels[-1] if self.hidden_channels else self.input_channels
return nn.Sequential(
nn.Conv2d(final_hidden, final_hidden // 2,
kernel_size=self.kernel_size,
padding=self.padding),
nn.BatchNorm2d(final_hidden // 2),
nn.ReLU(inplace=True),
nn.Dropout2d(self.dropout_rate),
# Upsampling to increase resolution
nn.ConvTranspose2d(final_hidden // 2, final_hidden // 4,
kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(final_hidden // 4),
nn.ReLU(inplace=True),
nn.Conv2d(final_hidden // 4, self.num_body_parts + 1, kernel_size=1),
# +1 for background class
)
def _build_uv_regression_head(self) -> nn.Module:
"""Build UV regression head for coordinate prediction."""
final_hidden = self.hidden_channels[-1] if self.hidden_channels else self.input_channels
return nn.Sequential(
nn.Conv2d(final_hidden, final_hidden // 2,
kernel_size=self.kernel_size,
padding=self.padding),
nn.BatchNorm2d(final_hidden // 2),
nn.ReLU(inplace=True),
nn.Dropout2d(self.dropout_rate),
# Upsampling to increase resolution
nn.ConvTranspose2d(final_hidden // 2, final_hidden // 4,
kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(final_hidden // 4),
nn.ReLU(inplace=True),
nn.Conv2d(final_hidden // 4, self.num_uv_coordinates, kernel_size=1),
)
def _initialize_weights(self):
"""Initialize network weights."""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Forward pass through the DensePose head.
Args:
x: Input feature tensor of shape (batch_size, channels, height, width)
Returns:
Dictionary containing:
- segmentation: Body part logits (batch_size, num_parts+1, height, width)
- uv_coordinates: UV coordinates (batch_size, 2, height, width)
"""
# Validate input shape
if x.shape[1] != self.input_channels:
raise DensePoseError(f"Expected {self.input_channels} input channels, got {x.shape[1]}")
# Apply FPN if enabled
if self.use_fpn:
# Simple FPN processing - in practice this would be more sophisticated
x = self.fpn['level_2'](x)
# Shared feature processing
shared_features = self.shared_conv(x)
# Segmentation branch
segmentation_logits = self.segmentation_head(shared_features)
# UV regression branch
uv_coordinates = self.uv_regression_head(shared_features)
uv_coordinates = torch.sigmoid(uv_coordinates) # Normalize to [0, 1]
return {
'segmentation': segmentation_logits,
'uv_coordinates': uv_coordinates
}
def compute_segmentation_loss(self, pred_logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Compute segmentation loss.
Args:
pred_logits: Predicted segmentation logits
target: Target segmentation masks
Returns:
Computed cross-entropy loss
"""
return F.cross_entropy(pred_logits, target, ignore_index=-1)
def compute_uv_loss(self, pred_uv: torch.Tensor, target_uv: torch.Tensor) -> torch.Tensor:
"""Compute UV coordinate regression loss.
Args:
pred_uv: Predicted UV coordinates
target_uv: Target UV coordinates
Returns:
Computed L1 loss
"""
return F.l1_loss(pred_uv, target_uv)
def compute_total_loss(self, predictions: Dict[str, torch.Tensor],
seg_target: torch.Tensor,
uv_target: torch.Tensor,
seg_weight: float = 1.0,
uv_weight: float = 1.0) -> torch.Tensor:
"""Compute total loss combining segmentation and UV losses.
Args:
predictions: Dictionary of predictions
seg_target: Target segmentation masks
uv_target: Target UV coordinates
seg_weight: Weight for segmentation loss
uv_weight: Weight for UV loss
Returns:
Combined loss
"""
seg_loss = self.compute_segmentation_loss(predictions['segmentation'], seg_target)
uv_loss = self.compute_uv_loss(predictions['uv_coordinates'], uv_target)
return seg_weight * seg_loss + uv_weight * uv_loss
def get_prediction_confidence(self, predictions: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Get prediction confidence scores.
Args:
predictions: Dictionary of predictions
Returns:
Dictionary of confidence scores
"""
seg_logits = predictions['segmentation']
uv_coords = predictions['uv_coordinates']
# Segmentation confidence: max probability
seg_probs = F.softmax(seg_logits, dim=1)
seg_confidence = torch.max(seg_probs, dim=1)[0]
# UV confidence: inverse of prediction variance
uv_variance = torch.var(uv_coords, dim=1, keepdim=True)
uv_confidence = 1.0 / (1.0 + uv_variance)
return {
'segmentation_confidence': seg_confidence,
'uv_confidence': uv_confidence.squeeze(1)
}
def post_process_predictions(self, predictions: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Post-process predictions for final output.
Args:
predictions: Raw predictions from forward pass
Returns:
Post-processed predictions
"""
seg_logits = predictions['segmentation']
uv_coords = predictions['uv_coordinates']
# Convert logits to class predictions
body_parts = torch.argmax(seg_logits, dim=1)
# Get confidence scores
confidence = self.get_prediction_confidence(predictions)
return {
'body_parts': body_parts,
'uv_coordinates': uv_coords,
'confidence_scores': confidence
}

View File

@@ -0,0 +1,301 @@
"""Modality translation network for WiFi-DensePose system."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Any, List
class ModalityTranslationError(Exception):
"""Exception raised for modality translation errors."""
pass
class ModalityTranslationNetwork(nn.Module):
"""Neural network for translating CSI data to visual feature space."""
def __init__(self, config: Dict[str, Any]):
"""Initialize modality translation network.
Args:
config: Configuration dictionary with network parameters
"""
super().__init__()
self._validate_config(config)
self.config = config
self.input_channels = config['input_channels']
self.hidden_channels = config['hidden_channels']
self.output_channels = config['output_channels']
self.kernel_size = config.get('kernel_size', 3)
self.stride = config.get('stride', 1)
self.padding = config.get('padding', 1)
self.dropout_rate = config.get('dropout_rate', 0.1)
self.activation = config.get('activation', 'relu')
self.normalization = config.get('normalization', 'batch')
self.use_attention = config.get('use_attention', False)
self.attention_heads = config.get('attention_heads', 8)
# Encoder: CSI -> Feature space
self.encoder = self._build_encoder()
# Decoder: Feature space -> Visual-like features
self.decoder = self._build_decoder()
# Attention mechanism
if self.use_attention:
self.attention = self._build_attention()
# Initialize weights
self._initialize_weights()
def _validate_config(self, config: Dict[str, Any]):
"""Validate configuration parameters."""
required_fields = ['input_channels', 'hidden_channels', 'output_channels']
for field in required_fields:
if field not in config:
raise ValueError(f"Missing required field: {field}")
if config['input_channels'] <= 0:
raise ValueError("input_channels must be positive")
if not config['hidden_channels'] or len(config['hidden_channels']) == 0:
raise ValueError("hidden_channels must be a non-empty list")
if config['output_channels'] <= 0:
raise ValueError("output_channels must be positive")
def _build_encoder(self) -> nn.ModuleList:
"""Build encoder network."""
layers = nn.ModuleList()
# Initial convolution
in_channels = self.input_channels
for i, out_channels in enumerate(self.hidden_channels):
layer_block = nn.Sequential(
nn.Conv2d(in_channels, out_channels,
kernel_size=self.kernel_size,
stride=self.stride if i == 0 else 2,
padding=self.padding),
self._get_normalization(out_channels),
self._get_activation(),
nn.Dropout2d(self.dropout_rate)
)
layers.append(layer_block)
in_channels = out_channels
return layers
def _build_decoder(self) -> nn.ModuleList:
"""Build decoder network."""
layers = nn.ModuleList()
# Start with the last hidden channel size
in_channels = self.hidden_channels[-1]
# Progressive upsampling (reverse of encoder)
for i, out_channels in enumerate(reversed(self.hidden_channels[:-1])):
layer_block = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels,
kernel_size=self.kernel_size,
stride=2,
padding=self.padding,
output_padding=1),
self._get_normalization(out_channels),
self._get_activation(),
nn.Dropout2d(self.dropout_rate)
)
layers.append(layer_block)
in_channels = out_channels
# Final output layer
final_layer = nn.Sequential(
nn.Conv2d(in_channels, self.output_channels,
kernel_size=self.kernel_size,
padding=self.padding),
nn.Tanh() # Normalize output
)
layers.append(final_layer)
return layers
def _get_normalization(self, channels: int) -> nn.Module:
"""Get normalization layer."""
if self.normalization == 'batch':
return nn.BatchNorm2d(channels)
elif self.normalization == 'instance':
return nn.InstanceNorm2d(channels)
elif self.normalization == 'layer':
return nn.GroupNorm(1, channels)
else:
return nn.Identity()
def _get_activation(self) -> nn.Module:
"""Get activation function."""
if self.activation == 'relu':
return nn.ReLU(inplace=True)
elif self.activation == 'leaky_relu':
return nn.LeakyReLU(0.2, inplace=True)
elif self.activation == 'gelu':
return nn.GELU()
else:
return nn.ReLU(inplace=True)
def _build_attention(self) -> nn.Module:
"""Build attention mechanism."""
return nn.MultiheadAttention(
embed_dim=self.hidden_channels[-1],
num_heads=self.attention_heads,
dropout=self.dropout_rate,
batch_first=True
)
def _initialize_weights(self):
"""Initialize network weights."""
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the network.
Args:
x: Input CSI tensor of shape (batch_size, channels, height, width)
Returns:
Translated features tensor
"""
# Validate input shape
if x.shape[1] != self.input_channels:
raise ModalityTranslationError(f"Expected {self.input_channels} input channels, got {x.shape[1]}")
# Encode CSI data
encoded_features = self.encode(x)
# Decode to visual-like features
decoded = self.decode(encoded_features)
return decoded
def encode(self, x: torch.Tensor) -> List[torch.Tensor]:
"""Encode input through encoder layers.
Args:
x: Input tensor
Returns:
List of feature maps from each encoder layer
"""
features = []
current = x
for layer in self.encoder:
current = layer(current)
features.append(current)
return features
def decode(self, encoded_features: List[torch.Tensor]) -> torch.Tensor:
"""Decode features through decoder layers.
Args:
encoded_features: List of encoded feature maps
Returns:
Decoded output tensor
"""
# Start with the last encoded feature
current = encoded_features[-1]
# Apply attention if enabled
if self.use_attention:
batch_size, channels, height, width = current.shape
# Reshape for attention: (batch, seq_len, embed_dim)
current_flat = current.view(batch_size, channels, -1).transpose(1, 2)
attended, _ = self.attention(current_flat, current_flat, current_flat)
current = attended.transpose(1, 2).view(batch_size, channels, height, width)
# Apply decoder layers
for layer in self.decoder:
current = layer(current)
return current
def compute_translation_loss(self, predicted: torch.Tensor, target: torch.Tensor, loss_type: str = 'mse') -> torch.Tensor:
"""Compute translation loss between predicted and target features.
Args:
predicted: Predicted feature tensor
target: Target feature tensor
loss_type: Type of loss ('mse', 'l1', 'smooth_l1')
Returns:
Computed loss tensor
"""
if loss_type == 'mse':
return F.mse_loss(predicted, target)
elif loss_type == 'l1':
return F.l1_loss(predicted, target)
elif loss_type == 'smooth_l1':
return F.smooth_l1_loss(predicted, target)
else:
return F.mse_loss(predicted, target)
def get_feature_statistics(self, features: torch.Tensor) -> Dict[str, float]:
"""Get statistics of feature tensor.
Args:
features: Feature tensor to analyze
Returns:
Dictionary of feature statistics
"""
with torch.no_grad():
return {
'mean': features.mean().item(),
'std': features.std().item(),
'min': features.min().item(),
'max': features.max().item(),
'sparsity': (features == 0).float().mean().item()
}
def get_intermediate_features(self, x: torch.Tensor) -> Dict[str, Any]:
"""Get intermediate features for visualization.
Args:
x: Input tensor
Returns:
Dictionary containing intermediate features
"""
result = {}
# Get encoder features
encoder_features = self.encode(x)
result['encoder_features'] = encoder_features
# Get decoder features
decoder_features = []
current = encoder_features[-1]
if self.use_attention:
batch_size, channels, height, width = current.shape
current_flat = current.view(batch_size, channels, -1).transpose(1, 2)
attended, attention_weights = self.attention(current_flat, current_flat, current_flat)
current = attended.transpose(1, 2).view(batch_size, channels, height, width)
result['attention_weights'] = attention_weights
for layer in self.decoder:
current = layer(current)
decoder_features.append(current)
result['decoder_features'] = decoder_features
return result

View File

@@ -0,0 +1,19 @@
"""
Services package for WiFi-DensePose API
"""
from .orchestrator import ServiceOrchestrator
from .health_check import HealthCheckService
from .metrics import MetricsService
from .pose_service import PoseService
from .stream_service import StreamService
from .hardware_service import HardwareService
__all__ = [
'ServiceOrchestrator',
'HealthCheckService',
'MetricsService',
'PoseService',
'StreamService',
'HardwareService'
]

View File

@@ -0,0 +1,482 @@
"""
Hardware interface service for WiFi-DensePose API
"""
import logging
import asyncio
import time
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta
import numpy as np
from src.config.settings import Settings
from src.config.domains import DomainConfig
from src.core.router_interface import RouterInterface
logger = logging.getLogger(__name__)
class HardwareService:
"""Service for hardware interface operations."""
def __init__(self, settings: Settings, domain_config: DomainConfig):
"""Initialize hardware service."""
self.settings = settings
self.domain_config = domain_config
self.logger = logging.getLogger(__name__)
# Router interfaces
self.router_interfaces: Dict[str, RouterInterface] = {}
# Service state
self.is_running = False
self.last_error = None
# Data collection statistics
self.stats = {
"total_samples": 0,
"successful_samples": 0,
"failed_samples": 0,
"average_sample_rate": 0.0,
"last_sample_time": None,
"connected_routers": 0
}
# Background tasks
self.collection_task = None
self.monitoring_task = None
# Data buffers
self.recent_samples = []
self.max_recent_samples = 1000
async def initialize(self):
"""Initialize the hardware service."""
await self.start()
async def start(self):
"""Start the hardware service."""
if self.is_running:
return
try:
self.logger.info("Starting hardware service...")
# Initialize router interfaces
await self._initialize_routers()
self.is_running = True
# Start background tasks
if not self.settings.mock_hardware:
self.collection_task = asyncio.create_task(self._data_collection_loop())
self.monitoring_task = asyncio.create_task(self._monitoring_loop())
self.logger.info("Hardware service started successfully")
except Exception as e:
self.last_error = str(e)
self.logger.error(f"Failed to start hardware service: {e}")
raise
async def stop(self):
"""Stop the hardware service."""
self.is_running = False
# Cancel background tasks
if self.collection_task:
self.collection_task.cancel()
try:
await self.collection_task
except asyncio.CancelledError:
pass
if self.monitoring_task:
self.monitoring_task.cancel()
try:
await self.monitoring_task
except asyncio.CancelledError:
pass
# Disconnect from routers
await self._disconnect_routers()
self.logger.info("Hardware service stopped")
async def _initialize_routers(self):
"""Initialize router interfaces."""
try:
# Get router configurations from domain config
routers = self.domain_config.get_all_routers()
for router_config in routers:
if not router_config.enabled:
continue
router_id = router_config.router_id
# Create router interface
router_interface = RouterInterface(
router_id=router_id,
host=router_config.ip_address,
port=22, # Default SSH port
username="admin", # Default username
password="admin", # Default password
interface=router_config.interface,
mock_mode=self.settings.mock_hardware
)
# Connect to router (always connect, even in mock mode)
await router_interface.connect()
self.router_interfaces[router_id] = router_interface
self.logger.info(f"Router interface initialized: {router_id}")
self.stats["connected_routers"] = len(self.router_interfaces)
if not self.router_interfaces:
self.logger.warning("No router interfaces configured")
except Exception as e:
self.logger.error(f"Failed to initialize routers: {e}")
raise
async def _disconnect_routers(self):
"""Disconnect from all routers."""
for router_id, interface in self.router_interfaces.items():
try:
await interface.disconnect()
self.logger.info(f"Disconnected from router: {router_id}")
except Exception as e:
self.logger.error(f"Error disconnecting from router {router_id}: {e}")
self.router_interfaces.clear()
self.stats["connected_routers"] = 0
async def _data_collection_loop(self):
"""Background loop for data collection."""
try:
while self.is_running:
start_time = time.time()
# Collect data from all routers
await self._collect_data_from_routers()
# Calculate sleep time to maintain polling interval
elapsed = time.time() - start_time
sleep_time = max(0, self.settings.hardware_polling_interval - elapsed)
if sleep_time > 0:
await asyncio.sleep(sleep_time)
except asyncio.CancelledError:
self.logger.info("Data collection loop cancelled")
except Exception as e:
self.logger.error(f"Error in data collection loop: {e}")
self.last_error = str(e)
async def _monitoring_loop(self):
"""Background loop for hardware monitoring."""
try:
while self.is_running:
# Monitor router connections
await self._monitor_router_health()
# Update statistics
self._update_sample_rate_stats()
# Wait before next check
await asyncio.sleep(30) # Check every 30 seconds
except asyncio.CancelledError:
self.logger.info("Monitoring loop cancelled")
except Exception as e:
self.logger.error(f"Error in monitoring loop: {e}")
async def _collect_data_from_routers(self):
"""Collect CSI data from all connected routers."""
for router_id, interface in self.router_interfaces.items():
try:
# Get CSI data from router
csi_data = await interface.get_csi_data()
if csi_data is not None:
# Process the collected data
await self._process_collected_data(router_id, csi_data)
self.stats["successful_samples"] += 1
self.stats["last_sample_time"] = datetime.now().isoformat()
else:
self.stats["failed_samples"] += 1
self.stats["total_samples"] += 1
except Exception as e:
self.logger.error(f"Error collecting data from router {router_id}: {e}")
self.stats["failed_samples"] += 1
self.stats["total_samples"] += 1
async def _process_collected_data(self, router_id: str, csi_data: np.ndarray):
"""Process collected CSI data."""
try:
# Create sample metadata
metadata = {
"router_id": router_id,
"timestamp": datetime.now().isoformat(),
"sample_rate": self.stats["average_sample_rate"],
"data_shape": csi_data.shape if hasattr(csi_data, 'shape') else None
}
# Add to recent samples buffer
sample = {
"router_id": router_id,
"timestamp": metadata["timestamp"],
"data": csi_data,
"metadata": metadata
}
self.recent_samples.append(sample)
# Maintain buffer size
if len(self.recent_samples) > self.max_recent_samples:
self.recent_samples.pop(0)
# Notify other services (this would typically be done through an event system)
# For now, we'll just log the data collection
self.logger.debug(f"Collected CSI data from {router_id}: shape {csi_data.shape if hasattr(csi_data, 'shape') else 'unknown'}")
except Exception as e:
self.logger.error(f"Error processing collected data: {e}")
async def _monitor_router_health(self):
"""Monitor health of router connections."""
healthy_routers = 0
for router_id, interface in self.router_interfaces.items():
try:
is_healthy = await interface.check_health()
if is_healthy:
healthy_routers += 1
else:
self.logger.warning(f"Router {router_id} is unhealthy")
# Try to reconnect if not in mock mode
if not self.settings.mock_hardware:
try:
await interface.reconnect()
self.logger.info(f"Reconnected to router {router_id}")
except Exception as e:
self.logger.error(f"Failed to reconnect to router {router_id}: {e}")
except Exception as e:
self.logger.error(f"Error checking health of router {router_id}: {e}")
self.stats["connected_routers"] = healthy_routers
def _update_sample_rate_stats(self):
"""Update sample rate statistics."""
if len(self.recent_samples) < 2:
return
# Calculate sample rate from recent samples
recent_count = min(100, len(self.recent_samples))
recent_samples = self.recent_samples[-recent_count:]
if len(recent_samples) >= 2:
# Calculate time differences
time_diffs = []
for i in range(1, len(recent_samples)):
try:
t1 = datetime.fromisoformat(recent_samples[i-1]["timestamp"])
t2 = datetime.fromisoformat(recent_samples[i]["timestamp"])
diff = (t2 - t1).total_seconds()
if diff > 0:
time_diffs.append(diff)
except Exception:
continue
if time_diffs:
avg_interval = sum(time_diffs) / len(time_diffs)
self.stats["average_sample_rate"] = 1.0 / avg_interval if avg_interval > 0 else 0.0
async def get_router_status(self, router_id: str) -> Dict[str, Any]:
"""Get status of a specific router."""
if router_id not in self.router_interfaces:
raise ValueError(f"Router {router_id} not found")
interface = self.router_interfaces[router_id]
try:
is_healthy = await interface.check_health()
status = await interface.get_status()
return {
"router_id": router_id,
"healthy": is_healthy,
"connected": status.get("connected", False),
"last_data_time": status.get("last_data_time"),
"error_count": status.get("error_count", 0),
"configuration": status.get("configuration", {})
}
except Exception as e:
return {
"router_id": router_id,
"healthy": False,
"connected": False,
"error": str(e)
}
async def get_all_router_status(self) -> List[Dict[str, Any]]:
"""Get status of all routers."""
statuses = []
for router_id in self.router_interfaces:
try:
status = await self.get_router_status(router_id)
statuses.append(status)
except Exception as e:
statuses.append({
"router_id": router_id,
"healthy": False,
"error": str(e)
})
return statuses
async def get_recent_data(self, router_id: Optional[str] = None, limit: int = 100) -> List[Dict[str, Any]]:
"""Get recent CSI data samples."""
samples = self.recent_samples[-limit:] if limit else self.recent_samples
if router_id:
samples = [s for s in samples if s["router_id"] == router_id]
# Convert numpy arrays to lists for JSON serialization
result = []
for sample in samples:
sample_copy = sample.copy()
if isinstance(sample_copy["data"], np.ndarray):
sample_copy["data"] = sample_copy["data"].tolist()
result.append(sample_copy)
return result
async def get_status(self) -> Dict[str, Any]:
"""Get service status."""
return {
"status": "healthy" if self.is_running and not self.last_error else "unhealthy",
"running": self.is_running,
"last_error": self.last_error,
"statistics": self.stats.copy(),
"configuration": {
"mock_hardware": self.settings.mock_hardware,
"wifi_interface": self.settings.wifi_interface,
"polling_interval": self.settings.hardware_polling_interval,
"buffer_size": self.settings.csi_buffer_size
},
"routers": await self.get_all_router_status()
}
async def get_metrics(self) -> Dict[str, Any]:
"""Get service metrics."""
total_samples = self.stats["total_samples"]
success_rate = self.stats["successful_samples"] / max(1, total_samples)
return {
"hardware_service": {
"total_samples": total_samples,
"successful_samples": self.stats["successful_samples"],
"failed_samples": self.stats["failed_samples"],
"success_rate": success_rate,
"average_sample_rate": self.stats["average_sample_rate"],
"connected_routers": self.stats["connected_routers"],
"last_sample_time": self.stats["last_sample_time"]
}
}
async def reset(self):
"""Reset service state."""
self.stats = {
"total_samples": 0,
"successful_samples": 0,
"failed_samples": 0,
"average_sample_rate": 0.0,
"last_sample_time": None,
"connected_routers": len(self.router_interfaces)
}
self.recent_samples.clear()
self.last_error = None
self.logger.info("Hardware service reset")
async def trigger_manual_collection(self, router_id: Optional[str] = None) -> Dict[str, Any]:
"""Manually trigger data collection."""
if not self.is_running:
raise RuntimeError("Hardware service is not running")
results = {}
if router_id:
# Collect from specific router
if router_id not in self.router_interfaces:
raise ValueError(f"Router {router_id} not found")
interface = self.router_interfaces[router_id]
try:
csi_data = await interface.get_csi_data()
if csi_data is not None:
await self._process_collected_data(router_id, csi_data)
results[router_id] = {"success": True, "data_shape": csi_data.shape if hasattr(csi_data, 'shape') else None}
else:
results[router_id] = {"success": False, "error": "No data received"}
except Exception as e:
results[router_id] = {"success": False, "error": str(e)}
else:
# Collect from all routers
await self._collect_data_from_routers()
results = {"message": "Manual collection triggered for all routers"}
return results
async def health_check(self) -> Dict[str, Any]:
"""Perform health check."""
try:
status = "healthy" if self.is_running and not self.last_error else "unhealthy"
# Check router health
healthy_routers = 0
total_routers = len(self.router_interfaces)
for router_id, interface in self.router_interfaces.items():
try:
if await interface.check_health():
healthy_routers += 1
except Exception:
pass
return {
"status": status,
"message": self.last_error if self.last_error else "Hardware service is running normally",
"connected_routers": f"{healthy_routers}/{total_routers}",
"metrics": {
"total_samples": self.stats["total_samples"],
"success_rate": (
self.stats["successful_samples"] / max(1, self.stats["total_samples"])
),
"average_sample_rate": self.stats["average_sample_rate"]
}
}
except Exception as e:
return {
"status": "unhealthy",
"message": f"Health check failed: {str(e)}"
}
async def is_ready(self) -> bool:
"""Check if service is ready."""
return self.is_running and len(self.router_interfaces) > 0

View File

@@ -0,0 +1,465 @@
"""
Health check service for WiFi-DensePose API
"""
import asyncio
import logging
import time
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from enum import Enum
from src.config.settings import Settings
logger = logging.getLogger(__name__)
class HealthStatus(Enum):
"""Health status enumeration."""
HEALTHY = "healthy"
DEGRADED = "degraded"
UNHEALTHY = "unhealthy"
UNKNOWN = "unknown"
@dataclass
class HealthCheck:
"""Health check result."""
name: str
status: HealthStatus
message: str
timestamp: datetime = field(default_factory=datetime.utcnow)
duration_ms: float = 0.0
details: Dict[str, Any] = field(default_factory=dict)
@dataclass
class ServiceHealth:
"""Service health information."""
name: str
status: HealthStatus
last_check: Optional[datetime] = None
checks: List[HealthCheck] = field(default_factory=list)
uptime: float = 0.0
error_count: int = 0
last_error: Optional[str] = None
class HealthCheckService:
"""Service for monitoring application health."""
def __init__(self, settings: Settings):
self.settings = settings
self._services: Dict[str, ServiceHealth] = {}
self._start_time = time.time()
self._initialized = False
self._running = False
async def initialize(self):
"""Initialize health check service."""
if self._initialized:
return
logger.info("Initializing health check service")
# Initialize service health tracking
self._services = {
"api": ServiceHealth("api", HealthStatus.UNKNOWN),
"database": ServiceHealth("database", HealthStatus.UNKNOWN),
"redis": ServiceHealth("redis", HealthStatus.UNKNOWN),
"hardware": ServiceHealth("hardware", HealthStatus.UNKNOWN),
"pose": ServiceHealth("pose", HealthStatus.UNKNOWN),
"stream": ServiceHealth("stream", HealthStatus.UNKNOWN),
}
self._initialized = True
logger.info("Health check service initialized")
async def start(self):
"""Start health check service."""
if not self._initialized:
await self.initialize()
self._running = True
logger.info("Health check service started")
async def shutdown(self):
"""Shutdown health check service."""
self._running = False
logger.info("Health check service shut down")
async def perform_health_checks(self) -> Dict[str, HealthCheck]:
"""Perform all health checks."""
if not self._running:
return {}
logger.debug("Performing health checks")
results = {}
# Perform individual health checks
checks = [
self._check_api_health(),
self._check_database_health(),
self._check_redis_health(),
self._check_hardware_health(),
self._check_pose_health(),
self._check_stream_health(),
]
# Run checks concurrently
check_results = await asyncio.gather(*checks, return_exceptions=True)
# Process results
for i, result in enumerate(check_results):
check_name = ["api", "database", "redis", "hardware", "pose", "stream"][i]
if isinstance(result, Exception):
health_check = HealthCheck(
name=check_name,
status=HealthStatus.UNHEALTHY,
message=f"Health check failed: {result}"
)
else:
health_check = result
results[check_name] = health_check
self._update_service_health(check_name, health_check)
logger.debug(f"Completed {len(results)} health checks")
return results
async def _check_api_health(self) -> HealthCheck:
"""Check API health."""
start_time = time.time()
try:
# Basic API health check
uptime = time.time() - self._start_time
status = HealthStatus.HEALTHY
message = "API is running normally"
details = {
"uptime_seconds": uptime,
"uptime_formatted": str(timedelta(seconds=int(uptime)))
}
except Exception as e:
status = HealthStatus.UNHEALTHY
message = f"API health check failed: {e}"
details = {"error": str(e)}
duration_ms = (time.time() - start_time) * 1000
return HealthCheck(
name="api",
status=status,
message=message,
duration_ms=duration_ms,
details=details
)
async def _check_database_health(self) -> HealthCheck:
"""Check database health."""
start_time = time.time()
try:
# Import here to avoid circular imports
from src.database.connection import get_database_manager
db_manager = get_database_manager()
if not db_manager.is_connected():
status = HealthStatus.UNHEALTHY
message = "Database is not connected"
details = {"connected": False}
else:
# Test database connection
await db_manager.test_connection()
status = HealthStatus.HEALTHY
message = "Database is connected and responsive"
details = {
"connected": True,
"pool_size": db_manager.get_pool_size(),
"active_connections": db_manager.get_active_connections()
}
except Exception as e:
status = HealthStatus.UNHEALTHY
message = f"Database health check failed: {e}"
details = {"error": str(e)}
duration_ms = (time.time() - start_time) * 1000
return HealthCheck(
name="database",
status=status,
message=message,
duration_ms=duration_ms,
details=details
)
async def _check_redis_health(self) -> HealthCheck:
"""Check Redis health."""
start_time = time.time()
try:
redis_config = self.settings.get_redis_url()
if not redis_config:
status = HealthStatus.UNKNOWN
message = "Redis is not configured"
details = {"configured": False}
else:
# Test Redis connection
import redis.asyncio as redis
redis_client = redis.from_url(redis_config)
await redis_client.ping()
await redis_client.close()
status = HealthStatus.HEALTHY
message = "Redis is connected and responsive"
details = {"connected": True}
except Exception as e:
status = HealthStatus.UNHEALTHY
message = f"Redis health check failed: {e}"
details = {"error": str(e)}
duration_ms = (time.time() - start_time) * 1000
return HealthCheck(
name="redis",
status=status,
message=message,
duration_ms=duration_ms,
details=details
)
async def _check_hardware_health(self) -> HealthCheck:
"""Check hardware service health."""
start_time = time.time()
try:
# Import here to avoid circular imports
from src.api.dependencies import get_hardware_service
hardware_service = get_hardware_service()
if hasattr(hardware_service, 'get_status'):
status_info = await hardware_service.get_status()
if status_info.get("status") == "healthy":
status = HealthStatus.HEALTHY
message = "Hardware service is operational"
else:
status = HealthStatus.DEGRADED
message = f"Hardware service status: {status_info.get('status', 'unknown')}"
details = status_info
else:
status = HealthStatus.UNKNOWN
message = "Hardware service status unavailable"
details = {}
except Exception as e:
status = HealthStatus.UNHEALTHY
message = f"Hardware health check failed: {e}"
details = {"error": str(e)}
duration_ms = (time.time() - start_time) * 1000
return HealthCheck(
name="hardware",
status=status,
message=message,
duration_ms=duration_ms,
details=details
)
async def _check_pose_health(self) -> HealthCheck:
"""Check pose service health."""
start_time = time.time()
try:
# Import here to avoid circular imports
from src.api.dependencies import get_pose_service
pose_service = get_pose_service()
if hasattr(pose_service, 'get_status'):
status_info = await pose_service.get_status()
if status_info.get("status") == "healthy":
status = HealthStatus.HEALTHY
message = "Pose service is operational"
else:
status = HealthStatus.DEGRADED
message = f"Pose service status: {status_info.get('status', 'unknown')}"
details = status_info
else:
status = HealthStatus.UNKNOWN
message = "Pose service status unavailable"
details = {}
except Exception as e:
status = HealthStatus.UNHEALTHY
message = f"Pose health check failed: {e}"
details = {"error": str(e)}
duration_ms = (time.time() - start_time) * 1000
return HealthCheck(
name="pose",
status=status,
message=message,
duration_ms=duration_ms,
details=details
)
async def _check_stream_health(self) -> HealthCheck:
"""Check stream service health."""
start_time = time.time()
try:
# Import here to avoid circular imports
from src.api.dependencies import get_stream_service
stream_service = get_stream_service()
if hasattr(stream_service, 'get_status'):
status_info = await stream_service.get_status()
if status_info.get("status") == "healthy":
status = HealthStatus.HEALTHY
message = "Stream service is operational"
else:
status = HealthStatus.DEGRADED
message = f"Stream service status: {status_info.get('status', 'unknown')}"
details = status_info
else:
status = HealthStatus.UNKNOWN
message = "Stream service status unavailable"
details = {}
except Exception as e:
status = HealthStatus.UNHEALTHY
message = f"Stream health check failed: {e}"
details = {"error": str(e)}
duration_ms = (time.time() - start_time) * 1000
return HealthCheck(
name="stream",
status=status,
message=message,
duration_ms=duration_ms,
details=details
)
def _update_service_health(self, service_name: str, health_check: HealthCheck):
"""Update service health information."""
if service_name not in self._services:
self._services[service_name] = ServiceHealth(service_name, HealthStatus.UNKNOWN)
service_health = self._services[service_name]
service_health.status = health_check.status
service_health.last_check = health_check.timestamp
service_health.uptime = time.time() - self._start_time
# Keep last 10 checks
service_health.checks.append(health_check)
if len(service_health.checks) > 10:
service_health.checks.pop(0)
# Update error tracking
if health_check.status == HealthStatus.UNHEALTHY:
service_health.error_count += 1
service_health.last_error = health_check.message
async def get_overall_health(self) -> Dict[str, Any]:
"""Get overall system health."""
if not self._services:
return {
"status": HealthStatus.UNKNOWN.value,
"message": "Health checks not initialized"
}
# Determine overall status
statuses = [service.status for service in self._services.values()]
if all(status == HealthStatus.HEALTHY for status in statuses):
overall_status = HealthStatus.HEALTHY
message = "All services are healthy"
elif any(status == HealthStatus.UNHEALTHY for status in statuses):
overall_status = HealthStatus.UNHEALTHY
unhealthy_services = [
name for name, service in self._services.items()
if service.status == HealthStatus.UNHEALTHY
]
message = f"Unhealthy services: {', '.join(unhealthy_services)}"
elif any(status == HealthStatus.DEGRADED for status in statuses):
overall_status = HealthStatus.DEGRADED
degraded_services = [
name for name, service in self._services.items()
if service.status == HealthStatus.DEGRADED
]
message = f"Degraded services: {', '.join(degraded_services)}"
else:
overall_status = HealthStatus.UNKNOWN
message = "System health status unknown"
return {
"status": overall_status.value,
"message": message,
"timestamp": datetime.utcnow().isoformat(),
"uptime": time.time() - self._start_time,
"services": {
name: {
"status": service.status.value,
"last_check": service.last_check.isoformat() if service.last_check else None,
"error_count": service.error_count,
"last_error": service.last_error
}
for name, service in self._services.items()
}
}
async def get_service_health(self, service_name: str) -> Optional[Dict[str, Any]]:
"""Get health information for a specific service."""
service = self._services.get(service_name)
if not service:
return None
return {
"name": service.name,
"status": service.status.value,
"last_check": service.last_check.isoformat() if service.last_check else None,
"uptime": service.uptime,
"error_count": service.error_count,
"last_error": service.last_error,
"recent_checks": [
{
"timestamp": check.timestamp.isoformat(),
"status": check.status.value,
"message": check.message,
"duration_ms": check.duration_ms,
"details": check.details
}
for check in service.checks[-5:] # Last 5 checks
]
}
async def get_status(self) -> Dict[str, Any]:
"""Get health check service status."""
return {
"status": "healthy" if self._running else "stopped",
"initialized": self._initialized,
"running": self._running,
"services_monitored": len(self._services),
"uptime": time.time() - self._start_time
}

431
v1/src/services/metrics.py Normal file
View File

@@ -0,0 +1,431 @@
"""
Metrics collection service for WiFi-DensePose API
"""
import asyncio
import logging
import time
import psutil
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from collections import defaultdict, deque
from src.config.settings import Settings
logger = logging.getLogger(__name__)
@dataclass
class MetricPoint:
"""Single metric data point."""
timestamp: datetime
value: float
labels: Dict[str, str] = field(default_factory=dict)
@dataclass
class MetricSeries:
"""Time series of metric points."""
name: str
description: str
unit: str
points: deque = field(default_factory=lambda: deque(maxlen=1000))
def add_point(self, value: float, labels: Optional[Dict[str, str]] = None):
"""Add a metric point."""
point = MetricPoint(
timestamp=datetime.utcnow(),
value=value,
labels=labels or {}
)
self.points.append(point)
def get_latest(self) -> Optional[MetricPoint]:
"""Get the latest metric point."""
return self.points[-1] if self.points else None
def get_average(self, duration: timedelta) -> Optional[float]:
"""Get average value over a time duration."""
cutoff = datetime.utcnow() - duration
relevant_points = [
point for point in self.points
if point.timestamp >= cutoff
]
if not relevant_points:
return None
return sum(point.value for point in relevant_points) / len(relevant_points)
def get_max(self, duration: timedelta) -> Optional[float]:
"""Get maximum value over a time duration."""
cutoff = datetime.utcnow() - duration
relevant_points = [
point for point in self.points
if point.timestamp >= cutoff
]
if not relevant_points:
return None
return max(point.value for point in relevant_points)
class MetricsService:
"""Service for collecting and managing application metrics."""
def __init__(self, settings: Settings):
self.settings = settings
self._metrics: Dict[str, MetricSeries] = {}
self._counters: Dict[str, float] = defaultdict(float)
self._gauges: Dict[str, float] = {}
self._histograms: Dict[str, List[float]] = defaultdict(list)
self._start_time = time.time()
self._initialized = False
self._running = False
# Initialize standard metrics
self._initialize_standard_metrics()
def _initialize_standard_metrics(self):
"""Initialize standard system and application metrics."""
self._metrics.update({
# System metrics
"system_cpu_usage": MetricSeries(
"system_cpu_usage", "System CPU usage percentage", "percent"
),
"system_memory_usage": MetricSeries(
"system_memory_usage", "System memory usage percentage", "percent"
),
"system_disk_usage": MetricSeries(
"system_disk_usage", "System disk usage percentage", "percent"
),
"system_network_bytes_sent": MetricSeries(
"system_network_bytes_sent", "Network bytes sent", "bytes"
),
"system_network_bytes_recv": MetricSeries(
"system_network_bytes_recv", "Network bytes received", "bytes"
),
# Application metrics
"app_requests_total": MetricSeries(
"app_requests_total", "Total HTTP requests", "count"
),
"app_request_duration": MetricSeries(
"app_request_duration", "HTTP request duration", "seconds"
),
"app_active_connections": MetricSeries(
"app_active_connections", "Active WebSocket connections", "count"
),
"app_pose_detections": MetricSeries(
"app_pose_detections", "Pose detections performed", "count"
),
"app_pose_processing_time": MetricSeries(
"app_pose_processing_time", "Pose processing time", "seconds"
),
"app_csi_data_points": MetricSeries(
"app_csi_data_points", "CSI data points processed", "count"
),
"app_stream_fps": MetricSeries(
"app_stream_fps", "Streaming frames per second", "fps"
),
# Error metrics
"app_errors_total": MetricSeries(
"app_errors_total", "Total application errors", "count"
),
"app_http_errors": MetricSeries(
"app_http_errors", "HTTP errors", "count"
),
})
async def initialize(self):
"""Initialize metrics service."""
if self._initialized:
return
logger.info("Initializing metrics service")
self._initialized = True
logger.info("Metrics service initialized")
async def start(self):
"""Start metrics service."""
if not self._initialized:
await self.initialize()
self._running = True
logger.info("Metrics service started")
async def shutdown(self):
"""Shutdown metrics service."""
self._running = False
logger.info("Metrics service shut down")
async def collect_metrics(self):
"""Collect all metrics."""
if not self._running:
return
logger.debug("Collecting metrics")
# Collect system metrics
await self._collect_system_metrics()
# Collect application metrics
await self._collect_application_metrics()
logger.debug("Metrics collection completed")
async def _collect_system_metrics(self):
"""Collect system-level metrics."""
try:
# CPU usage
cpu_percent = psutil.cpu_percent(interval=1)
self._metrics["system_cpu_usage"].add_point(cpu_percent)
# Memory usage
memory = psutil.virtual_memory()
self._metrics["system_memory_usage"].add_point(memory.percent)
# Disk usage
disk = psutil.disk_usage('/')
disk_percent = (disk.used / disk.total) * 100
self._metrics["system_disk_usage"].add_point(disk_percent)
# Network I/O
network = psutil.net_io_counters()
self._metrics["system_network_bytes_sent"].add_point(network.bytes_sent)
self._metrics["system_network_bytes_recv"].add_point(network.bytes_recv)
except Exception as e:
logger.error(f"Error collecting system metrics: {e}")
async def _collect_application_metrics(self):
"""Collect application-specific metrics."""
try:
# Import here to avoid circular imports
from src.api.websocket.connection_manager import connection_manager
# Active connections
connection_stats = await connection_manager.get_connection_stats()
active_connections = connection_stats.get("active_connections", 0)
self._metrics["app_active_connections"].add_point(active_connections)
# Update counters as metrics
for name, value in self._counters.items():
if name in self._metrics:
self._metrics[name].add_point(value)
# Update gauges as metrics
for name, value in self._gauges.items():
if name in self._metrics:
self._metrics[name].add_point(value)
except Exception as e:
logger.error(f"Error collecting application metrics: {e}")
def increment_counter(self, name: str, value: float = 1.0, labels: Optional[Dict[str, str]] = None):
"""Increment a counter metric."""
self._counters[name] += value
if name in self._metrics:
self._metrics[name].add_point(self._counters[name], labels)
def set_gauge(self, name: str, value: float, labels: Optional[Dict[str, str]] = None):
"""Set a gauge metric value."""
self._gauges[name] = value
if name in self._metrics:
self._metrics[name].add_point(value, labels)
def record_histogram(self, name: str, value: float, labels: Optional[Dict[str, str]] = None):
"""Record a histogram value."""
self._histograms[name].append(value)
# Keep only last 1000 values
if len(self._histograms[name]) > 1000:
self._histograms[name] = self._histograms[name][-1000:]
if name in self._metrics:
self._metrics[name].add_point(value, labels)
def time_function(self, metric_name: str):
"""Decorator to time function execution."""
def decorator(func):
import functools
@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
start_time = time.time()
try:
result = await func(*args, **kwargs)
return result
finally:
duration = time.time() - start_time
self.record_histogram(metric_name, duration)
@functools.wraps(func)
def sync_wrapper(*args, **kwargs):
start_time = time.time()
try:
result = func(*args, **kwargs)
return result
finally:
duration = time.time() - start_time
self.record_histogram(metric_name, duration)
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
return decorator
def get_metric(self, name: str) -> Optional[MetricSeries]:
"""Get a metric series by name."""
return self._metrics.get(name)
def get_metric_value(self, name: str) -> Optional[float]:
"""Get the latest value of a metric."""
metric = self._metrics.get(name)
if metric:
latest = metric.get_latest()
return latest.value if latest else None
return None
def get_counter_value(self, name: str) -> float:
"""Get current counter value."""
return self._counters.get(name, 0.0)
def get_gauge_value(self, name: str) -> Optional[float]:
"""Get current gauge value."""
return self._gauges.get(name)
def get_histogram_stats(self, name: str) -> Dict[str, float]:
"""Get histogram statistics."""
values = self._histograms.get(name, [])
if not values:
return {}
sorted_values = sorted(values)
count = len(sorted_values)
return {
"count": count,
"sum": sum(sorted_values),
"min": sorted_values[0],
"max": sorted_values[-1],
"mean": sum(sorted_values) / count,
"p50": sorted_values[int(count * 0.5)],
"p90": sorted_values[int(count * 0.9)],
"p95": sorted_values[int(count * 0.95)],
"p99": sorted_values[int(count * 0.99)],
}
async def get_all_metrics(self) -> Dict[str, Any]:
"""Get all current metrics."""
metrics = {}
# Current metric values
for name, metric_series in self._metrics.items():
latest = metric_series.get_latest()
if latest:
metrics[name] = {
"value": latest.value,
"timestamp": latest.timestamp.isoformat(),
"description": metric_series.description,
"unit": metric_series.unit,
"labels": latest.labels
}
# Counter values
metrics.update({
f"counter_{name}": value
for name, value in self._counters.items()
})
# Gauge values
metrics.update({
f"gauge_{name}": value
for name, value in self._gauges.items()
})
# Histogram statistics
for name, values in self._histograms.items():
if values:
stats = self.get_histogram_stats(name)
metrics[f"histogram_{name}"] = stats
return metrics
async def get_system_metrics(self) -> Dict[str, Any]:
"""Get system metrics summary."""
return {
"cpu_usage": self.get_metric_value("system_cpu_usage"),
"memory_usage": self.get_metric_value("system_memory_usage"),
"disk_usage": self.get_metric_value("system_disk_usage"),
"network_bytes_sent": self.get_metric_value("system_network_bytes_sent"),
"network_bytes_recv": self.get_metric_value("system_network_bytes_recv"),
}
async def get_application_metrics(self) -> Dict[str, Any]:
"""Get application metrics summary."""
return {
"requests_total": self.get_counter_value("app_requests_total"),
"active_connections": self.get_metric_value("app_active_connections"),
"pose_detections": self.get_counter_value("app_pose_detections"),
"csi_data_points": self.get_counter_value("app_csi_data_points"),
"errors_total": self.get_counter_value("app_errors_total"),
"uptime_seconds": time.time() - self._start_time,
"request_duration_stats": self.get_histogram_stats("app_request_duration"),
"pose_processing_time_stats": self.get_histogram_stats("app_pose_processing_time"),
}
async def get_performance_summary(self) -> Dict[str, Any]:
"""Get performance metrics summary."""
one_hour = timedelta(hours=1)
return {
"system": {
"cpu_avg_1h": self._metrics["system_cpu_usage"].get_average(one_hour),
"memory_avg_1h": self._metrics["system_memory_usage"].get_average(one_hour),
"cpu_max_1h": self._metrics["system_cpu_usage"].get_max(one_hour),
"memory_max_1h": self._metrics["system_memory_usage"].get_max(one_hour),
},
"application": {
"avg_request_duration": self.get_histogram_stats("app_request_duration").get("mean"),
"avg_pose_processing_time": self.get_histogram_stats("app_pose_processing_time").get("mean"),
"total_requests": self.get_counter_value("app_requests_total"),
"total_errors": self.get_counter_value("app_errors_total"),
"error_rate": (
self.get_counter_value("app_errors_total") /
max(self.get_counter_value("app_requests_total"), 1)
) * 100,
}
}
async def get_status(self) -> Dict[str, Any]:
"""Get metrics service status."""
return {
"status": "healthy" if self._running else "stopped",
"initialized": self._initialized,
"running": self._running,
"metrics_count": len(self._metrics),
"counters_count": len(self._counters),
"gauges_count": len(self._gauges),
"histograms_count": len(self._histograms),
"uptime": time.time() - self._start_time
}
def reset_metrics(self):
"""Reset all metrics."""
logger.info("Resetting all metrics")
# Clear metric points but keep series definitions
for metric_series in self._metrics.values():
metric_series.points.clear()
# Reset counters, gauges, and histograms
self._counters.clear()
self._gauges.clear()
self._histograms.clear()
logger.info("All metrics reset")

View File

@@ -0,0 +1,395 @@
"""
Main service orchestrator for WiFi-DensePose API
"""
import asyncio
import logging
from typing import Dict, Any, List, Optional
from contextlib import asynccontextmanager
from src.config.settings import Settings
from src.services.health_check import HealthCheckService
from src.services.metrics import MetricsService
from src.api.dependencies import (
get_hardware_service,
get_pose_service,
get_stream_service
)
from src.api.websocket.connection_manager import connection_manager
from src.api.websocket.pose_stream import PoseStreamHandler
logger = logging.getLogger(__name__)
class ServiceOrchestrator:
"""Main service orchestrator that manages all application services."""
def __init__(self, settings: Settings):
self.settings = settings
self._services: Dict[str, Any] = {}
self._background_tasks: List[asyncio.Task] = []
self._initialized = False
self._started = False
# Core services
self.health_service = HealthCheckService(settings)
self.metrics_service = MetricsService(settings)
# Application services (will be initialized later)
self.hardware_service = None
self.pose_service = None
self.stream_service = None
self.pose_stream_handler = None
async def initialize(self):
"""Initialize all services."""
if self._initialized:
logger.warning("Services already initialized")
return
logger.info("Initializing services...")
try:
# Initialize core services
await self.health_service.initialize()
await self.metrics_service.initialize()
# Initialize application services
await self._initialize_application_services()
# Store services in registry
self._services = {
'health': self.health_service,
'metrics': self.metrics_service,
'hardware': self.hardware_service,
'pose': self.pose_service,
'stream': self.stream_service,
'pose_stream_handler': self.pose_stream_handler,
'connection_manager': connection_manager
}
self._initialized = True
logger.info("All services initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize services: {e}")
await self.shutdown()
raise
async def _initialize_application_services(self):
"""Initialize application-specific services."""
try:
# Initialize hardware service
self.hardware_service = get_hardware_service()
await self.hardware_service.initialize()
logger.info("Hardware service initialized")
# Initialize pose service
self.pose_service = get_pose_service()
await self.pose_service.initialize()
logger.info("Pose service initialized")
# Initialize stream service
self.stream_service = get_stream_service()
await self.stream_service.initialize()
logger.info("Stream service initialized")
# Initialize pose stream handler
self.pose_stream_handler = PoseStreamHandler(
connection_manager=connection_manager,
pose_service=self.pose_service,
stream_service=self.stream_service
)
logger.info("Pose stream handler initialized")
except Exception as e:
logger.error(f"Failed to initialize application services: {e}")
raise
async def start(self):
"""Start all services and background tasks."""
if not self._initialized:
await self.initialize()
if self._started:
logger.warning("Services already started")
return
logger.info("Starting services...")
try:
# Start core services
await self.health_service.start()
await self.metrics_service.start()
# Start application services
await self._start_application_services()
# Start background tasks
await self._start_background_tasks()
self._started = True
logger.info("All services started successfully")
except Exception as e:
logger.error(f"Failed to start services: {e}")
await self.shutdown()
raise
async def _start_application_services(self):
"""Start application-specific services."""
try:
# Start hardware service
if hasattr(self.hardware_service, 'start'):
await self.hardware_service.start()
# Start pose service
if hasattr(self.pose_service, 'start'):
await self.pose_service.start()
# Start stream service
if hasattr(self.stream_service, 'start'):
await self.stream_service.start()
logger.info("Application services started")
except Exception as e:
logger.error(f"Failed to start application services: {e}")
raise
async def _start_background_tasks(self):
"""Start background tasks."""
try:
# Start health check monitoring
if self.settings.health_check_interval > 0:
task = asyncio.create_task(self._health_check_loop())
self._background_tasks.append(task)
# Start metrics collection
if self.settings.metrics_enabled:
task = asyncio.create_task(self._metrics_collection_loop())
self._background_tasks.append(task)
# Start pose streaming if enabled
if self.settings.enable_real_time_processing:
await self.pose_stream_handler.start_streaming()
logger.info(f"Started {len(self._background_tasks)} background tasks")
except Exception as e:
logger.error(f"Failed to start background tasks: {e}")
raise
async def _health_check_loop(self):
"""Background health check loop."""
logger.info("Starting health check loop")
while True:
try:
await self.health_service.perform_health_checks()
await asyncio.sleep(self.settings.health_check_interval)
except asyncio.CancelledError:
logger.info("Health check loop cancelled")
break
except Exception as e:
logger.error(f"Error in health check loop: {e}")
await asyncio.sleep(self.settings.health_check_interval)
async def _metrics_collection_loop(self):
"""Background metrics collection loop."""
logger.info("Starting metrics collection loop")
while True:
try:
await self.metrics_service.collect_metrics()
await asyncio.sleep(60) # Collect metrics every minute
except asyncio.CancelledError:
logger.info("Metrics collection loop cancelled")
break
except Exception as e:
logger.error(f"Error in metrics collection loop: {e}")
await asyncio.sleep(60)
async def shutdown(self):
"""Shutdown all services and cleanup resources."""
logger.info("Shutting down services...")
try:
# Cancel background tasks
for task in self._background_tasks:
if not task.done():
task.cancel()
if self._background_tasks:
await asyncio.gather(*self._background_tasks, return_exceptions=True)
self._background_tasks.clear()
# Stop pose streaming
if self.pose_stream_handler:
await self.pose_stream_handler.shutdown()
# Shutdown connection manager
await connection_manager.shutdown()
# Shutdown application services
await self._shutdown_application_services()
# Shutdown core services
await self.health_service.shutdown()
await self.metrics_service.shutdown()
self._started = False
self._initialized = False
logger.info("All services shut down successfully")
except Exception as e:
logger.error(f"Error during shutdown: {e}")
async def _shutdown_application_services(self):
"""Shutdown application-specific services."""
try:
# Shutdown services in reverse order
if self.stream_service and hasattr(self.stream_service, 'shutdown'):
await self.stream_service.shutdown()
if self.pose_service and hasattr(self.pose_service, 'shutdown'):
await self.pose_service.shutdown()
if self.hardware_service and hasattr(self.hardware_service, 'shutdown'):
await self.hardware_service.shutdown()
logger.info("Application services shut down")
except Exception as e:
logger.error(f"Error shutting down application services: {e}")
async def restart_service(self, service_name: str):
"""Restart a specific service."""
logger.info(f"Restarting service: {service_name}")
service = self._services.get(service_name)
if not service:
raise ValueError(f"Service not found: {service_name}")
try:
# Stop service
if hasattr(service, 'stop'):
await service.stop()
elif hasattr(service, 'shutdown'):
await service.shutdown()
# Reinitialize service
if hasattr(service, 'initialize'):
await service.initialize()
# Start service
if hasattr(service, 'start'):
await service.start()
logger.info(f"Service restarted successfully: {service_name}")
except Exception as e:
logger.error(f"Failed to restart service {service_name}: {e}")
raise
async def reset_services(self):
"""Reset all services to initial state."""
logger.info("Resetting all services")
try:
# Reset application services
if self.hardware_service and hasattr(self.hardware_service, 'reset'):
await self.hardware_service.reset()
if self.pose_service and hasattr(self.pose_service, 'reset'):
await self.pose_service.reset()
if self.stream_service and hasattr(self.stream_service, 'reset'):
await self.stream_service.reset()
# Reset connection manager
await connection_manager.reset()
logger.info("All services reset successfully")
except Exception as e:
logger.error(f"Failed to reset services: {e}")
raise
async def get_service_status(self) -> Dict[str, Any]:
"""Get status of all services."""
status = {}
for name, service in self._services.items():
try:
if hasattr(service, 'get_status'):
status[name] = await service.get_status()
else:
status[name] = {"status": "unknown"}
except Exception as e:
status[name] = {"status": "error", "error": str(e)}
return status
async def get_service_metrics(self) -> Dict[str, Any]:
"""Get metrics from all services."""
metrics = {}
for name, service in self._services.items():
try:
if hasattr(service, 'get_metrics'):
metrics[name] = await service.get_metrics()
elif hasattr(service, 'get_performance_metrics'):
metrics[name] = await service.get_performance_metrics()
except Exception as e:
logger.error(f"Failed to get metrics from {name}: {e}")
metrics[name] = {"error": str(e)}
return metrics
async def get_service_info(self) -> Dict[str, Any]:
"""Get information about all services."""
info = {
"total_services": len(self._services),
"initialized": self._initialized,
"started": self._started,
"background_tasks": len(self._background_tasks),
"services": {}
}
for name, service in self._services.items():
service_info = {
"type": type(service).__name__,
"module": type(service).__module__
}
# Add service-specific info if available
if hasattr(service, 'get_info'):
try:
service_info.update(await service.get_info())
except Exception as e:
service_info["error"] = str(e)
info["services"][name] = service_info
return info
def get_service(self, name: str) -> Optional[Any]:
"""Get a specific service by name."""
return self._services.get(name)
@property
def is_healthy(self) -> bool:
"""Check if all services are healthy."""
return self._initialized and self._started
@asynccontextmanager
async def service_context(self):
"""Context manager for service lifecycle."""
try:
await self.initialize()
await self.start()
yield self
finally:
await self.shutdown()

View File

@@ -0,0 +1,757 @@
"""
Pose estimation service for WiFi-DensePose API
"""
import logging
import asyncio
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta
import numpy as np
import torch
from src.config.settings import Settings
from src.config.domains import DomainConfig
from src.core.csi_processor import CSIProcessor
from src.core.phase_sanitizer import PhaseSanitizer
from src.models.densepose_head import DensePoseHead
from src.models.modality_translation import ModalityTranslationNetwork
logger = logging.getLogger(__name__)
class PoseService:
"""Service for pose estimation operations."""
def __init__(self, settings: Settings, domain_config: DomainConfig):
"""Initialize pose service."""
self.settings = settings
self.domain_config = domain_config
self.logger = logging.getLogger(__name__)
# Initialize components
self.csi_processor = None
self.phase_sanitizer = None
self.densepose_model = None
self.modality_translator = None
# Service state
self.is_initialized = False
self.is_running = False
self.last_error = None
# Processing statistics
self.stats = {
"total_processed": 0,
"successful_detections": 0,
"failed_detections": 0,
"average_confidence": 0.0,
"processing_time_ms": 0.0
}
async def initialize(self):
"""Initialize the pose service."""
try:
self.logger.info("Initializing pose service...")
# Initialize CSI processor
csi_config = {
'buffer_size': self.settings.csi_buffer_size,
'sampling_rate': getattr(self.settings, 'csi_sampling_rate', 1000),
'window_size': getattr(self.settings, 'csi_window_size', 512),
'overlap': getattr(self.settings, 'csi_overlap', 0.5),
'noise_threshold': getattr(self.settings, 'csi_noise_threshold', 0.1),
'human_detection_threshold': getattr(self.settings, 'csi_human_detection_threshold', 0.8),
'smoothing_factor': getattr(self.settings, 'csi_smoothing_factor', 0.9),
'max_history_size': getattr(self.settings, 'csi_max_history_size', 500),
'num_subcarriers': 56,
'num_antennas': 3
}
self.csi_processor = CSIProcessor(config=csi_config)
# Initialize phase sanitizer
phase_config = {
'unwrapping_method': 'numpy',
'outlier_threshold': 3.0,
'smoothing_window': 5,
'enable_outlier_removal': True,
'enable_smoothing': True,
'enable_noise_filtering': True,
'noise_threshold': getattr(self.settings, 'csi_noise_threshold', 0.1)
}
self.phase_sanitizer = PhaseSanitizer(config=phase_config)
# Initialize models if not mocking
if not self.settings.mock_pose_data:
await self._initialize_models()
else:
self.logger.info("Using mock pose data for development")
self.is_initialized = True
self.logger.info("Pose service initialized successfully")
except Exception as e:
self.last_error = str(e)
self.logger.error(f"Failed to initialize pose service: {e}")
raise
async def _initialize_models(self):
"""Initialize neural network models."""
try:
# Initialize DensePose model
if self.settings.pose_model_path:
self.densepose_model = DensePoseHead()
# Load model weights if path is provided
# model_state = torch.load(self.settings.pose_model_path)
# self.densepose_model.load_state_dict(model_state)
self.logger.info("DensePose model loaded")
else:
self.logger.warning("No pose model path provided, using default model")
self.densepose_model = DensePoseHead()
# Initialize modality translation
config = {
'input_channels': 64, # CSI data channels
'hidden_channels': [128, 256, 512],
'output_channels': 256, # Visual feature channels
'use_attention': True
}
self.modality_translator = ModalityTranslationNetwork(config)
# Set models to evaluation mode
self.densepose_model.eval()
self.modality_translator.eval()
except Exception as e:
self.logger.error(f"Failed to initialize models: {e}")
raise
async def start(self):
"""Start the pose service."""
if not self.is_initialized:
await self.initialize()
self.is_running = True
self.logger.info("Pose service started")
async def stop(self):
"""Stop the pose service."""
self.is_running = False
self.logger.info("Pose service stopped")
async def process_csi_data(self, csi_data: np.ndarray, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""Process CSI data and estimate poses."""
if not self.is_running:
raise RuntimeError("Pose service is not running")
start_time = datetime.now()
try:
# Process CSI data
processed_csi = await self._process_csi(csi_data, metadata)
# Estimate poses
poses = await self._estimate_poses(processed_csi, metadata)
# Update statistics
processing_time = (datetime.now() - start_time).total_seconds() * 1000
self._update_stats(poses, processing_time)
return {
"timestamp": start_time.isoformat(),
"poses": poses,
"metadata": metadata,
"processing_time_ms": processing_time,
"confidence_scores": [pose.get("confidence", 0.0) for pose in poses]
}
except Exception as e:
self.last_error = str(e)
self.stats["failed_detections"] += 1
self.logger.error(f"Error processing CSI data: {e}")
raise
async def _process_csi(self, csi_data: np.ndarray, metadata: Dict[str, Any]) -> np.ndarray:
"""Process raw CSI data."""
# Convert raw data to CSIData format
from src.hardware.csi_extractor import CSIData
# Create CSIData object with proper fields
# For mock data, create amplitude and phase from input
if csi_data.ndim == 1:
amplitude = np.abs(csi_data)
phase = np.angle(csi_data) if np.iscomplexobj(csi_data) else np.zeros_like(csi_data)
else:
amplitude = csi_data
phase = np.zeros_like(csi_data)
csi_data_obj = CSIData(
timestamp=metadata.get("timestamp", datetime.now()),
amplitude=amplitude,
phase=phase,
frequency=metadata.get("frequency", 5.0), # 5 GHz default
bandwidth=metadata.get("bandwidth", 20.0), # 20 MHz default
num_subcarriers=metadata.get("num_subcarriers", 56),
num_antennas=metadata.get("num_antennas", 3),
snr=metadata.get("snr", 20.0), # 20 dB default
metadata=metadata
)
# Process CSI data
try:
detection_result = await self.csi_processor.process_csi_data(csi_data_obj)
# Add to history for temporal analysis
self.csi_processor.add_to_history(csi_data_obj)
# Extract amplitude data for pose estimation
if detection_result and detection_result.features:
amplitude_data = detection_result.features.amplitude_mean
# Apply phase sanitization if we have phase data
if hasattr(detection_result.features, 'phase_difference'):
phase_data = detection_result.features.phase_difference
sanitized_phase = self.phase_sanitizer.sanitize(phase_data)
# Combine amplitude and phase data
return np.concatenate([amplitude_data, sanitized_phase])
return amplitude_data
except Exception as e:
self.logger.warning(f"CSI processing failed, using raw data: {e}")
return csi_data
async def _estimate_poses(self, csi_data: np.ndarray, metadata: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Estimate poses from processed CSI data."""
if self.settings.mock_pose_data:
return self._generate_mock_poses()
try:
# Convert CSI data to tensor
csi_tensor = torch.from_numpy(csi_data).float()
# Add batch dimension if needed
if len(csi_tensor.shape) == 2:
csi_tensor = csi_tensor.unsqueeze(0)
# Translate modality (CSI to visual-like features)
with torch.no_grad():
visual_features = self.modality_translator(csi_tensor)
# Estimate poses using DensePose
pose_outputs = self.densepose_model(visual_features)
# Convert outputs to pose detections
poses = self._parse_pose_outputs(pose_outputs)
# Filter by confidence threshold
filtered_poses = [
pose for pose in poses
if pose.get("confidence", 0.0) >= self.settings.pose_confidence_threshold
]
# Limit number of persons
if len(filtered_poses) > self.settings.pose_max_persons:
filtered_poses = sorted(
filtered_poses,
key=lambda x: x.get("confidence", 0.0),
reverse=True
)[:self.settings.pose_max_persons]
return filtered_poses
except Exception as e:
self.logger.error(f"Error in pose estimation: {e}")
return []
def _parse_pose_outputs(self, outputs: torch.Tensor) -> List[Dict[str, Any]]:
"""Parse neural network outputs into pose detections."""
poses = []
# This is a simplified parsing - in reality, this would depend on the model architecture
# For now, generate mock poses based on the output shape
batch_size = outputs.shape[0]
for i in range(batch_size):
# Extract pose information (mock implementation)
confidence = float(torch.sigmoid(outputs[i, 0]).item()) if outputs.shape[1] > 0 else 0.5
pose = {
"person_id": i,
"confidence": confidence,
"keypoints": self._generate_keypoints(),
"bounding_box": self._generate_bounding_box(),
"activity": self._classify_activity(outputs[i] if len(outputs.shape) > 1 else outputs),
"timestamp": datetime.now().isoformat()
}
poses.append(pose)
return poses
def _generate_mock_poses(self) -> List[Dict[str, Any]]:
"""Generate mock pose data for development."""
import random
num_persons = random.randint(1, min(3, self.settings.pose_max_persons))
poses = []
for i in range(num_persons):
confidence = random.uniform(0.3, 0.95)
pose = {
"person_id": i,
"confidence": confidence,
"keypoints": self._generate_keypoints(),
"bounding_box": self._generate_bounding_box(),
"activity": random.choice(["standing", "sitting", "walking", "lying"]),
"timestamp": datetime.now().isoformat()
}
poses.append(pose)
return poses
def _generate_keypoints(self) -> List[Dict[str, Any]]:
"""Generate keypoints for a person."""
import random
keypoint_names = [
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
"left_wrist", "right_wrist", "left_hip", "right_hip",
"left_knee", "right_knee", "left_ankle", "right_ankle"
]
keypoints = []
for name in keypoint_names:
keypoints.append({
"name": name,
"x": random.uniform(0.1, 0.9),
"y": random.uniform(0.1, 0.9),
"confidence": random.uniform(0.5, 0.95)
})
return keypoints
def _generate_bounding_box(self) -> Dict[str, float]:
"""Generate bounding box for a person."""
import random
x = random.uniform(0.1, 0.6)
y = random.uniform(0.1, 0.6)
width = random.uniform(0.2, 0.4)
height = random.uniform(0.3, 0.5)
return {
"x": x,
"y": y,
"width": width,
"height": height
}
def _classify_activity(self, features: torch.Tensor) -> str:
"""Classify activity from features."""
# Simple mock classification
import random
activities = ["standing", "sitting", "walking", "lying", "unknown"]
return random.choice(activities)
def _update_stats(self, poses: List[Dict[str, Any]], processing_time: float):
"""Update processing statistics."""
self.stats["total_processed"] += 1
if poses:
self.stats["successful_detections"] += 1
confidences = [pose.get("confidence", 0.0) for pose in poses]
avg_confidence = sum(confidences) / len(confidences)
# Update running average
total = self.stats["successful_detections"]
current_avg = self.stats["average_confidence"]
self.stats["average_confidence"] = (current_avg * (total - 1) + avg_confidence) / total
else:
self.stats["failed_detections"] += 1
# Update processing time (running average)
total = self.stats["total_processed"]
current_avg = self.stats["processing_time_ms"]
self.stats["processing_time_ms"] = (current_avg * (total - 1) + processing_time) / total
async def get_status(self) -> Dict[str, Any]:
"""Get service status."""
return {
"status": "healthy" if self.is_running and not self.last_error else "unhealthy",
"initialized": self.is_initialized,
"running": self.is_running,
"last_error": self.last_error,
"statistics": self.stats.copy(),
"configuration": {
"mock_data": self.settings.mock_pose_data,
"confidence_threshold": self.settings.pose_confidence_threshold,
"max_persons": self.settings.pose_max_persons,
"batch_size": self.settings.pose_processing_batch_size
}
}
async def get_metrics(self) -> Dict[str, Any]:
"""Get service metrics."""
return {
"pose_service": {
"total_processed": self.stats["total_processed"],
"successful_detections": self.stats["successful_detections"],
"failed_detections": self.stats["failed_detections"],
"success_rate": (
self.stats["successful_detections"] / max(1, self.stats["total_processed"])
),
"average_confidence": self.stats["average_confidence"],
"average_processing_time_ms": self.stats["processing_time_ms"]
}
}
async def reset(self):
"""Reset service state."""
self.stats = {
"total_processed": 0,
"successful_detections": 0,
"failed_detections": 0,
"average_confidence": 0.0,
"processing_time_ms": 0.0
}
self.last_error = None
self.logger.info("Pose service reset")
# API endpoint methods
async def estimate_poses(self, zone_ids=None, confidence_threshold=None, max_persons=None,
include_keypoints=True, include_segmentation=False):
"""Estimate poses with API parameters."""
try:
# Generate mock CSI data for estimation
mock_csi = np.random.randn(64, 56, 3) # Mock CSI data
metadata = {
"timestamp": datetime.now(),
"zone_ids": zone_ids or ["zone_1"],
"confidence_threshold": confidence_threshold or self.settings.pose_confidence_threshold,
"max_persons": max_persons or self.settings.pose_max_persons
}
# Process the data
result = await self.process_csi_data(mock_csi, metadata)
# Format for API response
persons = []
for i, pose in enumerate(result["poses"]):
person = {
"person_id": str(pose["person_id"]),
"confidence": pose["confidence"],
"bounding_box": pose["bounding_box"],
"zone_id": zone_ids[0] if zone_ids else "zone_1",
"activity": pose["activity"],
"timestamp": datetime.fromisoformat(pose["timestamp"])
}
if include_keypoints:
person["keypoints"] = pose["keypoints"]
if include_segmentation:
person["segmentation"] = {"mask": "mock_segmentation_data"}
persons.append(person)
# Zone summary
zone_summary = {}
for zone_id in (zone_ids or ["zone_1"]):
zone_summary[zone_id] = len([p for p in persons if p.get("zone_id") == zone_id])
return {
"timestamp": datetime.now(),
"frame_id": f"frame_{int(datetime.now().timestamp())}",
"persons": persons,
"zone_summary": zone_summary,
"processing_time_ms": result["processing_time_ms"],
"metadata": {"mock_data": self.settings.mock_pose_data}
}
except Exception as e:
self.logger.error(f"Error in estimate_poses: {e}")
raise
async def analyze_with_params(self, zone_ids=None, confidence_threshold=None, max_persons=None,
include_keypoints=True, include_segmentation=False):
"""Analyze pose data with custom parameters."""
return await self.estimate_poses(zone_ids, confidence_threshold, max_persons,
include_keypoints, include_segmentation)
async def get_zone_occupancy(self, zone_id: str):
"""Get current occupancy for a specific zone."""
try:
# Mock occupancy data
import random
count = random.randint(0, 5)
persons = []
for i in range(count):
persons.append({
"person_id": f"person_{i}",
"confidence": random.uniform(0.7, 0.95),
"activity": random.choice(["standing", "sitting", "walking"])
})
return {
"count": count,
"max_occupancy": 10,
"persons": persons,
"timestamp": datetime.now()
}
except Exception as e:
self.logger.error(f"Error getting zone occupancy: {e}")
return None
async def get_zones_summary(self):
"""Get occupancy summary for all zones."""
try:
import random
zones = ["zone_1", "zone_2", "zone_3", "zone_4"]
zone_data = {}
total_persons = 0
active_zones = 0
for zone_id in zones:
count = random.randint(0, 3)
zone_data[zone_id] = {
"occupancy": count,
"max_occupancy": 10,
"status": "active" if count > 0 else "inactive"
}
total_persons += count
if count > 0:
active_zones += 1
return {
"total_persons": total_persons,
"zones": zone_data,
"active_zones": active_zones
}
except Exception as e:
self.logger.error(f"Error getting zones summary: {e}")
raise
async def get_historical_data(self, start_time, end_time, zone_ids=None,
aggregation_interval=300, include_raw_data=False):
"""Get historical pose estimation data."""
try:
# Mock historical data
import random
from datetime import timedelta
current_time = start_time
aggregated_data = []
raw_data = [] if include_raw_data else None
while current_time < end_time:
# Generate aggregated data point
data_point = {
"timestamp": current_time,
"total_persons": random.randint(0, 8),
"zones": {}
}
for zone_id in (zone_ids or ["zone_1", "zone_2", "zone_3"]):
data_point["zones"][zone_id] = {
"occupancy": random.randint(0, 3),
"avg_confidence": random.uniform(0.7, 0.95)
}
aggregated_data.append(data_point)
# Generate raw data if requested
if include_raw_data:
for _ in range(random.randint(0, 5)):
raw_data.append({
"timestamp": current_time + timedelta(seconds=random.randint(0, aggregation_interval)),
"person_id": f"person_{random.randint(1, 10)}",
"zone_id": random.choice(zone_ids or ["zone_1", "zone_2", "zone_3"]),
"confidence": random.uniform(0.5, 0.95),
"activity": random.choice(["standing", "sitting", "walking"])
})
current_time += timedelta(seconds=aggregation_interval)
return {
"aggregated_data": aggregated_data,
"raw_data": raw_data,
"total_records": len(aggregated_data)
}
except Exception as e:
self.logger.error(f"Error getting historical data: {e}")
raise
async def get_recent_activities(self, zone_id=None, limit=10):
"""Get recently detected activities."""
try:
import random
activities = []
for i in range(limit):
activity = {
"activity_id": f"activity_{i}",
"person_id": f"person_{random.randint(1, 5)}",
"zone_id": zone_id or random.choice(["zone_1", "zone_2", "zone_3"]),
"activity": random.choice(["standing", "sitting", "walking", "lying"]),
"confidence": random.uniform(0.6, 0.95),
"timestamp": datetime.now() - timedelta(minutes=random.randint(0, 60)),
"duration_seconds": random.randint(10, 300)
}
activities.append(activity)
return activities
except Exception as e:
self.logger.error(f"Error getting recent activities: {e}")
raise
async def is_calibrating(self):
"""Check if calibration is in progress."""
return False # Mock implementation
async def start_calibration(self):
"""Start calibration process."""
import uuid
calibration_id = str(uuid.uuid4())
self.logger.info(f"Started calibration: {calibration_id}")
return calibration_id
async def run_calibration(self, calibration_id):
"""Run calibration process."""
self.logger.info(f"Running calibration: {calibration_id}")
# Mock calibration process
await asyncio.sleep(5)
self.logger.info(f"Calibration completed: {calibration_id}")
async def get_calibration_status(self):
"""Get current calibration status."""
return {
"is_calibrating": False,
"calibration_id": None,
"progress_percent": 100,
"current_step": "completed",
"estimated_remaining_minutes": 0,
"last_calibration": datetime.now() - timedelta(hours=1)
}
async def get_statistics(self, start_time, end_time):
"""Get pose estimation statistics."""
try:
import random
# Mock statistics
total_detections = random.randint(100, 1000)
successful_detections = int(total_detections * random.uniform(0.8, 0.95))
return {
"total_detections": total_detections,
"successful_detections": successful_detections,
"failed_detections": total_detections - successful_detections,
"success_rate": successful_detections / total_detections,
"average_confidence": random.uniform(0.75, 0.90),
"average_processing_time_ms": random.uniform(50, 200),
"unique_persons": random.randint(5, 20),
"most_active_zone": random.choice(["zone_1", "zone_2", "zone_3"]),
"activity_distribution": {
"standing": random.uniform(0.3, 0.5),
"sitting": random.uniform(0.2, 0.4),
"walking": random.uniform(0.1, 0.3),
"lying": random.uniform(0.0, 0.1)
}
}
except Exception as e:
self.logger.error(f"Error getting statistics: {e}")
raise
async def process_segmentation_data(self, frame_id):
"""Process segmentation data in background."""
self.logger.info(f"Processing segmentation data for frame: {frame_id}")
# Mock background processing
await asyncio.sleep(2)
self.logger.info(f"Segmentation processing completed for frame: {frame_id}")
# WebSocket streaming methods
async def get_current_pose_data(self):
"""Get current pose data for streaming."""
try:
# Generate current pose data
result = await self.estimate_poses()
# Format data by zones for WebSocket streaming
zone_data = {}
# Group persons by zone
for person in result["persons"]:
zone_id = person.get("zone_id", "zone_1")
if zone_id not in zone_data:
zone_data[zone_id] = {
"pose": {
"persons": [],
"count": 0
},
"confidence": 0.0,
"activity": None,
"metadata": {
"frame_id": result["frame_id"],
"processing_time_ms": result["processing_time_ms"]
}
}
zone_data[zone_id]["pose"]["persons"].append(person)
zone_data[zone_id]["pose"]["count"] += 1
# Update zone confidence (average)
current_confidence = zone_data[zone_id]["confidence"]
person_confidence = person.get("confidence", 0.0)
zone_data[zone_id]["confidence"] = (current_confidence + person_confidence) / 2
# Set activity if not already set
if not zone_data[zone_id]["activity"] and person.get("activity"):
zone_data[zone_id]["activity"] = person["activity"]
return zone_data
except Exception as e:
self.logger.error(f"Error getting current pose data: {e}")
# Return empty zone data on error
return {}
# Health check methods
async def health_check(self):
"""Perform health check."""
try:
status = "healthy" if self.is_running and not self.last_error else "unhealthy"
return {
"status": status,
"message": self.last_error if self.last_error else "Service is running normally",
"uptime_seconds": 0.0, # TODO: Implement actual uptime tracking
"metrics": {
"total_processed": self.stats["total_processed"],
"success_rate": (
self.stats["successful_detections"] / max(1, self.stats["total_processed"])
),
"average_processing_time_ms": self.stats["processing_time_ms"]
}
}
except Exception as e:
return {
"status": "unhealthy",
"message": f"Health check failed: {str(e)}"
}
async def is_ready(self):
"""Check if service is ready."""
return self.is_initialized and self.is_running

View File

@@ -0,0 +1,397 @@
"""
Real-time streaming service for WiFi-DensePose API
"""
import logging
import asyncio
import json
from typing import Dict, List, Optional, Any, Set
from datetime import datetime
from collections import deque
import numpy as np
from fastapi import WebSocket
from src.config.settings import Settings
from src.config.domains import DomainConfig
logger = logging.getLogger(__name__)
class StreamService:
"""Service for real-time data streaming."""
def __init__(self, settings: Settings, domain_config: DomainConfig):
"""Initialize stream service."""
self.settings = settings
self.domain_config = domain_config
self.logger = logging.getLogger(__name__)
# WebSocket connections
self.connections: Set[WebSocket] = set()
self.connection_metadata: Dict[WebSocket, Dict[str, Any]] = {}
# Stream buffers
self.pose_buffer = deque(maxlen=self.settings.stream_buffer_size)
self.csi_buffer = deque(maxlen=self.settings.stream_buffer_size)
# Service state
self.is_running = False
self.last_error = None
# Streaming statistics
self.stats = {
"active_connections": 0,
"total_connections": 0,
"messages_sent": 0,
"messages_failed": 0,
"data_points_streamed": 0,
"average_latency_ms": 0.0
}
# Background tasks
self.streaming_task = None
async def initialize(self):
"""Initialize the stream service."""
self.logger.info("Stream service initialized")
async def start(self):
"""Start the stream service."""
if self.is_running:
return
self.is_running = True
self.logger.info("Stream service started")
# Start background streaming task
if self.settings.enable_real_time_processing:
self.streaming_task = asyncio.create_task(self._streaming_loop())
async def stop(self):
"""Stop the stream service."""
self.is_running = False
# Cancel background task
if self.streaming_task:
self.streaming_task.cancel()
try:
await self.streaming_task
except asyncio.CancelledError:
pass
# Close all connections
await self._close_all_connections()
self.logger.info("Stream service stopped")
async def add_connection(self, websocket: WebSocket, metadata: Dict[str, Any] = None):
"""Add a new WebSocket connection."""
try:
await websocket.accept()
self.connections.add(websocket)
self.connection_metadata[websocket] = metadata or {}
self.stats["active_connections"] = len(self.connections)
self.stats["total_connections"] += 1
self.logger.info(f"New WebSocket connection added. Total: {len(self.connections)}")
# Send initial data if available
await self._send_initial_data(websocket)
except Exception as e:
self.logger.error(f"Error adding WebSocket connection: {e}")
raise
async def remove_connection(self, websocket: WebSocket):
"""Remove a WebSocket connection."""
try:
if websocket in self.connections:
self.connections.remove(websocket)
self.connection_metadata.pop(websocket, None)
self.stats["active_connections"] = len(self.connections)
self.logger.info(f"WebSocket connection removed. Total: {len(self.connections)}")
except Exception as e:
self.logger.error(f"Error removing WebSocket connection: {e}")
async def broadcast_pose_data(self, pose_data: Dict[str, Any]):
"""Broadcast pose data to all connected clients."""
if not self.is_running:
return
# Add to buffer
self.pose_buffer.append({
"type": "pose_data",
"timestamp": datetime.now().isoformat(),
"data": pose_data
})
# Broadcast to all connections
await self._broadcast_message({
"type": "pose_update",
"timestamp": datetime.now().isoformat(),
"data": pose_data
})
async def broadcast_csi_data(self, csi_data: np.ndarray, metadata: Dict[str, Any]):
"""Broadcast CSI data to all connected clients."""
if not self.is_running:
return
# Convert numpy array to list for JSON serialization
csi_list = csi_data.tolist() if isinstance(csi_data, np.ndarray) else csi_data
# Add to buffer
self.csi_buffer.append({
"type": "csi_data",
"timestamp": datetime.now().isoformat(),
"data": csi_list,
"metadata": metadata
})
# Broadcast to all connections
await self._broadcast_message({
"type": "csi_update",
"timestamp": datetime.now().isoformat(),
"data": csi_list,
"metadata": metadata
})
async def broadcast_system_status(self, status_data: Dict[str, Any]):
"""Broadcast system status to all connected clients."""
if not self.is_running:
return
await self._broadcast_message({
"type": "system_status",
"timestamp": datetime.now().isoformat(),
"data": status_data
})
async def send_to_connection(self, websocket: WebSocket, message: Dict[str, Any]):
"""Send message to a specific connection."""
try:
if websocket in self.connections:
await websocket.send_text(json.dumps(message))
self.stats["messages_sent"] += 1
except Exception as e:
self.logger.error(f"Error sending message to connection: {e}")
self.stats["messages_failed"] += 1
await self.remove_connection(websocket)
async def _broadcast_message(self, message: Dict[str, Any]):
"""Broadcast message to all connected clients."""
if not self.connections:
return
disconnected = set()
for websocket in self.connections.copy():
try:
await websocket.send_text(json.dumps(message))
self.stats["messages_sent"] += 1
except Exception as e:
self.logger.warning(f"Failed to send message to connection: {e}")
self.stats["messages_failed"] += 1
disconnected.add(websocket)
# Remove disconnected clients
for websocket in disconnected:
await self.remove_connection(websocket)
if message.get("type") in ["pose_update", "csi_update"]:
self.stats["data_points_streamed"] += 1
async def _send_initial_data(self, websocket: WebSocket):
"""Send initial data to a new connection."""
try:
# Send recent pose data
if self.pose_buffer:
recent_poses = list(self.pose_buffer)[-10:] # Last 10 poses
await self.send_to_connection(websocket, {
"type": "initial_poses",
"timestamp": datetime.now().isoformat(),
"data": recent_poses
})
# Send recent CSI data
if self.csi_buffer:
recent_csi = list(self.csi_buffer)[-5:] # Last 5 CSI readings
await self.send_to_connection(websocket, {
"type": "initial_csi",
"timestamp": datetime.now().isoformat(),
"data": recent_csi
})
# Send service status
status = await self.get_status()
await self.send_to_connection(websocket, {
"type": "service_status",
"timestamp": datetime.now().isoformat(),
"data": status
})
except Exception as e:
self.logger.error(f"Error sending initial data: {e}")
async def _streaming_loop(self):
"""Background streaming loop for periodic updates."""
try:
while self.is_running:
# Send periodic heartbeat
if self.connections:
await self._broadcast_message({
"type": "heartbeat",
"timestamp": datetime.now().isoformat(),
"active_connections": len(self.connections)
})
# Wait for next iteration
await asyncio.sleep(self.settings.websocket_ping_interval)
except asyncio.CancelledError:
self.logger.info("Streaming loop cancelled")
except Exception as e:
self.logger.error(f"Error in streaming loop: {e}")
self.last_error = str(e)
async def _close_all_connections(self):
"""Close all WebSocket connections."""
disconnected = []
for websocket in self.connections.copy():
try:
await websocket.close()
disconnected.append(websocket)
except Exception as e:
self.logger.warning(f"Error closing connection: {e}")
disconnected.append(websocket)
# Clear all connections
for websocket in disconnected:
await self.remove_connection(websocket)
async def get_status(self) -> Dict[str, Any]:
"""Get service status."""
return {
"status": "healthy" if self.is_running and not self.last_error else "unhealthy",
"running": self.is_running,
"last_error": self.last_error,
"connections": {
"active": len(self.connections),
"total": self.stats["total_connections"]
},
"buffers": {
"pose_buffer_size": len(self.pose_buffer),
"csi_buffer_size": len(self.csi_buffer),
"max_buffer_size": self.settings.stream_buffer_size
},
"statistics": self.stats.copy(),
"configuration": {
"stream_fps": self.settings.stream_fps,
"buffer_size": self.settings.stream_buffer_size,
"ping_interval": self.settings.websocket_ping_interval,
"timeout": self.settings.websocket_timeout
}
}
async def get_metrics(self) -> Dict[str, Any]:
"""Get service metrics."""
total_messages = self.stats["messages_sent"] + self.stats["messages_failed"]
success_rate = self.stats["messages_sent"] / max(1, total_messages)
return {
"stream_service": {
"active_connections": self.stats["active_connections"],
"total_connections": self.stats["total_connections"],
"messages_sent": self.stats["messages_sent"],
"messages_failed": self.stats["messages_failed"],
"message_success_rate": success_rate,
"data_points_streamed": self.stats["data_points_streamed"],
"average_latency_ms": self.stats["average_latency_ms"]
}
}
async def get_connection_info(self) -> List[Dict[str, Any]]:
"""Get information about active connections."""
connections_info = []
for websocket in self.connections:
metadata = self.connection_metadata.get(websocket, {})
connection_info = {
"id": id(websocket),
"connected_at": metadata.get("connected_at", "unknown"),
"user_agent": metadata.get("user_agent", "unknown"),
"ip_address": metadata.get("ip_address", "unknown"),
"subscription_types": metadata.get("subscription_types", [])
}
connections_info.append(connection_info)
return connections_info
async def reset(self):
"""Reset service state."""
# Clear buffers
self.pose_buffer.clear()
self.csi_buffer.clear()
# Reset statistics
self.stats = {
"active_connections": len(self.connections),
"total_connections": 0,
"messages_sent": 0,
"messages_failed": 0,
"data_points_streamed": 0,
"average_latency_ms": 0.0
}
self.last_error = None
self.logger.info("Stream service reset")
def get_buffer_data(self, buffer_type: str, limit: int = 100) -> List[Dict[str, Any]]:
"""Get data from buffers."""
if buffer_type == "pose":
return list(self.pose_buffer)[-limit:]
elif buffer_type == "csi":
return list(self.csi_buffer)[-limit:]
else:
return []
@property
def is_active(self) -> bool:
"""Check if stream service is active."""
return self.is_running
async def health_check(self) -> Dict[str, Any]:
"""Perform health check."""
try:
status = "healthy" if self.is_running and not self.last_error else "unhealthy"
return {
"status": status,
"message": self.last_error if self.last_error else "Stream service is running normally",
"active_connections": len(self.connections),
"metrics": {
"messages_sent": self.stats["messages_sent"],
"messages_failed": self.stats["messages_failed"],
"data_points_streamed": self.stats["data_points_streamed"]
}
}
except Exception as e:
return {
"status": "unhealthy",
"message": f"Health check failed: {str(e)}"
}
async def is_ready(self) -> bool:
"""Check if service is ready."""
return self.is_running

612
v1/src/tasks/backup.py Normal file
View File

@@ -0,0 +1,612 @@
"""
Backup tasks for WiFi-DensePose API
"""
import asyncio
import logging
import os
import shutil
import gzip
import json
import subprocess
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, Any, Optional, List
from contextlib import asynccontextmanager
from sqlalchemy import select, text
from sqlalchemy.ext.asyncio import AsyncSession
from src.config.settings import Settings
from src.database.connection import get_database_manager
from src.database.models import Device, Session, CSIData, PoseDetection, SystemMetric, AuditLog
from src.logger import get_logger
logger = get_logger(__name__)
class BackupTask:
"""Base class for backup tasks."""
def __init__(self, name: str, settings: Settings):
self.name = name
self.settings = settings
self.enabled = True
self.last_run = None
self.run_count = 0
self.error_count = 0
self.backup_dir = Path(settings.backup_directory)
self.backup_dir.mkdir(parents=True, exist_ok=True)
async def execute_backup(self, session: AsyncSession) -> Dict[str, Any]:
"""Execute the backup task."""
raise NotImplementedError
async def run(self, session: AsyncSession) -> Dict[str, Any]:
"""Run the backup task with error handling."""
start_time = datetime.utcnow()
try:
logger.info(f"Starting backup task: {self.name}")
result = await self.execute_backup(session)
self.last_run = start_time
self.run_count += 1
logger.info(
f"Backup task {self.name} completed: "
f"backed up {result.get('backup_size_mb', 0):.2f} MB"
)
return {
"task": self.name,
"status": "success",
"start_time": start_time.isoformat(),
"duration_ms": (datetime.utcnow() - start_time).total_seconds() * 1000,
**result
}
except Exception as e:
self.error_count += 1
logger.error(f"Backup task {self.name} failed: {e}", exc_info=True)
return {
"task": self.name,
"status": "error",
"start_time": start_time.isoformat(),
"duration_ms": (datetime.utcnow() - start_time).total_seconds() * 1000,
"error": str(e),
"backup_size_mb": 0
}
def get_stats(self) -> Dict[str, Any]:
"""Get task statistics."""
return {
"name": self.name,
"enabled": self.enabled,
"last_run": self.last_run.isoformat() if self.last_run else None,
"run_count": self.run_count,
"error_count": self.error_count,
"backup_directory": str(self.backup_dir),
}
def _get_backup_filename(self, prefix: str, extension: str = ".gz") -> str:
"""Generate backup filename with timestamp."""
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
return f"{prefix}_{timestamp}{extension}"
def _get_file_size_mb(self, file_path: Path) -> float:
"""Get file size in MB."""
if file_path.exists():
return file_path.stat().st_size / (1024 * 1024)
return 0.0
def _cleanup_old_backups(self, pattern: str, retention_days: int):
"""Clean up old backup files."""
if retention_days <= 0:
return
cutoff_date = datetime.utcnow() - timedelta(days=retention_days)
for backup_file in self.backup_dir.glob(pattern):
if backup_file.stat().st_mtime < cutoff_date.timestamp():
try:
backup_file.unlink()
logger.debug(f"Deleted old backup: {backup_file}")
except Exception as e:
logger.warning(f"Failed to delete old backup {backup_file}: {e}")
class DatabaseBackup(BackupTask):
"""Full database backup using pg_dump."""
def __init__(self, settings: Settings):
super().__init__("database_backup", settings)
self.retention_days = settings.database_backup_retention_days
async def execute_backup(self, session: AsyncSession) -> Dict[str, Any]:
"""Execute database backup."""
backup_filename = self._get_backup_filename("database_full", ".sql.gz")
backup_path = self.backup_dir / backup_filename
# Build pg_dump command
pg_dump_cmd = [
"pg_dump",
"--verbose",
"--no-password",
"--format=custom",
"--compress=9",
"--file", str(backup_path),
]
# Add connection parameters
if self.settings.database_url:
pg_dump_cmd.append(self.settings.database_url)
else:
pg_dump_cmd.extend([
"--host", self.settings.db_host,
"--port", str(self.settings.db_port),
"--username", self.settings.db_user,
"--dbname", self.settings.db_name,
])
# Set environment variables
env = os.environ.copy()
if self.settings.db_password:
env["PGPASSWORD"] = self.settings.db_password
# Execute pg_dump
process = await asyncio.create_subprocess_exec(
*pg_dump_cmd,
env=env,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
error_msg = stderr.decode() if stderr else "Unknown pg_dump error"
raise Exception(f"pg_dump failed: {error_msg}")
backup_size_mb = self._get_file_size_mb(backup_path)
# Clean up old backups
self._cleanup_old_backups("database_full_*.sql.gz", self.retention_days)
return {
"backup_file": backup_filename,
"backup_path": str(backup_path),
"backup_size_mb": backup_size_mb,
"retention_days": self.retention_days,
}
class ConfigurationBackup(BackupTask):
"""Backup configuration files and settings."""
def __init__(self, settings: Settings):
super().__init__("configuration_backup", settings)
self.retention_days = settings.config_backup_retention_days
self.config_files = [
"src/config/settings.py",
".env",
"pyproject.toml",
"docker-compose.yml",
"Dockerfile",
]
async def execute_backup(self, session: AsyncSession) -> Dict[str, Any]:
"""Execute configuration backup."""
backup_filename = self._get_backup_filename("configuration", ".tar.gz")
backup_path = self.backup_dir / backup_filename
# Create temporary directory for config files
temp_dir = self.backup_dir / "temp_config"
temp_dir.mkdir(exist_ok=True)
try:
copied_files = []
# Copy configuration files
for config_file in self.config_files:
source_path = Path(config_file)
if source_path.exists():
dest_path = temp_dir / source_path.name
shutil.copy2(source_path, dest_path)
copied_files.append(config_file)
# Create settings dump
settings_dump = {
"backup_timestamp": datetime.utcnow().isoformat(),
"environment": self.settings.environment,
"debug": self.settings.debug,
"version": self.settings.version,
"database_settings": {
"db_host": self.settings.db_host,
"db_port": self.settings.db_port,
"db_name": self.settings.db_name,
"db_pool_size": self.settings.db_pool_size,
},
"redis_settings": {
"redis_enabled": self.settings.redis_enabled,
"redis_host": self.settings.redis_host,
"redis_port": self.settings.redis_port,
"redis_db": self.settings.redis_db,
},
"monitoring_settings": {
"monitoring_interval_seconds": self.settings.monitoring_interval_seconds,
"cleanup_interval_seconds": self.settings.cleanup_interval_seconds,
},
}
settings_file = temp_dir / "settings_dump.json"
with open(settings_file, 'w') as f:
json.dump(settings_dump, f, indent=2)
# Create tar.gz archive
tar_cmd = [
"tar", "-czf", str(backup_path),
"-C", str(temp_dir),
"."
]
process = await asyncio.create_subprocess_exec(
*tar_cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
error_msg = stderr.decode() if stderr else "Unknown tar error"
raise Exception(f"tar failed: {error_msg}")
backup_size_mb = self._get_file_size_mb(backup_path)
# Clean up old backups
self._cleanup_old_backups("configuration_*.tar.gz", self.retention_days)
return {
"backup_file": backup_filename,
"backup_path": str(backup_path),
"backup_size_mb": backup_size_mb,
"copied_files": copied_files,
"retention_days": self.retention_days,
}
finally:
# Clean up temporary directory
if temp_dir.exists():
shutil.rmtree(temp_dir)
class DataExportBackup(BackupTask):
"""Export specific data tables to JSON format."""
def __init__(self, settings: Settings):
super().__init__("data_export_backup", settings)
self.retention_days = settings.data_export_retention_days
self.export_batch_size = 1000
async def execute_backup(self, session: AsyncSession) -> Dict[str, Any]:
"""Execute data export backup."""
backup_filename = self._get_backup_filename("data_export", ".json.gz")
backup_path = self.backup_dir / backup_filename
export_data = {
"backup_timestamp": datetime.utcnow().isoformat(),
"export_version": "1.0",
"tables": {}
}
# Export devices
devices_data = await self._export_table_data(session, Device, "devices")
export_data["tables"]["devices"] = devices_data
# Export sessions
sessions_data = await self._export_table_data(session, Session, "sessions")
export_data["tables"]["sessions"] = sessions_data
# Export recent CSI data (last 7 days)
recent_date = datetime.utcnow() - timedelta(days=7)
csi_query = select(CSIData).where(CSIData.created_at >= recent_date)
csi_data = await self._export_query_data(session, csi_query, "csi_data")
export_data["tables"]["csi_data_recent"] = csi_data
# Export recent pose detections (last 7 days)
pose_query = select(PoseDetection).where(PoseDetection.created_at >= recent_date)
pose_data = await self._export_query_data(session, pose_query, "pose_detections")
export_data["tables"]["pose_detections_recent"] = pose_data
# Write compressed JSON
with gzip.open(backup_path, 'wt', encoding='utf-8') as f:
json.dump(export_data, f, indent=2, default=str)
backup_size_mb = self._get_file_size_mb(backup_path)
# Clean up old backups
self._cleanup_old_backups("data_export_*.json.gz", self.retention_days)
total_records = sum(
table_data["record_count"]
for table_data in export_data["tables"].values()
)
return {
"backup_file": backup_filename,
"backup_path": str(backup_path),
"backup_size_mb": backup_size_mb,
"total_records": total_records,
"tables_exported": list(export_data["tables"].keys()),
"retention_days": self.retention_days,
}
async def _export_table_data(self, session: AsyncSession, model_class, table_name: str) -> Dict[str, Any]:
"""Export all data from a table."""
query = select(model_class)
return await self._export_query_data(session, query, table_name)
async def _export_query_data(self, session: AsyncSession, query, table_name: str) -> Dict[str, Any]:
"""Export data from a query."""
result = await session.execute(query)
records = result.scalars().all()
exported_records = []
for record in records:
if hasattr(record, 'to_dict'):
exported_records.append(record.to_dict())
else:
# Fallback for records without to_dict method
record_dict = {}
for column in record.__table__.columns:
value = getattr(record, column.name)
if isinstance(value, datetime):
value = value.isoformat()
record_dict[column.name] = value
exported_records.append(record_dict)
return {
"table_name": table_name,
"record_count": len(exported_records),
"export_timestamp": datetime.utcnow().isoformat(),
"records": exported_records,
}
class LogsBackup(BackupTask):
"""Backup application logs."""
def __init__(self, settings: Settings):
super().__init__("logs_backup", settings)
self.retention_days = settings.logs_backup_retention_days
self.logs_directory = Path(settings.log_directory)
async def execute_backup(self, session: AsyncSession) -> Dict[str, Any]:
"""Execute logs backup."""
if not self.logs_directory.exists():
return {
"backup_file": None,
"backup_path": None,
"backup_size_mb": 0,
"message": "Logs directory does not exist",
}
backup_filename = self._get_backup_filename("logs", ".tar.gz")
backup_path = self.backup_dir / backup_filename
# Create tar.gz archive of logs
tar_cmd = [
"tar", "-czf", str(backup_path),
"-C", str(self.logs_directory.parent),
self.logs_directory.name
]
process = await asyncio.create_subprocess_exec(
*tar_cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
error_msg = stderr.decode() if stderr else "Unknown tar error"
raise Exception(f"tar failed: {error_msg}")
backup_size_mb = self._get_file_size_mb(backup_path)
# Count log files
log_files = list(self.logs_directory.glob("*.log*"))
# Clean up old backups
self._cleanup_old_backups("logs_*.tar.gz", self.retention_days)
return {
"backup_file": backup_filename,
"backup_path": str(backup_path),
"backup_size_mb": backup_size_mb,
"log_files_count": len(log_files),
"retention_days": self.retention_days,
}
class BackupManager:
"""Manager for all backup tasks."""
def __init__(self, settings: Settings):
self.settings = settings
self.db_manager = get_database_manager(settings)
self.tasks = self._initialize_tasks()
self.running = False
self.last_run = None
self.run_count = 0
self.total_backup_size = 0
def _initialize_tasks(self) -> List[BackupTask]:
"""Initialize all backup tasks."""
tasks = [
DatabaseBackup(self.settings),
ConfigurationBackup(self.settings),
DataExportBackup(self.settings),
LogsBackup(self.settings),
]
# Filter enabled tasks
enabled_tasks = [task for task in tasks if task.enabled]
logger.info(f"Initialized {len(enabled_tasks)} backup tasks")
return enabled_tasks
async def run_all_tasks(self) -> Dict[str, Any]:
"""Run all backup tasks."""
if self.running:
return {"status": "already_running", "message": "Backup already in progress"}
self.running = True
start_time = datetime.utcnow()
try:
logger.info("Starting backup tasks")
results = []
total_backup_size = 0
async with self.db_manager.get_async_session() as session:
for task in self.tasks:
if not task.enabled:
continue
result = await task.run(session)
results.append(result)
total_backup_size += result.get("backup_size_mb", 0)
self.last_run = start_time
self.run_count += 1
self.total_backup_size += total_backup_size
duration = (datetime.utcnow() - start_time).total_seconds()
logger.info(
f"Backup tasks completed: created {total_backup_size:.2f} MB "
f"in {duration:.2f} seconds"
)
return {
"status": "completed",
"start_time": start_time.isoformat(),
"duration_seconds": duration,
"total_backup_size_mb": total_backup_size,
"task_results": results,
}
except Exception as e:
logger.error(f"Backup tasks failed: {e}", exc_info=True)
return {
"status": "error",
"start_time": start_time.isoformat(),
"duration_seconds": (datetime.utcnow() - start_time).total_seconds(),
"error": str(e),
"total_backup_size_mb": 0,
}
finally:
self.running = False
async def run_task(self, task_name: str) -> Dict[str, Any]:
"""Run a specific backup task."""
task = next((t for t in self.tasks if t.name == task_name), None)
if not task:
return {
"status": "error",
"error": f"Task '{task_name}' not found",
"available_tasks": [t.name for t in self.tasks]
}
if not task.enabled:
return {
"status": "error",
"error": f"Task '{task_name}' is disabled"
}
async with self.db_manager.get_async_session() as session:
return await task.run(session)
def get_stats(self) -> Dict[str, Any]:
"""Get backup manager statistics."""
return {
"manager": {
"running": self.running,
"last_run": self.last_run.isoformat() if self.last_run else None,
"run_count": self.run_count,
"total_backup_size_mb": self.total_backup_size,
},
"tasks": [task.get_stats() for task in self.tasks],
}
def list_backups(self) -> Dict[str, List[Dict[str, Any]]]:
"""List all backup files."""
backup_files = {}
for task in self.tasks:
task_backups = []
# Define patterns for each task type
patterns = {
"database_backup": "database_full_*.sql.gz",
"configuration_backup": "configuration_*.tar.gz",
"data_export_backup": "data_export_*.json.gz",
"logs_backup": "logs_*.tar.gz",
}
pattern = patterns.get(task.name, f"{task.name}_*")
for backup_file in task.backup_dir.glob(pattern):
stat = backup_file.stat()
task_backups.append({
"filename": backup_file.name,
"path": str(backup_file),
"size_mb": stat.st_size / (1024 * 1024),
"created_at": datetime.fromtimestamp(stat.st_mtime).isoformat(),
})
# Sort by creation time (newest first)
task_backups.sort(key=lambda x: x["created_at"], reverse=True)
backup_files[task.name] = task_backups
return backup_files
# Global backup manager instance
_backup_manager: Optional[BackupManager] = None
def get_backup_manager(settings: Settings) -> BackupManager:
"""Get backup manager instance."""
global _backup_manager
if _backup_manager is None:
_backup_manager = BackupManager(settings)
return _backup_manager
async def run_periodic_backup(settings: Settings):
"""Run periodic backup tasks."""
backup_manager = get_backup_manager(settings)
while True:
try:
await backup_manager.run_all_tasks()
# Wait for next backup interval
await asyncio.sleep(settings.backup_interval_seconds)
except asyncio.CancelledError:
logger.info("Periodic backup cancelled")
break
except Exception as e:
logger.error(f"Periodic backup error: {e}", exc_info=True)
# Wait before retrying
await asyncio.sleep(300) # 5 minutes

598
v1/src/tasks/cleanup.py Normal file
View File

@@ -0,0 +1,598 @@
"""
Periodic cleanup tasks for WiFi-DensePose API
"""
import asyncio
import logging
from datetime import datetime, timedelta
from typing import Dict, Any, Optional, List
from contextlib import asynccontextmanager
from sqlalchemy import delete, select, func, and_, or_
from sqlalchemy.ext.asyncio import AsyncSession
from src.config.settings import Settings
from src.database.connection import get_database_manager
from src.database.models import (
CSIData, PoseDetection, SystemMetric, AuditLog, Session, Device
)
from src.logger import get_logger
logger = get_logger(__name__)
class CleanupTask:
"""Base class for cleanup tasks."""
def __init__(self, name: str, settings: Settings):
self.name = name
self.settings = settings
self.enabled = True
self.last_run = None
self.run_count = 0
self.error_count = 0
self.total_cleaned = 0
async def execute(self, session: AsyncSession) -> Dict[str, Any]:
"""Execute the cleanup task."""
raise NotImplementedError
async def run(self, session: AsyncSession) -> Dict[str, Any]:
"""Run the cleanup task with error handling."""
start_time = datetime.utcnow()
try:
logger.info(f"Starting cleanup task: {self.name}")
result = await self.execute(session)
self.last_run = start_time
self.run_count += 1
if result.get("cleaned_count", 0) > 0:
self.total_cleaned += result["cleaned_count"]
logger.info(
f"Cleanup task {self.name} completed: "
f"cleaned {result['cleaned_count']} items"
)
else:
logger.debug(f"Cleanup task {self.name} completed: no items to clean")
return {
"task": self.name,
"status": "success",
"start_time": start_time.isoformat(),
"duration_ms": (datetime.utcnow() - start_time).total_seconds() * 1000,
**result
}
except Exception as e:
self.error_count += 1
logger.error(f"Cleanup task {self.name} failed: {e}", exc_info=True)
return {
"task": self.name,
"status": "error",
"start_time": start_time.isoformat(),
"duration_ms": (datetime.utcnow() - start_time).total_seconds() * 1000,
"error": str(e),
"cleaned_count": 0
}
def get_stats(self) -> Dict[str, Any]:
"""Get task statistics."""
return {
"name": self.name,
"enabled": self.enabled,
"last_run": self.last_run.isoformat() if self.last_run else None,
"run_count": self.run_count,
"error_count": self.error_count,
"total_cleaned": self.total_cleaned,
}
class OldCSIDataCleanup(CleanupTask):
"""Cleanup old CSI data records."""
def __init__(self, settings: Settings):
super().__init__("old_csi_data_cleanup", settings)
self.retention_days = settings.csi_data_retention_days
self.batch_size = settings.cleanup_batch_size
async def execute(self, session: AsyncSession) -> Dict[str, Any]:
"""Execute CSI data cleanup."""
if self.retention_days <= 0:
return {"cleaned_count": 0, "message": "CSI data retention disabled"}
cutoff_date = datetime.utcnow() - timedelta(days=self.retention_days)
# Count records to be deleted
count_query = select(func.count(CSIData.id)).where(
CSIData.created_at < cutoff_date
)
total_count = await session.scalar(count_query)
if total_count == 0:
return {"cleaned_count": 0, "message": "No old CSI data to clean"}
# Delete in batches
cleaned_count = 0
while cleaned_count < total_count:
# Get batch of IDs to delete
id_query = select(CSIData.id).where(
CSIData.created_at < cutoff_date
).limit(self.batch_size)
result = await session.execute(id_query)
ids_to_delete = [row[0] for row in result.fetchall()]
if not ids_to_delete:
break
# Delete batch
delete_query = delete(CSIData).where(CSIData.id.in_(ids_to_delete))
await session.execute(delete_query)
await session.commit()
batch_size = len(ids_to_delete)
cleaned_count += batch_size
logger.debug(f"Deleted {batch_size} CSI data records (total: {cleaned_count})")
# Small delay to avoid overwhelming the database
await asyncio.sleep(0.1)
return {
"cleaned_count": cleaned_count,
"retention_days": self.retention_days,
"cutoff_date": cutoff_date.isoformat()
}
class OldPoseDetectionCleanup(CleanupTask):
"""Cleanup old pose detection records."""
def __init__(self, settings: Settings):
super().__init__("old_pose_detection_cleanup", settings)
self.retention_days = settings.pose_detection_retention_days
self.batch_size = settings.cleanup_batch_size
async def execute(self, session: AsyncSession) -> Dict[str, Any]:
"""Execute pose detection cleanup."""
if self.retention_days <= 0:
return {"cleaned_count": 0, "message": "Pose detection retention disabled"}
cutoff_date = datetime.utcnow() - timedelta(days=self.retention_days)
# Count records to be deleted
count_query = select(func.count(PoseDetection.id)).where(
PoseDetection.created_at < cutoff_date
)
total_count = await session.scalar(count_query)
if total_count == 0:
return {"cleaned_count": 0, "message": "No old pose detections to clean"}
# Delete in batches
cleaned_count = 0
while cleaned_count < total_count:
# Get batch of IDs to delete
id_query = select(PoseDetection.id).where(
PoseDetection.created_at < cutoff_date
).limit(self.batch_size)
result = await session.execute(id_query)
ids_to_delete = [row[0] for row in result.fetchall()]
if not ids_to_delete:
break
# Delete batch
delete_query = delete(PoseDetection).where(PoseDetection.id.in_(ids_to_delete))
await session.execute(delete_query)
await session.commit()
batch_size = len(ids_to_delete)
cleaned_count += batch_size
logger.debug(f"Deleted {batch_size} pose detection records (total: {cleaned_count})")
# Small delay to avoid overwhelming the database
await asyncio.sleep(0.1)
return {
"cleaned_count": cleaned_count,
"retention_days": self.retention_days,
"cutoff_date": cutoff_date.isoformat()
}
class OldMetricsCleanup(CleanupTask):
"""Cleanup old system metrics."""
def __init__(self, settings: Settings):
super().__init__("old_metrics_cleanup", settings)
self.retention_days = settings.metrics_retention_days
self.batch_size = settings.cleanup_batch_size
async def execute(self, session: AsyncSession) -> Dict[str, Any]:
"""Execute metrics cleanup."""
if self.retention_days <= 0:
return {"cleaned_count": 0, "message": "Metrics retention disabled"}
cutoff_date = datetime.utcnow() - timedelta(days=self.retention_days)
# Count records to be deleted
count_query = select(func.count(SystemMetric.id)).where(
SystemMetric.created_at < cutoff_date
)
total_count = await session.scalar(count_query)
if total_count == 0:
return {"cleaned_count": 0, "message": "No old metrics to clean"}
# Delete in batches
cleaned_count = 0
while cleaned_count < total_count:
# Get batch of IDs to delete
id_query = select(SystemMetric.id).where(
SystemMetric.created_at < cutoff_date
).limit(self.batch_size)
result = await session.execute(id_query)
ids_to_delete = [row[0] for row in result.fetchall()]
if not ids_to_delete:
break
# Delete batch
delete_query = delete(SystemMetric).where(SystemMetric.id.in_(ids_to_delete))
await session.execute(delete_query)
await session.commit()
batch_size = len(ids_to_delete)
cleaned_count += batch_size
logger.debug(f"Deleted {batch_size} metric records (total: {cleaned_count})")
# Small delay to avoid overwhelming the database
await asyncio.sleep(0.1)
return {
"cleaned_count": cleaned_count,
"retention_days": self.retention_days,
"cutoff_date": cutoff_date.isoformat()
}
class OldAuditLogCleanup(CleanupTask):
"""Cleanup old audit logs."""
def __init__(self, settings: Settings):
super().__init__("old_audit_log_cleanup", settings)
self.retention_days = settings.audit_log_retention_days
self.batch_size = settings.cleanup_batch_size
async def execute(self, session: AsyncSession) -> Dict[str, Any]:
"""Execute audit log cleanup."""
if self.retention_days <= 0:
return {"cleaned_count": 0, "message": "Audit log retention disabled"}
cutoff_date = datetime.utcnow() - timedelta(days=self.retention_days)
# Count records to be deleted
count_query = select(func.count(AuditLog.id)).where(
AuditLog.created_at < cutoff_date
)
total_count = await session.scalar(count_query)
if total_count == 0:
return {"cleaned_count": 0, "message": "No old audit logs to clean"}
# Delete in batches
cleaned_count = 0
while cleaned_count < total_count:
# Get batch of IDs to delete
id_query = select(AuditLog.id).where(
AuditLog.created_at < cutoff_date
).limit(self.batch_size)
result = await session.execute(id_query)
ids_to_delete = [row[0] for row in result.fetchall()]
if not ids_to_delete:
break
# Delete batch
delete_query = delete(AuditLog).where(AuditLog.id.in_(ids_to_delete))
await session.execute(delete_query)
await session.commit()
batch_size = len(ids_to_delete)
cleaned_count += batch_size
logger.debug(f"Deleted {batch_size} audit log records (total: {cleaned_count})")
# Small delay to avoid overwhelming the database
await asyncio.sleep(0.1)
return {
"cleaned_count": cleaned_count,
"retention_days": self.retention_days,
"cutoff_date": cutoff_date.isoformat()
}
class OrphanedSessionCleanup(CleanupTask):
"""Cleanup orphaned sessions (sessions without associated data)."""
def __init__(self, settings: Settings):
super().__init__("orphaned_session_cleanup", settings)
self.orphan_threshold_days = settings.orphaned_session_threshold_days
self.batch_size = settings.cleanup_batch_size
async def execute(self, session: AsyncSession) -> Dict[str, Any]:
"""Execute orphaned session cleanup."""
if self.orphan_threshold_days <= 0:
return {"cleaned_count": 0, "message": "Orphaned session cleanup disabled"}
cutoff_date = datetime.utcnow() - timedelta(days=self.orphan_threshold_days)
# Find sessions that are old and have no associated CSI data or pose detections
orphaned_sessions_query = select(Session.id).where(
and_(
Session.created_at < cutoff_date,
Session.status.in_(["completed", "failed", "cancelled"]),
~Session.id.in_(select(CSIData.session_id).where(CSIData.session_id.isnot(None))),
~Session.id.in_(select(PoseDetection.session_id))
)
)
result = await session.execute(orphaned_sessions_query)
orphaned_ids = [row[0] for row in result.fetchall()]
if not orphaned_ids:
return {"cleaned_count": 0, "message": "No orphaned sessions to clean"}
# Delete orphaned sessions
delete_query = delete(Session).where(Session.id.in_(orphaned_ids))
await session.execute(delete_query)
await session.commit()
cleaned_count = len(orphaned_ids)
return {
"cleaned_count": cleaned_count,
"orphan_threshold_days": self.orphan_threshold_days,
"cutoff_date": cutoff_date.isoformat()
}
class InvalidDataCleanup(CleanupTask):
"""Cleanup invalid or corrupted data records."""
def __init__(self, settings: Settings):
super().__init__("invalid_data_cleanup", settings)
self.batch_size = settings.cleanup_batch_size
async def execute(self, session: AsyncSession) -> Dict[str, Any]:
"""Execute invalid data cleanup."""
total_cleaned = 0
# Clean invalid CSI data
invalid_csi_query = select(CSIData.id).where(
or_(
CSIData.is_valid == False,
CSIData.amplitude == None,
CSIData.phase == None,
CSIData.frequency <= 0,
CSIData.bandwidth <= 0,
CSIData.num_subcarriers <= 0
)
)
result = await session.execute(invalid_csi_query)
invalid_csi_ids = [row[0] for row in result.fetchall()]
if invalid_csi_ids:
delete_query = delete(CSIData).where(CSIData.id.in_(invalid_csi_ids))
await session.execute(delete_query)
total_cleaned += len(invalid_csi_ids)
logger.debug(f"Deleted {len(invalid_csi_ids)} invalid CSI data records")
# Clean invalid pose detections
invalid_pose_query = select(PoseDetection.id).where(
or_(
PoseDetection.is_valid == False,
PoseDetection.person_count < 0,
and_(
PoseDetection.detection_confidence.isnot(None),
or_(
PoseDetection.detection_confidence < 0,
PoseDetection.detection_confidence > 1
)
)
)
)
result = await session.execute(invalid_pose_query)
invalid_pose_ids = [row[0] for row in result.fetchall()]
if invalid_pose_ids:
delete_query = delete(PoseDetection).where(PoseDetection.id.in_(invalid_pose_ids))
await session.execute(delete_query)
total_cleaned += len(invalid_pose_ids)
logger.debug(f"Deleted {len(invalid_pose_ids)} invalid pose detection records")
await session.commit()
return {
"cleaned_count": total_cleaned,
"invalid_csi_count": len(invalid_csi_ids) if invalid_csi_ids else 0,
"invalid_pose_count": len(invalid_pose_ids) if invalid_pose_ids else 0,
}
class CleanupManager:
"""Manager for all cleanup tasks."""
def __init__(self, settings: Settings):
self.settings = settings
self.db_manager = get_database_manager(settings)
self.tasks = self._initialize_tasks()
self.running = False
self.last_run = None
self.run_count = 0
self.total_cleaned = 0
def _initialize_tasks(self) -> List[CleanupTask]:
"""Initialize all cleanup tasks."""
tasks = [
OldCSIDataCleanup(self.settings),
OldPoseDetectionCleanup(self.settings),
OldMetricsCleanup(self.settings),
OldAuditLogCleanup(self.settings),
OrphanedSessionCleanup(self.settings),
InvalidDataCleanup(self.settings),
]
# Filter enabled tasks
enabled_tasks = [task for task in tasks if task.enabled]
logger.info(f"Initialized {len(enabled_tasks)} cleanup tasks")
return enabled_tasks
async def run_all_tasks(self) -> Dict[str, Any]:
"""Run all cleanup tasks."""
if self.running:
return {"status": "already_running", "message": "Cleanup already in progress"}
self.running = True
start_time = datetime.utcnow()
try:
logger.info("Starting cleanup tasks")
results = []
total_cleaned = 0
async with self.db_manager.get_async_session() as session:
for task in self.tasks:
if not task.enabled:
continue
result = await task.run(session)
results.append(result)
total_cleaned += result.get("cleaned_count", 0)
self.last_run = start_time
self.run_count += 1
self.total_cleaned += total_cleaned
duration = (datetime.utcnow() - start_time).total_seconds()
logger.info(
f"Cleanup tasks completed: cleaned {total_cleaned} items "
f"in {duration:.2f} seconds"
)
return {
"status": "completed",
"start_time": start_time.isoformat(),
"duration_seconds": duration,
"total_cleaned": total_cleaned,
"task_results": results,
}
except Exception as e:
logger.error(f"Cleanup tasks failed: {e}", exc_info=True)
return {
"status": "error",
"start_time": start_time.isoformat(),
"duration_seconds": (datetime.utcnow() - start_time).total_seconds(),
"error": str(e),
"total_cleaned": 0,
}
finally:
self.running = False
async def run_task(self, task_name: str) -> Dict[str, Any]:
"""Run a specific cleanup task."""
task = next((t for t in self.tasks if t.name == task_name), None)
if not task:
return {
"status": "error",
"error": f"Task '{task_name}' not found",
"available_tasks": [t.name for t in self.tasks]
}
if not task.enabled:
return {
"status": "error",
"error": f"Task '{task_name}' is disabled"
}
async with self.db_manager.get_async_session() as session:
return await task.run(session)
def get_stats(self) -> Dict[str, Any]:
"""Get cleanup manager statistics."""
return {
"manager": {
"running": self.running,
"last_run": self.last_run.isoformat() if self.last_run else None,
"run_count": self.run_count,
"total_cleaned": self.total_cleaned,
},
"tasks": [task.get_stats() for task in self.tasks],
}
def enable_task(self, task_name: str) -> bool:
"""Enable a specific task."""
task = next((t for t in self.tasks if t.name == task_name), None)
if task:
task.enabled = True
return True
return False
def disable_task(self, task_name: str) -> bool:
"""Disable a specific task."""
task = next((t for t in self.tasks if t.name == task_name), None)
if task:
task.enabled = False
return True
return False
# Global cleanup manager instance
_cleanup_manager: Optional[CleanupManager] = None
def get_cleanup_manager(settings: Settings) -> CleanupManager:
"""Get cleanup manager instance."""
global _cleanup_manager
if _cleanup_manager is None:
_cleanup_manager = CleanupManager(settings)
return _cleanup_manager
async def run_periodic_cleanup(settings: Settings):
"""Run periodic cleanup tasks."""
cleanup_manager = get_cleanup_manager(settings)
while True:
try:
await cleanup_manager.run_all_tasks()
# Wait for next cleanup interval
await asyncio.sleep(settings.cleanup_interval_seconds)
except asyncio.CancelledError:
logger.info("Periodic cleanup cancelled")
break
except Exception as e:
logger.error(f"Periodic cleanup error: {e}", exc_info=True)
# Wait before retrying
await asyncio.sleep(60)

773
v1/src/tasks/monitoring.py Normal file
View File

@@ -0,0 +1,773 @@
"""
Monitoring tasks for WiFi-DensePose API
"""
import asyncio
import logging
import psutil
import time
from datetime import datetime, timedelta
from typing import Dict, Any, Optional, List
from contextlib import asynccontextmanager
from sqlalchemy import select, func, and_, or_
from sqlalchemy.ext.asyncio import AsyncSession
from src.config.settings import Settings
from src.database.connection import get_database_manager
from src.database.models import SystemMetric, Device, Session, CSIData, PoseDetection
from src.logger import get_logger
logger = get_logger(__name__)
class MonitoringTask:
"""Base class for monitoring tasks."""
def __init__(self, name: str, settings: Settings):
self.name = name
self.settings = settings
self.enabled = True
self.last_run = None
self.run_count = 0
self.error_count = 0
self.interval_seconds = 60 # Default interval
async def collect_metrics(self, session: AsyncSession) -> List[Dict[str, Any]]:
"""Collect metrics for this task."""
raise NotImplementedError
async def run(self, session: AsyncSession) -> Dict[str, Any]:
"""Run the monitoring task with error handling."""
start_time = datetime.utcnow()
try:
logger.debug(f"Starting monitoring task: {self.name}")
metrics = await self.collect_metrics(session)
# Store metrics in database
for metric_data in metrics:
metric = SystemMetric(
metric_name=metric_data["name"],
metric_type=metric_data["type"],
value=metric_data["value"],
unit=metric_data.get("unit"),
labels=metric_data.get("labels"),
tags=metric_data.get("tags"),
source=metric_data.get("source", self.name),
component=metric_data.get("component"),
description=metric_data.get("description"),
meta_data=metric_data.get("metadata"),
)
session.add(metric)
await session.commit()
self.last_run = start_time
self.run_count += 1
logger.debug(f"Monitoring task {self.name} completed: collected {len(metrics)} metrics")
return {
"task": self.name,
"status": "success",
"start_time": start_time.isoformat(),
"duration_ms": (datetime.utcnow() - start_time).total_seconds() * 1000,
"metrics_collected": len(metrics),
}
except Exception as e:
self.error_count += 1
logger.error(f"Monitoring task {self.name} failed: {e}", exc_info=True)
return {
"task": self.name,
"status": "error",
"start_time": start_time.isoformat(),
"duration_ms": (datetime.utcnow() - start_time).total_seconds() * 1000,
"error": str(e),
"metrics_collected": 0,
}
def get_stats(self) -> Dict[str, Any]:
"""Get task statistics."""
return {
"name": self.name,
"enabled": self.enabled,
"interval_seconds": self.interval_seconds,
"last_run": self.last_run.isoformat() if self.last_run else None,
"run_count": self.run_count,
"error_count": self.error_count,
}
class SystemResourceMonitoring(MonitoringTask):
"""Monitor system resources (CPU, memory, disk, network)."""
def __init__(self, settings: Settings):
super().__init__("system_resources", settings)
self.interval_seconds = settings.system_monitoring_interval
async def collect_metrics(self, session: AsyncSession) -> List[Dict[str, Any]]:
"""Collect system resource metrics."""
metrics = []
timestamp = datetime.utcnow()
# CPU metrics
cpu_percent = psutil.cpu_percent(interval=1)
cpu_count = psutil.cpu_count()
cpu_freq = psutil.cpu_freq()
metrics.extend([
{
"name": "system_cpu_usage_percent",
"type": "gauge",
"value": cpu_percent,
"unit": "percent",
"component": "cpu",
"description": "CPU usage percentage",
"metadata": {"timestamp": timestamp.isoformat()}
},
{
"name": "system_cpu_count",
"type": "gauge",
"value": cpu_count,
"unit": "count",
"component": "cpu",
"description": "Number of CPU cores",
"metadata": {"timestamp": timestamp.isoformat()}
}
])
if cpu_freq:
metrics.append({
"name": "system_cpu_frequency_mhz",
"type": "gauge",
"value": cpu_freq.current,
"unit": "mhz",
"component": "cpu",
"description": "Current CPU frequency",
"metadata": {"timestamp": timestamp.isoformat()}
})
# Memory metrics
memory = psutil.virtual_memory()
swap = psutil.swap_memory()
metrics.extend([
{
"name": "system_memory_total_bytes",
"type": "gauge",
"value": memory.total,
"unit": "bytes",
"component": "memory",
"description": "Total system memory",
"metadata": {"timestamp": timestamp.isoformat()}
},
{
"name": "system_memory_used_bytes",
"type": "gauge",
"value": memory.used,
"unit": "bytes",
"component": "memory",
"description": "Used system memory",
"metadata": {"timestamp": timestamp.isoformat()}
},
{
"name": "system_memory_available_bytes",
"type": "gauge",
"value": memory.available,
"unit": "bytes",
"component": "memory",
"description": "Available system memory",
"metadata": {"timestamp": timestamp.isoformat()}
},
{
"name": "system_memory_usage_percent",
"type": "gauge",
"value": memory.percent,
"unit": "percent",
"component": "memory",
"description": "Memory usage percentage",
"metadata": {"timestamp": timestamp.isoformat()}
},
{
"name": "system_swap_total_bytes",
"type": "gauge",
"value": swap.total,
"unit": "bytes",
"component": "memory",
"description": "Total swap memory",
"metadata": {"timestamp": timestamp.isoformat()}
},
{
"name": "system_swap_used_bytes",
"type": "gauge",
"value": swap.used,
"unit": "bytes",
"component": "memory",
"description": "Used swap memory",
"metadata": {"timestamp": timestamp.isoformat()}
}
])
# Disk metrics
disk_usage = psutil.disk_usage('/')
disk_io = psutil.disk_io_counters()
metrics.extend([
{
"name": "system_disk_total_bytes",
"type": "gauge",
"value": disk_usage.total,
"unit": "bytes",
"component": "disk",
"description": "Total disk space",
"metadata": {"timestamp": timestamp.isoformat()}
},
{
"name": "system_disk_used_bytes",
"type": "gauge",
"value": disk_usage.used,
"unit": "bytes",
"component": "disk",
"description": "Used disk space",
"metadata": {"timestamp": timestamp.isoformat()}
},
{
"name": "system_disk_free_bytes",
"type": "gauge",
"value": disk_usage.free,
"unit": "bytes",
"component": "disk",
"description": "Free disk space",
"metadata": {"timestamp": timestamp.isoformat()}
},
{
"name": "system_disk_usage_percent",
"type": "gauge",
"value": (disk_usage.used / disk_usage.total) * 100,
"unit": "percent",
"component": "disk",
"description": "Disk usage percentage",
"metadata": {"timestamp": timestamp.isoformat()}
}
])
if disk_io:
metrics.extend([
{
"name": "system_disk_read_bytes_total",
"type": "counter",
"value": disk_io.read_bytes,
"unit": "bytes",
"component": "disk",
"description": "Total bytes read from disk",
"metadata": {"timestamp": timestamp.isoformat()}
},
{
"name": "system_disk_write_bytes_total",
"type": "counter",
"value": disk_io.write_bytes,
"unit": "bytes",
"component": "disk",
"description": "Total bytes written to disk",
"metadata": {"timestamp": timestamp.isoformat()}
}
])
# Network metrics
network_io = psutil.net_io_counters()
if network_io:
metrics.extend([
{
"name": "system_network_bytes_sent_total",
"type": "counter",
"value": network_io.bytes_sent,
"unit": "bytes",
"component": "network",
"description": "Total bytes sent over network",
"metadata": {"timestamp": timestamp.isoformat()}
},
{
"name": "system_network_bytes_recv_total",
"type": "counter",
"value": network_io.bytes_recv,
"unit": "bytes",
"component": "network",
"description": "Total bytes received over network",
"metadata": {"timestamp": timestamp.isoformat()}
},
{
"name": "system_network_packets_sent_total",
"type": "counter",
"value": network_io.packets_sent,
"unit": "count",
"component": "network",
"description": "Total packets sent over network",
"metadata": {"timestamp": timestamp.isoformat()}
},
{
"name": "system_network_packets_recv_total",
"type": "counter",
"value": network_io.packets_recv,
"unit": "count",
"component": "network",
"description": "Total packets received over network",
"metadata": {"timestamp": timestamp.isoformat()}
}
])
return metrics
class DatabaseMonitoring(MonitoringTask):
"""Monitor database performance and statistics."""
def __init__(self, settings: Settings):
super().__init__("database", settings)
self.interval_seconds = settings.database_monitoring_interval
async def collect_metrics(self, session: AsyncSession) -> List[Dict[str, Any]]:
"""Collect database metrics."""
metrics = []
timestamp = datetime.utcnow()
# Get database connection stats
db_manager = get_database_manager(self.settings)
connection_stats = await db_manager.get_connection_stats()
# PostgreSQL connection metrics
if "postgresql" in connection_stats:
pg_stats = connection_stats["postgresql"]
metrics.extend([
{
"name": "database_connections_total",
"type": "gauge",
"value": pg_stats.get("total_connections", 0),
"unit": "count",
"component": "postgresql",
"description": "Total database connections",
"metadata": {"timestamp": timestamp.isoformat()}
},
{
"name": "database_connections_active",
"type": "gauge",
"value": pg_stats.get("checked_out", 0),
"unit": "count",
"component": "postgresql",
"description": "Active database connections",
"metadata": {"timestamp": timestamp.isoformat()}
},
{
"name": "database_connections_available",
"type": "gauge",
"value": pg_stats.get("available_connections", 0),
"unit": "count",
"component": "postgresql",
"description": "Available database connections",
"metadata": {"timestamp": timestamp.isoformat()}
}
])
# Redis connection metrics
if "redis" in connection_stats and not connection_stats["redis"].get("error"):
redis_stats = connection_stats["redis"]
metrics.extend([
{
"name": "redis_connections_active",
"type": "gauge",
"value": redis_stats.get("connected_clients", 0),
"unit": "count",
"component": "redis",
"description": "Active Redis connections",
"metadata": {"timestamp": timestamp.isoformat()}
},
{
"name": "redis_connections_blocked",
"type": "gauge",
"value": redis_stats.get("blocked_clients", 0),
"unit": "count",
"component": "redis",
"description": "Blocked Redis connections",
"metadata": {"timestamp": timestamp.isoformat()}
}
])
# Table row counts
table_counts = await self._get_table_counts(session)
for table_name, count in table_counts.items():
metrics.append({
"name": f"database_table_rows_{table_name}",
"type": "gauge",
"value": count,
"unit": "count",
"component": "postgresql",
"description": f"Number of rows in {table_name} table",
"metadata": {"timestamp": timestamp.isoformat(), "table": table_name}
})
return metrics
async def _get_table_counts(self, session: AsyncSession) -> Dict[str, int]:
"""Get row counts for all tables."""
counts = {}
# Count devices
result = await session.execute(select(func.count(Device.id)))
counts["devices"] = result.scalar() or 0
# Count sessions
result = await session.execute(select(func.count(Session.id)))
counts["sessions"] = result.scalar() or 0
# Count CSI data
result = await session.execute(select(func.count(CSIData.id)))
counts["csi_data"] = result.scalar() or 0
# Count pose detections
result = await session.execute(select(func.count(PoseDetection.id)))
counts["pose_detections"] = result.scalar() or 0
# Count system metrics
result = await session.execute(select(func.count(SystemMetric.id)))
counts["system_metrics"] = result.scalar() or 0
return counts
class ApplicationMonitoring(MonitoringTask):
"""Monitor application-specific metrics."""
def __init__(self, settings: Settings):
super().__init__("application", settings)
self.interval_seconds = settings.application_monitoring_interval
self.start_time = datetime.utcnow()
async def collect_metrics(self, session: AsyncSession) -> List[Dict[str, Any]]:
"""Collect application metrics."""
metrics = []
timestamp = datetime.utcnow()
# Application uptime
uptime_seconds = (timestamp - self.start_time).total_seconds()
metrics.append({
"name": "application_uptime_seconds",
"type": "gauge",
"value": uptime_seconds,
"unit": "seconds",
"component": "application",
"description": "Application uptime in seconds",
"metadata": {"timestamp": timestamp.isoformat()}
})
# Active sessions count
active_sessions_query = select(func.count(Session.id)).where(
Session.status == "active"
)
result = await session.execute(active_sessions_query)
active_sessions = result.scalar() or 0
metrics.append({
"name": "application_active_sessions",
"type": "gauge",
"value": active_sessions,
"unit": "count",
"component": "application",
"description": "Number of active sessions",
"metadata": {"timestamp": timestamp.isoformat()}
})
# Active devices count
active_devices_query = select(func.count(Device.id)).where(
Device.status == "active"
)
result = await session.execute(active_devices_query)
active_devices = result.scalar() or 0
metrics.append({
"name": "application_active_devices",
"type": "gauge",
"value": active_devices,
"unit": "count",
"component": "application",
"description": "Number of active devices",
"metadata": {"timestamp": timestamp.isoformat()}
})
# Recent data processing metrics (last hour)
one_hour_ago = timestamp - timedelta(hours=1)
# Recent CSI data count
recent_csi_query = select(func.count(CSIData.id)).where(
CSIData.created_at >= one_hour_ago
)
result = await session.execute(recent_csi_query)
recent_csi_count = result.scalar() or 0
metrics.append({
"name": "application_csi_data_hourly",
"type": "gauge",
"value": recent_csi_count,
"unit": "count",
"component": "application",
"description": "CSI data records created in the last hour",
"metadata": {"timestamp": timestamp.isoformat()}
})
# Recent pose detections count
recent_pose_query = select(func.count(PoseDetection.id)).where(
PoseDetection.created_at >= one_hour_ago
)
result = await session.execute(recent_pose_query)
recent_pose_count = result.scalar() or 0
metrics.append({
"name": "application_pose_detections_hourly",
"type": "gauge",
"value": recent_pose_count,
"unit": "count",
"component": "application",
"description": "Pose detections created in the last hour",
"metadata": {"timestamp": timestamp.isoformat()}
})
# Processing status metrics
processing_statuses = ["pending", "processing", "completed", "failed"]
for status in processing_statuses:
status_query = select(func.count(CSIData.id)).where(
CSIData.processing_status == status
)
result = await session.execute(status_query)
status_count = result.scalar() or 0
metrics.append({
"name": f"application_csi_processing_{status}",
"type": "gauge",
"value": status_count,
"unit": "count",
"component": "application",
"description": f"CSI data records with {status} processing status",
"metadata": {"timestamp": timestamp.isoformat(), "status": status}
})
return metrics
class PerformanceMonitoring(MonitoringTask):
"""Monitor performance metrics and response times."""
def __init__(self, settings: Settings):
super().__init__("performance", settings)
self.interval_seconds = settings.performance_monitoring_interval
self.response_times = []
self.error_counts = {}
async def collect_metrics(self, session: AsyncSession) -> List[Dict[str, Any]]:
"""Collect performance metrics."""
metrics = []
timestamp = datetime.utcnow()
# Database query performance test
start_time = time.time()
test_query = select(func.count(Device.id))
await session.execute(test_query)
db_response_time = (time.time() - start_time) * 1000 # Convert to milliseconds
metrics.append({
"name": "performance_database_query_time_ms",
"type": "gauge",
"value": db_response_time,
"unit": "milliseconds",
"component": "database",
"description": "Database query response time",
"metadata": {"timestamp": timestamp.isoformat()}
})
# Average response time (if we have data)
if self.response_times:
avg_response_time = sum(self.response_times) / len(self.response_times)
metrics.append({
"name": "performance_avg_response_time_ms",
"type": "gauge",
"value": avg_response_time,
"unit": "milliseconds",
"component": "api",
"description": "Average API response time",
"metadata": {"timestamp": timestamp.isoformat()}
})
# Clear old response times (keep only recent ones)
self.response_times = self.response_times[-100:] # Keep last 100
# Error rates
for error_type, count in self.error_counts.items():
metrics.append({
"name": f"performance_errors_{error_type}_total",
"type": "counter",
"value": count,
"unit": "count",
"component": "api",
"description": f"Total {error_type} errors",
"metadata": {"timestamp": timestamp.isoformat(), "error_type": error_type}
})
return metrics
def record_response_time(self, response_time_ms: float):
"""Record an API response time."""
self.response_times.append(response_time_ms)
def record_error(self, error_type: str):
"""Record an error occurrence."""
self.error_counts[error_type] = self.error_counts.get(error_type, 0) + 1
class MonitoringManager:
"""Manager for all monitoring tasks."""
def __init__(self, settings: Settings):
self.settings = settings
self.db_manager = get_database_manager(settings)
self.tasks = self._initialize_tasks()
self.running = False
self.last_run = None
self.run_count = 0
def _initialize_tasks(self) -> List[MonitoringTask]:
"""Initialize all monitoring tasks."""
tasks = [
SystemResourceMonitoring(self.settings),
DatabaseMonitoring(self.settings),
ApplicationMonitoring(self.settings),
PerformanceMonitoring(self.settings),
]
# Filter enabled tasks
enabled_tasks = [task for task in tasks if task.enabled]
logger.info(f"Initialized {len(enabled_tasks)} monitoring tasks")
return enabled_tasks
async def run_all_tasks(self) -> Dict[str, Any]:
"""Run all monitoring tasks."""
if self.running:
return {"status": "already_running", "message": "Monitoring already in progress"}
self.running = True
start_time = datetime.utcnow()
try:
logger.debug("Starting monitoring tasks")
results = []
total_metrics = 0
async with self.db_manager.get_async_session() as session:
for task in self.tasks:
if not task.enabled:
continue
result = await task.run(session)
results.append(result)
total_metrics += result.get("metrics_collected", 0)
self.last_run = start_time
self.run_count += 1
duration = (datetime.utcnow() - start_time).total_seconds()
logger.debug(
f"Monitoring tasks completed: collected {total_metrics} metrics "
f"in {duration:.2f} seconds"
)
return {
"status": "completed",
"start_time": start_time.isoformat(),
"duration_seconds": duration,
"total_metrics": total_metrics,
"task_results": results,
}
except Exception as e:
logger.error(f"Monitoring tasks failed: {e}", exc_info=True)
return {
"status": "error",
"start_time": start_time.isoformat(),
"duration_seconds": (datetime.utcnow() - start_time).total_seconds(),
"error": str(e),
"total_metrics": 0,
}
finally:
self.running = False
async def run_task(self, task_name: str) -> Dict[str, Any]:
"""Run a specific monitoring task."""
task = next((t for t in self.tasks if t.name == task_name), None)
if not task:
return {
"status": "error",
"error": f"Task '{task_name}' not found",
"available_tasks": [t.name for t in self.tasks]
}
if not task.enabled:
return {
"status": "error",
"error": f"Task '{task_name}' is disabled"
}
async with self.db_manager.get_async_session() as session:
return await task.run(session)
def get_stats(self) -> Dict[str, Any]:
"""Get monitoring manager statistics."""
return {
"manager": {
"running": self.running,
"last_run": self.last_run.isoformat() if self.last_run else None,
"run_count": self.run_count,
},
"tasks": [task.get_stats() for task in self.tasks],
}
def get_performance_task(self) -> Optional[PerformanceMonitoring]:
"""Get the performance monitoring task for recording metrics."""
return next((t for t in self.tasks if isinstance(t, PerformanceMonitoring)), None)
# Global monitoring manager instance
_monitoring_manager: Optional[MonitoringManager] = None
def get_monitoring_manager(settings: Settings) -> MonitoringManager:
"""Get monitoring manager instance."""
global _monitoring_manager
if _monitoring_manager is None:
_monitoring_manager = MonitoringManager(settings)
return _monitoring_manager
async def run_periodic_monitoring(settings: Settings):
"""Run periodic monitoring tasks."""
monitoring_manager = get_monitoring_manager(settings)
while True:
try:
await monitoring_manager.run_all_tasks()
# Wait for next monitoring interval
await asyncio.sleep(settings.monitoring_interval_seconds)
except asyncio.CancelledError:
logger.info("Periodic monitoring cancelled")
break
except Exception as e:
logger.error(f"Periodic monitoring error: {e}", exc_info=True)
# Wait before retrying
await asyncio.sleep(30)