feat: Complete Rust port of WiFi-DensePose with modular crates
Major changes: - Organized Python v1 implementation into v1/ subdirectory - Created Rust workspace with 9 modular crates: - wifi-densepose-core: Core types, traits, errors - wifi-densepose-signal: CSI processing, phase sanitization, FFT - wifi-densepose-nn: Neural network inference (ONNX/Candle/tch) - wifi-densepose-api: Axum-based REST/WebSocket API - wifi-densepose-db: SQLx database layer - wifi-densepose-config: Configuration management - wifi-densepose-hardware: Hardware abstraction - wifi-densepose-wasm: WebAssembly bindings - wifi-densepose-cli: Command-line interface Documentation: - ADR-001: Workspace structure - ADR-002: Signal processing library selection - ADR-003: Neural network inference strategy - DDD domain model with bounded contexts Testing: - 69 tests passing across all crates - Signal processing: 45 tests - Neural networks: 21 tests - Core: 3 doc tests Performance targets: - 10x faster CSI processing (~0.5ms vs ~5ms) - 5x lower memory usage (~100MB vs ~500MB) - WASM support for browser deployment
This commit is contained in:
7
v1/src/api/__init__.py
Normal file
7
v1/src/api/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
WiFi-DensePose FastAPI application package
|
||||
"""
|
||||
|
||||
# API package - routers and dependencies are imported by app.py
|
||||
|
||||
__all__ = []
|
||||
447
v1/src/api/dependencies.py
Normal file
447
v1/src/api/dependencies.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""
|
||||
Dependency injection for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from functools import lru_cache
|
||||
|
||||
from fastapi import Depends, HTTPException, status, Request
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
|
||||
from src.config.settings import get_settings
|
||||
from src.config.domains import get_domain_config
|
||||
from src.services.pose_service import PoseService
|
||||
from src.services.stream_service import StreamService
|
||||
from src.services.hardware_service import HardwareService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Security scheme for JWT authentication
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
# Service dependencies
|
||||
@lru_cache()
|
||||
def get_pose_service() -> PoseService:
|
||||
"""Get pose service instance."""
|
||||
settings = get_settings()
|
||||
domain_config = get_domain_config()
|
||||
|
||||
return PoseService(
|
||||
settings=settings,
|
||||
domain_config=domain_config
|
||||
)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_stream_service() -> StreamService:
|
||||
"""Get stream service instance."""
|
||||
settings = get_settings()
|
||||
domain_config = get_domain_config()
|
||||
|
||||
return StreamService(
|
||||
settings=settings,
|
||||
domain_config=domain_config
|
||||
)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_hardware_service() -> HardwareService:
|
||||
"""Get hardware service instance."""
|
||||
settings = get_settings()
|
||||
domain_config = get_domain_config()
|
||||
|
||||
return HardwareService(
|
||||
settings=settings,
|
||||
domain_config=domain_config
|
||||
)
|
||||
|
||||
|
||||
# Authentication dependencies
|
||||
async def get_current_user(
|
||||
request: Request,
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get current authenticated user."""
|
||||
settings = get_settings()
|
||||
|
||||
# Skip authentication if disabled
|
||||
if not settings.enable_authentication:
|
||||
return None
|
||||
|
||||
# Check if user is already set by middleware
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
return request.state.user
|
||||
|
||||
# No credentials provided
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
# This would normally validate the JWT token
|
||||
# For now, return a mock user for development
|
||||
if settings.is_development:
|
||||
return {
|
||||
"id": "dev-user",
|
||||
"username": "developer",
|
||||
"email": "dev@example.com",
|
||||
"is_admin": True,
|
||||
"permissions": ["read", "write", "admin"]
|
||||
}
|
||||
|
||||
# In production, implement proper JWT validation
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication not implemented",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get current active user (required authentication)."""
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Check if user is active
|
||||
if not current_user.get("is_active", True):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
|
||||
async def get_admin_user(
|
||||
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get current admin user (admin privileges required)."""
|
||||
if not current_user.get("is_admin", False):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin privileges required"
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
|
||||
# Permission dependencies
|
||||
def require_permission(permission: str):
|
||||
"""Dependency factory for permission checking."""
|
||||
|
||||
async def check_permission(
|
||||
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Check if user has required permission."""
|
||||
user_permissions = current_user.get("permissions", [])
|
||||
|
||||
# Admin users have all permissions
|
||||
if current_user.get("is_admin", False):
|
||||
return current_user
|
||||
|
||||
# Check specific permission
|
||||
if permission not in user_permissions:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Permission '{permission}' required"
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
return check_permission
|
||||
|
||||
|
||||
# Zone access dependencies
|
||||
async def validate_zone_access(
|
||||
zone_id: str,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
|
||||
) -> str:
|
||||
"""Validate user access to a specific zone."""
|
||||
domain_config = get_domain_config()
|
||||
|
||||
# Check if zone exists
|
||||
zone = domain_config.get_zone(zone_id)
|
||||
if not zone:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Zone '{zone_id}' not found"
|
||||
)
|
||||
|
||||
# Check if zone is enabled
|
||||
if not zone.enabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Zone '{zone_id}' is disabled"
|
||||
)
|
||||
|
||||
# If authentication is enabled, check user access
|
||||
if current_user:
|
||||
# Admin users have access to all zones
|
||||
if current_user.get("is_admin", False):
|
||||
return zone_id
|
||||
|
||||
# Check user's zone permissions
|
||||
user_zones = current_user.get("zones", [])
|
||||
if user_zones and zone_id not in user_zones:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Access denied to zone '{zone_id}'"
|
||||
)
|
||||
|
||||
return zone_id
|
||||
|
||||
|
||||
# Router access dependencies
|
||||
async def validate_router_access(
|
||||
router_id: str,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
|
||||
) -> str:
|
||||
"""Validate user access to a specific router."""
|
||||
domain_config = get_domain_config()
|
||||
|
||||
# Check if router exists
|
||||
router = domain_config.get_router(router_id)
|
||||
if not router:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Router '{router_id}' not found"
|
||||
)
|
||||
|
||||
# Check if router is enabled
|
||||
if not router.enabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Router '{router_id}' is disabled"
|
||||
)
|
||||
|
||||
# If authentication is enabled, check user access
|
||||
if current_user:
|
||||
# Admin users have access to all routers
|
||||
if current_user.get("is_admin", False):
|
||||
return router_id
|
||||
|
||||
# Check user's router permissions
|
||||
user_routers = current_user.get("routers", [])
|
||||
if user_routers and router_id not in user_routers:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Access denied to router '{router_id}'"
|
||||
)
|
||||
|
||||
return router_id
|
||||
|
||||
|
||||
# Service health dependencies
|
||||
async def check_service_health(
|
||||
request: Request,
|
||||
service_name: str
|
||||
) -> bool:
|
||||
"""Check if a service is healthy."""
|
||||
try:
|
||||
if service_name == "pose":
|
||||
service = getattr(request.app.state, 'pose_service', None)
|
||||
elif service_name == "stream":
|
||||
service = getattr(request.app.state, 'stream_service', None)
|
||||
elif service_name == "hardware":
|
||||
service = getattr(request.app.state, 'hardware_service', None)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unknown service: {service_name}"
|
||||
)
|
||||
|
||||
if not service:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Service '{service_name}' not available"
|
||||
)
|
||||
|
||||
# Check service health
|
||||
status_info = await service.get_status()
|
||||
if status_info.get("status") != "healthy":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Service '{service_name}' is unhealthy: {status_info.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking service health for {service_name}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Service '{service_name}' health check failed"
|
||||
)
|
||||
|
||||
|
||||
# Rate limiting dependencies
|
||||
async def check_rate_limit(
|
||||
request: Request,
|
||||
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
|
||||
) -> bool:
|
||||
"""Check rate limiting status."""
|
||||
settings = get_settings()
|
||||
|
||||
# Skip if rate limiting is disabled
|
||||
if not settings.enable_rate_limiting:
|
||||
return True
|
||||
|
||||
# Rate limiting is handled by middleware
|
||||
# This dependency can be used for additional checks
|
||||
return True
|
||||
|
||||
|
||||
# Configuration dependencies
|
||||
def get_zone_config(zone_id: str = Depends(validate_zone_access)):
|
||||
"""Get zone configuration."""
|
||||
domain_config = get_domain_config()
|
||||
return domain_config.get_zone(zone_id)
|
||||
|
||||
|
||||
def get_router_config(router_id: str = Depends(validate_router_access)):
|
||||
"""Get router configuration."""
|
||||
domain_config = get_domain_config()
|
||||
return domain_config.get_router(router_id)
|
||||
|
||||
|
||||
# Pagination dependencies
|
||||
class PaginationParams:
|
||||
"""Pagination parameters."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
page: int = 1,
|
||||
size: int = 20,
|
||||
max_size: int = 100
|
||||
):
|
||||
if page < 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Page must be >= 1"
|
||||
)
|
||||
|
||||
if size < 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Size must be >= 1"
|
||||
)
|
||||
|
||||
if size > max_size:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Size must be <= {max_size}"
|
||||
)
|
||||
|
||||
self.page = page
|
||||
self.size = size
|
||||
self.offset = (page - 1) * size
|
||||
self.limit = size
|
||||
|
||||
|
||||
def get_pagination_params(
|
||||
page: int = 1,
|
||||
size: int = 20
|
||||
) -> PaginationParams:
|
||||
"""Get pagination parameters."""
|
||||
return PaginationParams(page=page, size=size)
|
||||
|
||||
|
||||
# Query filter dependencies
|
||||
class QueryFilters:
|
||||
"""Common query filters."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_time: Optional[str] = None,
|
||||
end_time: Optional[str] = None,
|
||||
min_confidence: Optional[float] = None,
|
||||
activity: Optional[str] = None
|
||||
):
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
self.min_confidence = min_confidence
|
||||
self.activity = activity
|
||||
|
||||
# Validate confidence
|
||||
if min_confidence is not None:
|
||||
if not 0.0 <= min_confidence <= 1.0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="min_confidence must be between 0.0 and 1.0"
|
||||
)
|
||||
|
||||
|
||||
def get_query_filters(
|
||||
start_time: Optional[str] = None,
|
||||
end_time: Optional[str] = None,
|
||||
min_confidence: Optional[float] = None,
|
||||
activity: Optional[str] = None
|
||||
) -> QueryFilters:
|
||||
"""Get query filters."""
|
||||
return QueryFilters(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
min_confidence=min_confidence,
|
||||
activity=activity
|
||||
)
|
||||
|
||||
|
||||
# WebSocket dependencies
|
||||
async def get_websocket_user(
|
||||
websocket_token: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get user from WebSocket token."""
|
||||
settings = get_settings()
|
||||
|
||||
# Skip authentication if disabled
|
||||
if not settings.enable_authentication:
|
||||
return None
|
||||
|
||||
# For development, return mock user
|
||||
if settings.is_development:
|
||||
return {
|
||||
"id": "ws-user",
|
||||
"username": "websocket_user",
|
||||
"is_admin": False,
|
||||
"permissions": ["read"]
|
||||
}
|
||||
|
||||
# In production, implement proper token validation
|
||||
return None
|
||||
|
||||
|
||||
async def get_current_user_ws(
|
||||
websocket_token: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get current user for WebSocket connections."""
|
||||
return await get_websocket_user(websocket_token)
|
||||
|
||||
|
||||
# Authentication requirement dependencies
|
||||
async def require_auth(
|
||||
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Require authentication for endpoint access."""
|
||||
return current_user
|
||||
|
||||
|
||||
# Development dependencies
|
||||
async def development_only():
|
||||
"""Dependency that only allows access in development."""
|
||||
settings = get_settings()
|
||||
|
||||
if not settings.is_development:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Endpoint not available in production"
|
||||
)
|
||||
|
||||
return True
|
||||
421
v1/src/api/main.py
Normal file
421
v1/src/api/main.py
Normal file
@@ -0,0 +1,421 @@
|
||||
"""
|
||||
FastAPI application for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import logging.config
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict, Any
|
||||
|
||||
from fastapi import FastAPI, Request, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
|
||||
from src.config.settings import get_settings
|
||||
from src.config.domains import get_domain_config
|
||||
from src.api.routers import pose, stream, health
|
||||
from src.api.middleware.auth import AuthMiddleware
|
||||
from src.api.middleware.rate_limit import RateLimitMiddleware
|
||||
from src.api.dependencies import get_pose_service, get_stream_service, get_hardware_service
|
||||
from src.api.websocket.connection_manager import connection_manager
|
||||
from src.api.websocket.pose_stream import PoseStreamHandler
|
||||
|
||||
# Configure logging
|
||||
settings = get_settings()
|
||||
logging.config.dictConfig(settings.get_logging_config())
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan manager."""
|
||||
logger.info("Starting WiFi-DensePose API...")
|
||||
|
||||
try:
|
||||
# Initialize services
|
||||
await initialize_services(app)
|
||||
|
||||
# Start background tasks
|
||||
await start_background_tasks(app)
|
||||
|
||||
logger.info("WiFi-DensePose API started successfully")
|
||||
|
||||
yield
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start application: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Cleanup on shutdown
|
||||
logger.info("Shutting down WiFi-DensePose API...")
|
||||
await cleanup_services(app)
|
||||
logger.info("WiFi-DensePose API shutdown complete")
|
||||
|
||||
|
||||
async def initialize_services(app: FastAPI):
|
||||
"""Initialize application services."""
|
||||
try:
|
||||
# Initialize hardware service
|
||||
hardware_service = get_hardware_service()
|
||||
await hardware_service.initialize()
|
||||
|
||||
# Initialize pose service
|
||||
pose_service = get_pose_service()
|
||||
await pose_service.initialize()
|
||||
|
||||
# Initialize stream service
|
||||
stream_service = get_stream_service()
|
||||
await stream_service.initialize()
|
||||
|
||||
# Initialize pose stream handler
|
||||
pose_stream_handler = PoseStreamHandler(
|
||||
connection_manager=connection_manager,
|
||||
pose_service=pose_service,
|
||||
stream_service=stream_service
|
||||
)
|
||||
|
||||
# Store in app state for access in routes
|
||||
app.state.hardware_service = hardware_service
|
||||
app.state.pose_service = pose_service
|
||||
app.state.stream_service = stream_service
|
||||
app.state.pose_stream_handler = pose_stream_handler
|
||||
|
||||
logger.info("Services initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize services: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def start_background_tasks(app: FastAPI):
|
||||
"""Start background tasks."""
|
||||
try:
|
||||
# Start pose service
|
||||
pose_service = app.state.pose_service
|
||||
await pose_service.start()
|
||||
logger.info("Pose service started")
|
||||
|
||||
# Start pose streaming if enabled
|
||||
if settings.enable_real_time_processing:
|
||||
pose_stream_handler = app.state.pose_stream_handler
|
||||
await pose_stream_handler.start_streaming()
|
||||
|
||||
logger.info("Background tasks started")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start background tasks: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def cleanup_services(app: FastAPI):
|
||||
"""Cleanup services on shutdown."""
|
||||
try:
|
||||
# Stop pose streaming
|
||||
if hasattr(app.state, 'pose_stream_handler'):
|
||||
await app.state.pose_stream_handler.shutdown()
|
||||
|
||||
# Shutdown connection manager
|
||||
await connection_manager.shutdown()
|
||||
|
||||
# Cleanup services
|
||||
if hasattr(app.state, 'stream_service'):
|
||||
await app.state.stream_service.shutdown()
|
||||
|
||||
if hasattr(app.state, 'pose_service'):
|
||||
await app.state.pose_service.stop()
|
||||
|
||||
if hasattr(app.state, 'hardware_service'):
|
||||
await app.state.hardware_service.shutdown()
|
||||
|
||||
logger.info("Services cleaned up successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during cleanup: {e}")
|
||||
|
||||
|
||||
# Create FastAPI application
|
||||
app = FastAPI(
|
||||
title=settings.app_name,
|
||||
version=settings.version,
|
||||
description="WiFi-based human pose estimation and activity recognition API",
|
||||
docs_url=settings.docs_url if not settings.is_production else None,
|
||||
redoc_url=settings.redoc_url if not settings.is_production else None,
|
||||
openapi_url=settings.openapi_url if not settings.is_production else None,
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Add middleware
|
||||
if settings.enable_rate_limiting:
|
||||
app.add_middleware(RateLimitMiddleware)
|
||||
|
||||
if settings.enable_authentication:
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
# Add CORS middleware
|
||||
cors_config = settings.get_cors_config()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
**cors_config
|
||||
)
|
||||
|
||||
# Add trusted host middleware for production
|
||||
if settings.is_production:
|
||||
app.add_middleware(
|
||||
TrustedHostMiddleware,
|
||||
allowed_hosts=settings.allowed_hosts
|
||||
)
|
||||
|
||||
|
||||
# Exception handlers
|
||||
@app.exception_handler(StarletteHTTPException)
|
||||
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
|
||||
"""Handle HTTP exceptions."""
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={
|
||||
"error": {
|
||||
"code": exc.status_code,
|
||||
"message": exc.detail,
|
||||
"type": "http_error"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
"""Handle request validation errors."""
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content={
|
||||
"error": {
|
||||
"code": 422,
|
||||
"message": "Validation error",
|
||||
"type": "validation_error",
|
||||
"details": exc.errors()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def general_exception_handler(request: Request, exc: Exception):
|
||||
"""Handle general exceptions."""
|
||||
logger.error(f"Unhandled exception: {exc}", exc_info=True)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": {
|
||||
"code": 500,
|
||||
"message": "Internal server error",
|
||||
"type": "internal_error"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Middleware for request logging
|
||||
@app.middleware("http")
|
||||
async def log_requests(request: Request, call_next):
|
||||
"""Log all requests."""
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Calculate processing time
|
||||
process_time = asyncio.get_event_loop().time() - start_time
|
||||
|
||||
# Log request
|
||||
logger.info(
|
||||
f"{request.method} {request.url.path} - "
|
||||
f"Status: {response.status_code} - "
|
||||
f"Time: {process_time:.3f}s"
|
||||
)
|
||||
|
||||
# Add processing time header
|
||||
response.headers["X-Process-Time"] = str(process_time)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# Include routers
|
||||
app.include_router(
|
||||
health.router,
|
||||
prefix="/health",
|
||||
tags=["Health"]
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
pose.router,
|
||||
prefix=f"{settings.api_prefix}/pose",
|
||||
tags=["Pose Estimation"]
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
stream.router,
|
||||
prefix=f"{settings.api_prefix}/stream",
|
||||
tags=["Streaming"]
|
||||
)
|
||||
|
||||
|
||||
# Root endpoint
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint with API information."""
|
||||
return {
|
||||
"name": settings.app_name,
|
||||
"version": settings.version,
|
||||
"environment": settings.environment,
|
||||
"docs_url": settings.docs_url,
|
||||
"api_prefix": settings.api_prefix,
|
||||
"features": {
|
||||
"authentication": settings.enable_authentication,
|
||||
"rate_limiting": settings.enable_rate_limiting,
|
||||
"websockets": settings.enable_websockets,
|
||||
"real_time_processing": settings.enable_real_time_processing
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# API information endpoint
|
||||
@app.get(f"{settings.api_prefix}/info")
|
||||
async def api_info():
|
||||
"""Get detailed API information."""
|
||||
domain_config = get_domain_config()
|
||||
|
||||
return {
|
||||
"api": {
|
||||
"name": settings.app_name,
|
||||
"version": settings.version,
|
||||
"environment": settings.environment,
|
||||
"prefix": settings.api_prefix
|
||||
},
|
||||
"configuration": {
|
||||
"zones": len(domain_config.zones),
|
||||
"routers": len(domain_config.routers),
|
||||
"pose_models": len(domain_config.pose_models)
|
||||
},
|
||||
"features": {
|
||||
"authentication": settings.enable_authentication,
|
||||
"rate_limiting": settings.enable_rate_limiting,
|
||||
"websockets": settings.enable_websockets,
|
||||
"real_time_processing": settings.enable_real_time_processing,
|
||||
"historical_data": settings.enable_historical_data
|
||||
},
|
||||
"limits": {
|
||||
"rate_limit_requests": settings.rate_limit_requests,
|
||||
"rate_limit_window": settings.rate_limit_window,
|
||||
"max_websocket_connections": domain_config.streaming.max_connections
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Status endpoint
|
||||
@app.get(f"{settings.api_prefix}/status")
|
||||
async def api_status(request: Request):
|
||||
"""Get current API status."""
|
||||
try:
|
||||
# Get services from app state
|
||||
hardware_service = getattr(request.app.state, 'hardware_service', None)
|
||||
pose_service = getattr(request.app.state, 'pose_service', None)
|
||||
stream_service = getattr(request.app.state, 'stream_service', None)
|
||||
pose_stream_handler = getattr(request.app.state, 'pose_stream_handler', None)
|
||||
|
||||
# Get service statuses
|
||||
status = {
|
||||
"api": {
|
||||
"status": "healthy",
|
||||
"uptime": "unknown",
|
||||
"version": settings.version
|
||||
},
|
||||
"services": {
|
||||
"hardware": await hardware_service.get_status() if hardware_service else {"status": "unavailable"},
|
||||
"pose": await pose_service.get_status() if pose_service else {"status": "unavailable"},
|
||||
"stream": await stream_service.get_status() if stream_service else {"status": "unavailable"}
|
||||
},
|
||||
"streaming": pose_stream_handler.get_stream_status() if pose_stream_handler else {"is_streaming": False},
|
||||
"connections": await connection_manager.get_connection_stats()
|
||||
}
|
||||
|
||||
return status
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting API status: {e}")
|
||||
return {
|
||||
"api": {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Metrics endpoint (if enabled)
|
||||
if settings.metrics_enabled:
|
||||
@app.get(f"{settings.api_prefix}/metrics")
|
||||
async def api_metrics(request: Request):
|
||||
"""Get API metrics."""
|
||||
try:
|
||||
# Get services from app state
|
||||
pose_stream_handler = getattr(request.app.state, 'pose_stream_handler', None)
|
||||
|
||||
metrics = {
|
||||
"connections": await connection_manager.get_metrics(),
|
||||
"streaming": await pose_stream_handler.get_performance_metrics() if pose_stream_handler else {}
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting metrics: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
# Development endpoints (only in development)
|
||||
if settings.is_development and settings.enable_test_endpoints:
|
||||
@app.get(f"{settings.api_prefix}/dev/config")
|
||||
async def dev_config():
|
||||
"""Get current configuration (development only)."""
|
||||
domain_config = get_domain_config()
|
||||
return {
|
||||
"settings": settings.dict(),
|
||||
"domain_config": domain_config.to_dict()
|
||||
}
|
||||
|
||||
@app.post(f"{settings.api_prefix}/dev/reset")
|
||||
async def dev_reset(request: Request):
|
||||
"""Reset services (development only)."""
|
||||
try:
|
||||
# Reset services
|
||||
hardware_service = getattr(request.app.state, 'hardware_service', None)
|
||||
pose_service = getattr(request.app.state, 'pose_service', None)
|
||||
|
||||
if hardware_service:
|
||||
await hardware_service.reset()
|
||||
|
||||
if pose_service:
|
||||
await pose_service.reset()
|
||||
|
||||
return {"message": "Services reset successfully"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error resetting services: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"src.api.main:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
reload=settings.reload,
|
||||
workers=settings.workers if not settings.reload else 1,
|
||||
log_level=settings.log_level.lower()
|
||||
)
|
||||
8
v1/src/api/middleware/__init__.py
Normal file
8
v1/src/api/middleware/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
FastAPI middleware package
|
||||
"""
|
||||
|
||||
from .auth import AuthMiddleware
|
||||
from .rate_limit import RateLimitMiddleware
|
||||
|
||||
__all__ = ["AuthMiddleware", "RateLimitMiddleware"]
|
||||
322
v1/src/api/middleware/auth.py
Normal file
322
v1/src/api/middleware/auth.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
JWT Authentication middleware for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import Request, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from src.config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
"""JWT Authentication middleware."""
|
||||
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
self.settings = get_settings()
|
||||
|
||||
# Paths that don't require authentication
|
||||
self.public_paths = {
|
||||
"/",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/health",
|
||||
"/ready",
|
||||
"/live",
|
||||
"/version",
|
||||
"/metrics"
|
||||
}
|
||||
|
||||
# Paths that require authentication
|
||||
self.protected_paths = {
|
||||
"/api/v1/pose/analyze",
|
||||
"/api/v1/pose/calibrate",
|
||||
"/api/v1/pose/historical",
|
||||
"/api/v1/stream/start",
|
||||
"/api/v1/stream/stop",
|
||||
"/api/v1/stream/clients",
|
||||
"/api/v1/stream/broadcast"
|
||||
}
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""Process request through authentication middleware."""
|
||||
|
||||
# Skip authentication for public paths
|
||||
if self._is_public_path(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Extract and validate token
|
||||
token = self._extract_token(request)
|
||||
|
||||
if token:
|
||||
try:
|
||||
# Verify token and add user info to request state
|
||||
user_data = await self._verify_token(token)
|
||||
request.state.user = user_data
|
||||
request.state.authenticated = True
|
||||
|
||||
logger.debug(f"Authenticated user: {user_data.get('id')}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Token validation failed: {e}")
|
||||
|
||||
# For protected paths, return 401
|
||||
if self._is_protected_path(request.url.path):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"error": {
|
||||
"code": 401,
|
||||
"message": "Invalid or expired token",
|
||||
"type": "authentication_error"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# For other paths, continue without authentication
|
||||
request.state.user = None
|
||||
request.state.authenticated = False
|
||||
else:
|
||||
# No token provided
|
||||
if self._is_protected_path(request.url.path):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"error": {
|
||||
"code": 401,
|
||||
"message": "Authentication required",
|
||||
"type": "authentication_error"
|
||||
}
|
||||
},
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
request.state.user = None
|
||||
request.state.authenticated = False
|
||||
|
||||
# Continue with request processing
|
||||
response = await call_next(request)
|
||||
|
||||
# Add authentication headers to response
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
response.headers["X-User-ID"] = request.state.user.get("id", "")
|
||||
response.headers["X-Authenticated"] = "true"
|
||||
else:
|
||||
response.headers["X-Authenticated"] = "false"
|
||||
|
||||
return response
|
||||
|
||||
def _is_public_path(self, path: str) -> bool:
|
||||
"""Check if path is public (doesn't require authentication)."""
|
||||
# Exact match
|
||||
if path in self.public_paths:
|
||||
return True
|
||||
|
||||
# Pattern matching for public paths
|
||||
public_patterns = [
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/api/v1/pose/current", # Allow anonymous access to current pose data
|
||||
"/api/v1/pose/zones/", # Allow anonymous access to zone data
|
||||
"/api/v1/pose/activities", # Allow anonymous access to activities
|
||||
"/api/v1/pose/stats", # Allow anonymous access to stats
|
||||
"/api/v1/stream/status" # Allow anonymous access to stream status
|
||||
]
|
||||
|
||||
for pattern in public_patterns:
|
||||
if path.startswith(pattern):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_protected_path(self, path: str) -> bool:
|
||||
"""Check if path requires authentication."""
|
||||
# Exact match
|
||||
if path in self.protected_paths:
|
||||
return True
|
||||
|
||||
# Pattern matching for protected paths
|
||||
protected_patterns = [
|
||||
"/api/v1/pose/analyze",
|
||||
"/api/v1/pose/calibrate",
|
||||
"/api/v1/pose/historical",
|
||||
"/api/v1/stream/start",
|
||||
"/api/v1/stream/stop",
|
||||
"/api/v1/stream/clients",
|
||||
"/api/v1/stream/broadcast"
|
||||
]
|
||||
|
||||
for pattern in protected_patterns:
|
||||
if path.startswith(pattern):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _extract_token(self, request: Request) -> Optional[str]:
|
||||
"""Extract JWT token from request."""
|
||||
# Check Authorization header
|
||||
auth_header = request.headers.get("authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
return auth_header.split(" ")[1]
|
||||
|
||||
# Check query parameter (for WebSocket connections)
|
||||
token = request.query_params.get("token")
|
||||
if token:
|
||||
return token
|
||||
|
||||
# Check cookie
|
||||
token = request.cookies.get("access_token")
|
||||
if token:
|
||||
return token
|
||||
|
||||
return None
|
||||
|
||||
async def _verify_token(self, token: str) -> Dict[str, Any]:
|
||||
"""Verify JWT token and return user data."""
|
||||
try:
|
||||
# Decode JWT token
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
self.settings.secret_key,
|
||||
algorithms=[self.settings.jwt_algorithm]
|
||||
)
|
||||
|
||||
# Extract user information
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
raise ValueError("Token missing user ID")
|
||||
|
||||
# Check token expiration
|
||||
exp = payload.get("exp")
|
||||
if exp and datetime.utcnow() > datetime.fromtimestamp(exp):
|
||||
raise ValueError("Token expired")
|
||||
|
||||
# Build user object
|
||||
user_data = {
|
||||
"id": user_id,
|
||||
"username": payload.get("username"),
|
||||
"email": payload.get("email"),
|
||||
"is_admin": payload.get("is_admin", False),
|
||||
"permissions": payload.get("permissions", []),
|
||||
"accessible_zones": payload.get("accessible_zones", []),
|
||||
"token_issued_at": payload.get("iat"),
|
||||
"token_expires_at": payload.get("exp"),
|
||||
"session_id": payload.get("session_id")
|
||||
}
|
||||
|
||||
return user_data
|
||||
|
||||
except JWTError as e:
|
||||
raise ValueError(f"JWT validation failed: {e}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Token verification error: {e}")
|
||||
|
||||
def _log_authentication_event(self, request: Request, event_type: str, details: Dict[str, Any] = None):
|
||||
"""Log authentication events for security monitoring."""
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
user_agent = request.headers.get("user-agent", "unknown")
|
||||
|
||||
log_data = {
|
||||
"event_type": event_type,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"client_ip": client_ip,
|
||||
"user_agent": user_agent,
|
||||
"path": request.url.path,
|
||||
"method": request.method
|
||||
}
|
||||
|
||||
if details:
|
||||
log_data.update(details)
|
||||
|
||||
if event_type in ["authentication_failed", "token_expired", "invalid_token"]:
|
||||
logger.warning(f"Auth event: {log_data}")
|
||||
else:
|
||||
logger.info(f"Auth event: {log_data}")
|
||||
|
||||
|
||||
class TokenBlacklist:
|
||||
"""Simple in-memory token blacklist for logout functionality."""
|
||||
|
||||
def __init__(self):
|
||||
self._blacklisted_tokens = set()
|
||||
self._cleanup_interval = 3600 # 1 hour
|
||||
self._last_cleanup = datetime.utcnow()
|
||||
|
||||
def add_token(self, token: str):
|
||||
"""Add token to blacklist."""
|
||||
self._blacklisted_tokens.add(token)
|
||||
self._cleanup_if_needed()
|
||||
|
||||
def is_blacklisted(self, token: str) -> bool:
|
||||
"""Check if token is blacklisted."""
|
||||
self._cleanup_if_needed()
|
||||
return token in self._blacklisted_tokens
|
||||
|
||||
def _cleanup_if_needed(self):
|
||||
"""Clean up expired tokens from blacklist."""
|
||||
now = datetime.utcnow()
|
||||
if (now - self._last_cleanup).total_seconds() > self._cleanup_interval:
|
||||
# In a real implementation, you would check token expiration
|
||||
# For now, we'll just clear old tokens periodically
|
||||
self._blacklisted_tokens.clear()
|
||||
self._last_cleanup = now
|
||||
|
||||
|
||||
# Global token blacklist instance
|
||||
token_blacklist = TokenBlacklist()
|
||||
|
||||
|
||||
class SecurityHeaders:
|
||||
"""Security headers for API responses."""
|
||||
|
||||
@staticmethod
|
||||
def add_security_headers(response: Response) -> Response:
|
||||
"""Add security headers to response."""
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
response.headers["Content-Security-Policy"] = (
|
||||
"default-src 'self'; "
|
||||
"script-src 'self' 'unsafe-inline'; "
|
||||
"style-src 'self' 'unsafe-inline'; "
|
||||
"img-src 'self' data:; "
|
||||
"connect-src 'self' ws: wss:;"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class APIKeyAuth:
|
||||
"""Alternative API key authentication for service-to-service communication."""
|
||||
|
||||
def __init__(self, api_keys: Dict[str, Dict[str, Any]] = None):
|
||||
self.api_keys = api_keys or {}
|
||||
|
||||
def verify_api_key(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""Verify API key and return associated service info."""
|
||||
if api_key in self.api_keys:
|
||||
return self.api_keys[api_key]
|
||||
return None
|
||||
|
||||
def add_api_key(self, api_key: str, service_info: Dict[str, Any]):
|
||||
"""Add new API key."""
|
||||
self.api_keys[api_key] = service_info
|
||||
|
||||
def revoke_api_key(self, api_key: str):
|
||||
"""Revoke API key."""
|
||||
if api_key in self.api_keys:
|
||||
del self.api_keys[api_key]
|
||||
|
||||
|
||||
# Global API key auth instance
|
||||
api_key_auth = APIKeyAuth()
|
||||
429
v1/src/api/middleware/rate_limit.py
Normal file
429
v1/src/api/middleware/rate_limit.py
Normal file
@@ -0,0 +1,429 @@
|
||||
"""
|
||||
Rate limiting middleware for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict, deque
|
||||
|
||||
from fastapi import Request, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from src.config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Rate limiting middleware with sliding window algorithm."""
|
||||
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
self.settings = get_settings()
|
||||
|
||||
# Rate limit storage (in production, use Redis)
|
||||
self.request_counts = defaultdict(lambda: deque())
|
||||
self.blocked_clients = {}
|
||||
|
||||
# Rate limit configurations
|
||||
self.rate_limits = {
|
||||
"anonymous": {
|
||||
"requests": self.settings.rate_limit_requests,
|
||||
"window": self.settings.rate_limit_window,
|
||||
"burst": 10 # Allow burst of 10 requests
|
||||
},
|
||||
"authenticated": {
|
||||
"requests": self.settings.rate_limit_authenticated_requests,
|
||||
"window": self.settings.rate_limit_window,
|
||||
"burst": 50
|
||||
},
|
||||
"admin": {
|
||||
"requests": 10000, # Very high limit for admins
|
||||
"window": self.settings.rate_limit_window,
|
||||
"burst": 100
|
||||
}
|
||||
}
|
||||
|
||||
# Path-specific rate limits
|
||||
self.path_limits = {
|
||||
"/api/v1/pose/current": {"requests": 60, "window": 60}, # 1 per second
|
||||
"/api/v1/pose/analyze": {"requests": 10, "window": 60}, # 10 per minute
|
||||
"/api/v1/pose/calibrate": {"requests": 1, "window": 300}, # 1 per 5 minutes
|
||||
"/api/v1/stream/start": {"requests": 5, "window": 60}, # 5 per minute
|
||||
"/api/v1/stream/stop": {"requests": 5, "window": 60}, # 5 per minute
|
||||
}
|
||||
|
||||
# Exempt paths from rate limiting
|
||||
self.exempt_paths = {
|
||||
"/health",
|
||||
"/ready",
|
||||
"/live",
|
||||
"/version",
|
||||
"/metrics"
|
||||
}
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""Process request through rate limiting middleware."""
|
||||
|
||||
# Skip rate limiting for exempt paths
|
||||
if self._is_exempt_path(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Get client identifier
|
||||
client_id = self._get_client_id(request)
|
||||
|
||||
# Check if client is temporarily blocked
|
||||
if self._is_client_blocked(client_id):
|
||||
return self._create_rate_limit_response(
|
||||
"Client temporarily blocked due to excessive requests"
|
||||
)
|
||||
|
||||
# Get user type for rate limiting
|
||||
user_type = self._get_user_type(request)
|
||||
|
||||
# Check rate limits
|
||||
rate_limit_result = self._check_rate_limits(
|
||||
client_id,
|
||||
request.url.path,
|
||||
user_type
|
||||
)
|
||||
|
||||
if not rate_limit_result["allowed"]:
|
||||
# Log rate limit violation
|
||||
self._log_rate_limit_violation(request, client_id, rate_limit_result)
|
||||
|
||||
# Check if client should be temporarily blocked
|
||||
if rate_limit_result.get("violations", 0) > 5:
|
||||
self._block_client(client_id, duration=300) # 5 minutes
|
||||
|
||||
return self._create_rate_limit_response(
|
||||
rate_limit_result["message"],
|
||||
retry_after=rate_limit_result.get("retry_after", 60)
|
||||
)
|
||||
|
||||
# Record the request
|
||||
self._record_request(client_id, request.url.path)
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add rate limit headers
|
||||
self._add_rate_limit_headers(response, client_id, user_type)
|
||||
|
||||
return response
|
||||
|
||||
def _is_exempt_path(self, path: str) -> bool:
|
||||
"""Check if path is exempt from rate limiting."""
|
||||
return path in self.exempt_paths
|
||||
|
||||
def _get_client_id(self, request: Request) -> str:
|
||||
"""Get unique client identifier for rate limiting."""
|
||||
# Try to get user ID from request state (set by auth middleware)
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
return f"user:{request.state.user['id']}"
|
||||
|
||||
# Fall back to IP address
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
|
||||
# Include user agent for better identification
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
user_agent_hash = str(hash(user_agent))[:8]
|
||||
|
||||
return f"ip:{client_ip}:{user_agent_hash}"
|
||||
|
||||
def _get_user_type(self, request: Request) -> str:
|
||||
"""Determine user type for rate limiting."""
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
if request.state.user.get("is_admin", False):
|
||||
return "admin"
|
||||
return "authenticated"
|
||||
return "anonymous"
|
||||
|
||||
def _check_rate_limits(self, client_id: str, path: str, user_type: str) -> Dict:
|
||||
"""Check if request is within rate limits."""
|
||||
now = time.time()
|
||||
|
||||
# Get applicable rate limits
|
||||
general_limit = self.rate_limits[user_type]
|
||||
path_limit = self.path_limits.get(path)
|
||||
|
||||
# Check general rate limit
|
||||
general_result = self._check_limit(
|
||||
client_id,
|
||||
"general",
|
||||
general_limit["requests"],
|
||||
general_limit["window"],
|
||||
now
|
||||
)
|
||||
|
||||
if not general_result["allowed"]:
|
||||
return general_result
|
||||
|
||||
# Check path-specific rate limit if exists
|
||||
if path_limit:
|
||||
path_result = self._check_limit(
|
||||
client_id,
|
||||
f"path:{path}",
|
||||
path_limit["requests"],
|
||||
path_limit["window"],
|
||||
now
|
||||
)
|
||||
|
||||
if not path_result["allowed"]:
|
||||
return path_result
|
||||
|
||||
return {"allowed": True}
|
||||
|
||||
def _check_limit(self, client_id: str, limit_type: str, max_requests: int, window: int, now: float) -> Dict:
|
||||
"""Check specific rate limit using sliding window."""
|
||||
key = f"{client_id}:{limit_type}"
|
||||
requests = self.request_counts[key]
|
||||
|
||||
# Remove old requests outside the window
|
||||
cutoff = now - window
|
||||
while requests and requests[0] <= cutoff:
|
||||
requests.popleft()
|
||||
|
||||
# Check if limit exceeded
|
||||
if len(requests) >= max_requests:
|
||||
# Calculate retry after time
|
||||
oldest_request = requests[0] if requests else now
|
||||
retry_after = int(oldest_request + window - now) + 1
|
||||
|
||||
return {
|
||||
"allowed": False,
|
||||
"message": f"Rate limit exceeded: {max_requests} requests per {window} seconds",
|
||||
"retry_after": retry_after,
|
||||
"current_count": len(requests),
|
||||
"limit": max_requests,
|
||||
"window": window
|
||||
}
|
||||
|
||||
return {
|
||||
"allowed": True,
|
||||
"current_count": len(requests),
|
||||
"limit": max_requests,
|
||||
"window": window
|
||||
}
|
||||
|
||||
def _record_request(self, client_id: str, path: str):
|
||||
"""Record a request for rate limiting."""
|
||||
now = time.time()
|
||||
|
||||
# Record general request
|
||||
general_key = f"{client_id}:general"
|
||||
self.request_counts[general_key].append(now)
|
||||
|
||||
# Record path-specific request if path has specific limits
|
||||
if path in self.path_limits:
|
||||
path_key = f"{client_id}:path:{path}"
|
||||
self.request_counts[path_key].append(now)
|
||||
|
||||
def _is_client_blocked(self, client_id: str) -> bool:
|
||||
"""Check if client is temporarily blocked."""
|
||||
if client_id in self.blocked_clients:
|
||||
block_until = self.blocked_clients[client_id]
|
||||
if time.time() < block_until:
|
||||
return True
|
||||
else:
|
||||
# Block expired, remove it
|
||||
del self.blocked_clients[client_id]
|
||||
return False
|
||||
|
||||
def _block_client(self, client_id: str, duration: int):
|
||||
"""Temporarily block a client."""
|
||||
self.blocked_clients[client_id] = time.time() + duration
|
||||
logger.warning(f"Client {client_id} blocked for {duration} seconds due to rate limit violations")
|
||||
|
||||
def _create_rate_limit_response(self, message: str, retry_after: int = 60) -> JSONResponse:
|
||||
"""Create rate limit exceeded response."""
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": {
|
||||
"code": 429,
|
||||
"message": message,
|
||||
"type": "rate_limit_exceeded"
|
||||
}
|
||||
},
|
||||
headers={
|
||||
"Retry-After": str(retry_after),
|
||||
"X-RateLimit-Limit": "Exceeded",
|
||||
"X-RateLimit-Remaining": "0"
|
||||
}
|
||||
)
|
||||
|
||||
def _add_rate_limit_headers(self, response: Response, client_id: str, user_type: str):
|
||||
"""Add rate limit headers to response."""
|
||||
try:
|
||||
general_limit = self.rate_limits[user_type]
|
||||
general_key = f"{client_id}:general"
|
||||
current_requests = len(self.request_counts[general_key])
|
||||
|
||||
remaining = max(0, general_limit["requests"] - current_requests)
|
||||
|
||||
response.headers["X-RateLimit-Limit"] = str(general_limit["requests"])
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
response.headers["X-RateLimit-Window"] = str(general_limit["window"])
|
||||
|
||||
# Add reset time
|
||||
if self.request_counts[general_key]:
|
||||
oldest_request = self.request_counts[general_key][0]
|
||||
reset_time = int(oldest_request + general_limit["window"])
|
||||
response.headers["X-RateLimit-Reset"] = str(reset_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding rate limit headers: {e}")
|
||||
|
||||
def _log_rate_limit_violation(self, request: Request, client_id: str, result: Dict):
|
||||
"""Log rate limit violations for monitoring."""
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
user_agent = request.headers.get("user-agent", "unknown")
|
||||
|
||||
log_data = {
|
||||
"event_type": "rate_limit_violation",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"client_id": client_id,
|
||||
"client_ip": client_ip,
|
||||
"user_agent": user_agent,
|
||||
"path": request.url.path,
|
||||
"method": request.method,
|
||||
"current_count": result.get("current_count"),
|
||||
"limit": result.get("limit"),
|
||||
"window": result.get("window")
|
||||
}
|
||||
|
||||
logger.warning(f"Rate limit violation: {log_data}")
|
||||
|
||||
def cleanup_old_data(self):
|
||||
"""Clean up old rate limiting data (call periodically)."""
|
||||
now = time.time()
|
||||
cutoff = now - 3600 # Keep data for 1 hour
|
||||
|
||||
# Clean up request counts
|
||||
for key in list(self.request_counts.keys()):
|
||||
requests = self.request_counts[key]
|
||||
while requests and requests[0] <= cutoff:
|
||||
requests.popleft()
|
||||
|
||||
# Remove empty deques
|
||||
if not requests:
|
||||
del self.request_counts[key]
|
||||
|
||||
# Clean up expired blocks
|
||||
expired_blocks = [
|
||||
client_id for client_id, block_until in self.blocked_clients.items()
|
||||
if now >= block_until
|
||||
]
|
||||
|
||||
for client_id in expired_blocks:
|
||||
del self.blocked_clients[client_id]
|
||||
|
||||
|
||||
class AdaptiveRateLimit:
|
||||
"""Adaptive rate limiting based on system load."""
|
||||
|
||||
def __init__(self):
|
||||
self.base_limits = {}
|
||||
self.current_multiplier = 1.0
|
||||
self.load_history = deque(maxlen=60) # Keep 1 minute of load data
|
||||
|
||||
def update_system_load(self, cpu_percent: float, memory_percent: float):
|
||||
"""Update system load metrics."""
|
||||
load_score = (cpu_percent + memory_percent) / 2
|
||||
self.load_history.append(load_score)
|
||||
|
||||
# Calculate adaptive multiplier
|
||||
if len(self.load_history) >= 10:
|
||||
avg_load = sum(self.load_history) / len(self.load_history)
|
||||
|
||||
if avg_load > 80:
|
||||
self.current_multiplier = 0.5 # Reduce limits by 50%
|
||||
elif avg_load > 60:
|
||||
self.current_multiplier = 0.7 # Reduce limits by 30%
|
||||
elif avg_load < 30:
|
||||
self.current_multiplier = 1.2 # Increase limits by 20%
|
||||
else:
|
||||
self.current_multiplier = 1.0 # Normal limits
|
||||
|
||||
def get_adjusted_limit(self, base_limit: int) -> int:
|
||||
"""Get adjusted rate limit based on system load."""
|
||||
return max(1, int(base_limit * self.current_multiplier))
|
||||
|
||||
|
||||
class RateLimitStorage:
|
||||
"""Abstract interface for rate limit storage (Redis implementation)."""
|
||||
|
||||
async def get_count(self, key: str, window: int) -> int:
|
||||
"""Get current request count for key within window."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def increment(self, key: str, window: int) -> int:
|
||||
"""Increment request count and return new count."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def is_blocked(self, client_id: str) -> bool:
|
||||
"""Check if client is blocked."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def block_client(self, client_id: str, duration: int):
|
||||
"""Block client for duration seconds."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RedisRateLimitStorage(RateLimitStorage):
|
||||
"""Redis-based rate limit storage for production use."""
|
||||
|
||||
def __init__(self, redis_client):
|
||||
self.redis = redis_client
|
||||
|
||||
async def get_count(self, key: str, window: int) -> int:
|
||||
"""Get current request count using Redis sliding window."""
|
||||
now = time.time()
|
||||
pipeline = self.redis.pipeline()
|
||||
|
||||
# Remove old entries
|
||||
pipeline.zremrangebyscore(key, 0, now - window)
|
||||
|
||||
# Count current entries
|
||||
pipeline.zcard(key)
|
||||
|
||||
results = await pipeline.execute()
|
||||
return results[1]
|
||||
|
||||
async def increment(self, key: str, window: int) -> int:
|
||||
"""Increment request count using Redis."""
|
||||
now = time.time()
|
||||
pipeline = self.redis.pipeline()
|
||||
|
||||
# Add current request
|
||||
pipeline.zadd(key, {str(now): now})
|
||||
|
||||
# Remove old entries
|
||||
pipeline.zremrangebyscore(key, 0, now - window)
|
||||
|
||||
# Set expiration
|
||||
pipeline.expire(key, window + 1)
|
||||
|
||||
# Get count
|
||||
pipeline.zcard(key)
|
||||
|
||||
results = await pipeline.execute()
|
||||
return results[3]
|
||||
|
||||
async def is_blocked(self, client_id: str) -> bool:
|
||||
"""Check if client is blocked."""
|
||||
block_key = f"blocked:{client_id}"
|
||||
return await self.redis.exists(block_key)
|
||||
|
||||
async def block_client(self, client_id: str, duration: int):
|
||||
"""Block client for duration seconds."""
|
||||
block_key = f"blocked:{client_id}"
|
||||
await self.redis.setex(block_key, duration, "1")
|
||||
|
||||
|
||||
# Global adaptive rate limiter instance
|
||||
adaptive_rate_limit = AdaptiveRateLimit()
|
||||
7
v1/src/api/routers/__init__.py
Normal file
7
v1/src/api/routers/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
API routers package
|
||||
"""
|
||||
|
||||
from . import pose, stream, health
|
||||
|
||||
__all__ = ["pose", "stream", "health"]
|
||||
419
v1/src/api/routers/health.py
Normal file
419
v1/src/api/routers/health.py
Normal file
@@ -0,0 +1,419 @@
|
||||
"""
|
||||
Health check API endpoints
|
||||
"""
|
||||
|
||||
import logging
|
||||
import psutil
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.api.dependencies import get_current_user
|
||||
from src.config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Response models
|
||||
class ComponentHealth(BaseModel):
|
||||
"""Health status for a system component."""
|
||||
|
||||
name: str = Field(..., description="Component name")
|
||||
status: str = Field(..., description="Health status (healthy, degraded, unhealthy)")
|
||||
message: Optional[str] = Field(default=None, description="Status message")
|
||||
last_check: datetime = Field(..., description="Last health check timestamp")
|
||||
uptime_seconds: Optional[float] = Field(default=None, description="Component uptime")
|
||||
metrics: Optional[Dict[str, Any]] = Field(default=None, description="Component metrics")
|
||||
|
||||
|
||||
class SystemHealth(BaseModel):
|
||||
"""Overall system health status."""
|
||||
|
||||
status: str = Field(..., description="Overall system status")
|
||||
timestamp: datetime = Field(..., description="Health check timestamp")
|
||||
uptime_seconds: float = Field(..., description="System uptime")
|
||||
components: Dict[str, ComponentHealth] = Field(..., description="Component health status")
|
||||
system_metrics: Dict[str, Any] = Field(..., description="System-level metrics")
|
||||
|
||||
|
||||
class ReadinessCheck(BaseModel):
|
||||
"""System readiness check result."""
|
||||
|
||||
ready: bool = Field(..., description="Whether system is ready to serve requests")
|
||||
timestamp: datetime = Field(..., description="Readiness check timestamp")
|
||||
checks: Dict[str, bool] = Field(..., description="Individual readiness checks")
|
||||
message: str = Field(..., description="Readiness status message")
|
||||
|
||||
|
||||
# Health check endpoints
|
||||
@router.get("/health", response_model=SystemHealth)
|
||||
async def health_check(request: Request):
|
||||
"""Comprehensive system health check."""
|
||||
try:
|
||||
# Get services from app state
|
||||
hardware_service = getattr(request.app.state, 'hardware_service', None)
|
||||
pose_service = getattr(request.app.state, 'pose_service', None)
|
||||
stream_service = getattr(request.app.state, 'stream_service', None)
|
||||
|
||||
timestamp = datetime.utcnow()
|
||||
components = {}
|
||||
overall_status = "healthy"
|
||||
|
||||
# Check hardware service
|
||||
if hardware_service:
|
||||
try:
|
||||
hw_health = await hardware_service.health_check()
|
||||
components["hardware"] = ComponentHealth(
|
||||
name="Hardware Service",
|
||||
status=hw_health["status"],
|
||||
message=hw_health.get("message"),
|
||||
last_check=timestamp,
|
||||
uptime_seconds=hw_health.get("uptime_seconds"),
|
||||
metrics=hw_health.get("metrics")
|
||||
)
|
||||
|
||||
if hw_health["status"] != "healthy":
|
||||
overall_status = "degraded" if overall_status == "healthy" else "unhealthy"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Hardware service health check failed: {e}")
|
||||
components["hardware"] = ComponentHealth(
|
||||
name="Hardware Service",
|
||||
status="unhealthy",
|
||||
message=f"Health check failed: {str(e)}",
|
||||
last_check=timestamp
|
||||
)
|
||||
overall_status = "unhealthy"
|
||||
else:
|
||||
components["hardware"] = ComponentHealth(
|
||||
name="Hardware Service",
|
||||
status="unavailable",
|
||||
message="Service not initialized",
|
||||
last_check=timestamp
|
||||
)
|
||||
overall_status = "degraded"
|
||||
|
||||
# Check pose service
|
||||
if pose_service:
|
||||
try:
|
||||
pose_health = await pose_service.health_check()
|
||||
components["pose"] = ComponentHealth(
|
||||
name="Pose Service",
|
||||
status=pose_health["status"],
|
||||
message=pose_health.get("message"),
|
||||
last_check=timestamp,
|
||||
uptime_seconds=pose_health.get("uptime_seconds"),
|
||||
metrics=pose_health.get("metrics")
|
||||
)
|
||||
|
||||
if pose_health["status"] != "healthy":
|
||||
overall_status = "degraded" if overall_status == "healthy" else "unhealthy"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pose service health check failed: {e}")
|
||||
components["pose"] = ComponentHealth(
|
||||
name="Pose Service",
|
||||
status="unhealthy",
|
||||
message=f"Health check failed: {str(e)}",
|
||||
last_check=timestamp
|
||||
)
|
||||
overall_status = "unhealthy"
|
||||
else:
|
||||
components["pose"] = ComponentHealth(
|
||||
name="Pose Service",
|
||||
status="unavailable",
|
||||
message="Service not initialized",
|
||||
last_check=timestamp
|
||||
)
|
||||
overall_status = "degraded"
|
||||
|
||||
# Check stream service
|
||||
if stream_service:
|
||||
try:
|
||||
stream_health = await stream_service.health_check()
|
||||
components["stream"] = ComponentHealth(
|
||||
name="Stream Service",
|
||||
status=stream_health["status"],
|
||||
message=stream_health.get("message"),
|
||||
last_check=timestamp,
|
||||
uptime_seconds=stream_health.get("uptime_seconds"),
|
||||
metrics=stream_health.get("metrics")
|
||||
)
|
||||
|
||||
if stream_health["status"] != "healthy":
|
||||
overall_status = "degraded" if overall_status == "healthy" else "unhealthy"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Stream service health check failed: {e}")
|
||||
components["stream"] = ComponentHealth(
|
||||
name="Stream Service",
|
||||
status="unhealthy",
|
||||
message=f"Health check failed: {str(e)}",
|
||||
last_check=timestamp
|
||||
)
|
||||
overall_status = "unhealthy"
|
||||
else:
|
||||
components["stream"] = ComponentHealth(
|
||||
name="Stream Service",
|
||||
status="unavailable",
|
||||
message="Service not initialized",
|
||||
last_check=timestamp
|
||||
)
|
||||
overall_status = "degraded"
|
||||
|
||||
# Get system metrics
|
||||
system_metrics = get_system_metrics()
|
||||
|
||||
# Calculate system uptime (placeholder - would need actual startup time)
|
||||
uptime_seconds = 0.0 # TODO: Implement actual uptime tracking
|
||||
|
||||
return SystemHealth(
|
||||
status=overall_status,
|
||||
timestamp=timestamp,
|
||||
uptime_seconds=uptime_seconds,
|
||||
components=components,
|
||||
system_metrics=system_metrics
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Health check failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Health check failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/ready", response_model=ReadinessCheck)
|
||||
async def readiness_check(request: Request):
|
||||
"""Check if system is ready to serve requests."""
|
||||
try:
|
||||
timestamp = datetime.utcnow()
|
||||
checks = {}
|
||||
|
||||
# Check if services are available in app state
|
||||
if hasattr(request.app.state, 'pose_service') and request.app.state.pose_service:
|
||||
try:
|
||||
checks["pose_ready"] = await request.app.state.pose_service.is_ready()
|
||||
except Exception as e:
|
||||
logger.warning(f"Pose service readiness check failed: {e}")
|
||||
checks["pose_ready"] = False
|
||||
else:
|
||||
checks["pose_ready"] = False
|
||||
|
||||
if hasattr(request.app.state, 'stream_service') and request.app.state.stream_service:
|
||||
try:
|
||||
checks["stream_ready"] = await request.app.state.stream_service.is_ready()
|
||||
except Exception as e:
|
||||
logger.warning(f"Stream service readiness check failed: {e}")
|
||||
checks["stream_ready"] = False
|
||||
else:
|
||||
checks["stream_ready"] = False
|
||||
|
||||
# Hardware service check (basic availability)
|
||||
checks["hardware_ready"] = True # Basic readiness - API is responding
|
||||
|
||||
# Check system resources
|
||||
checks["memory_available"] = check_memory_availability()
|
||||
checks["disk_space_available"] = check_disk_space()
|
||||
|
||||
# Application is ready if at least the basic services are available
|
||||
# For now, we'll consider it ready if the API is responding
|
||||
ready = True # Basic readiness
|
||||
|
||||
message = "System is ready" if ready else "System is not ready"
|
||||
if not ready:
|
||||
failed_checks = [name for name, status in checks.items() if not status]
|
||||
message += f". Failed checks: {', '.join(failed_checks)}"
|
||||
|
||||
return ReadinessCheck(
|
||||
ready=ready,
|
||||
timestamp=timestamp,
|
||||
checks=checks,
|
||||
message=message
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Readiness check failed: {e}")
|
||||
return ReadinessCheck(
|
||||
ready=False,
|
||||
timestamp=datetime.utcnow(),
|
||||
checks={},
|
||||
message=f"Readiness check failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/live")
|
||||
async def liveness_check():
|
||||
"""Simple liveness check for load balancers."""
|
||||
return {
|
||||
"status": "alive",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/metrics")
|
||||
async def get_health_metrics(
|
||||
request: Request,
|
||||
current_user: Optional[Dict] = Depends(get_current_user)
|
||||
):
|
||||
"""Get detailed system metrics."""
|
||||
try:
|
||||
metrics = get_system_metrics()
|
||||
|
||||
# Add additional metrics if authenticated
|
||||
if current_user:
|
||||
metrics.update(get_detailed_metrics())
|
||||
|
||||
return {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"metrics": metrics
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting system metrics: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get system metrics: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/version")
|
||||
async def get_version_info():
|
||||
"""Get application version information."""
|
||||
settings = get_settings()
|
||||
|
||||
return {
|
||||
"name": settings.app_name,
|
||||
"version": settings.version,
|
||||
"environment": settings.environment,
|
||||
"debug": settings.debug,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
def get_system_metrics() -> Dict[str, Any]:
|
||||
"""Get basic system metrics."""
|
||||
try:
|
||||
# CPU metrics
|
||||
cpu_percent = psutil.cpu_percent(interval=1)
|
||||
cpu_count = psutil.cpu_count()
|
||||
|
||||
# Memory metrics
|
||||
memory = psutil.virtual_memory()
|
||||
memory_metrics = {
|
||||
"total_gb": round(memory.total / (1024**3), 2),
|
||||
"available_gb": round(memory.available / (1024**3), 2),
|
||||
"used_gb": round(memory.used / (1024**3), 2),
|
||||
"percent": memory.percent
|
||||
}
|
||||
|
||||
# Disk metrics
|
||||
disk = psutil.disk_usage('/')
|
||||
disk_metrics = {
|
||||
"total_gb": round(disk.total / (1024**3), 2),
|
||||
"free_gb": round(disk.free / (1024**3), 2),
|
||||
"used_gb": round(disk.used / (1024**3), 2),
|
||||
"percent": round((disk.used / disk.total) * 100, 2)
|
||||
}
|
||||
|
||||
# Network metrics (basic)
|
||||
network = psutil.net_io_counters()
|
||||
network_metrics = {
|
||||
"bytes_sent": network.bytes_sent,
|
||||
"bytes_recv": network.bytes_recv,
|
||||
"packets_sent": network.packets_sent,
|
||||
"packets_recv": network.packets_recv
|
||||
}
|
||||
|
||||
return {
|
||||
"cpu": {
|
||||
"percent": cpu_percent,
|
||||
"count": cpu_count
|
||||
},
|
||||
"memory": memory_metrics,
|
||||
"disk": disk_metrics,
|
||||
"network": network_metrics
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting system metrics: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def get_detailed_metrics() -> Dict[str, Any]:
|
||||
"""Get detailed system metrics (requires authentication)."""
|
||||
try:
|
||||
# Process metrics
|
||||
process = psutil.Process()
|
||||
process_metrics = {
|
||||
"pid": process.pid,
|
||||
"cpu_percent": process.cpu_percent(),
|
||||
"memory_mb": round(process.memory_info().rss / (1024**2), 2),
|
||||
"num_threads": process.num_threads(),
|
||||
"create_time": datetime.fromtimestamp(process.create_time()).isoformat()
|
||||
}
|
||||
|
||||
# Load average (Unix-like systems)
|
||||
load_avg = None
|
||||
try:
|
||||
load_avg = psutil.getloadavg()
|
||||
except AttributeError:
|
||||
# Windows doesn't have load average
|
||||
pass
|
||||
|
||||
# Temperature sensors (if available)
|
||||
temperatures = {}
|
||||
try:
|
||||
temps = psutil.sensors_temperatures()
|
||||
for name, entries in temps.items():
|
||||
temperatures[name] = [
|
||||
{"label": entry.label, "current": entry.current}
|
||||
for entry in entries
|
||||
]
|
||||
except AttributeError:
|
||||
# Not available on all systems
|
||||
pass
|
||||
|
||||
detailed = {
|
||||
"process": process_metrics
|
||||
}
|
||||
|
||||
if load_avg:
|
||||
detailed["load_average"] = {
|
||||
"1min": load_avg[0],
|
||||
"5min": load_avg[1],
|
||||
"15min": load_avg[2]
|
||||
}
|
||||
|
||||
if temperatures:
|
||||
detailed["temperatures"] = temperatures
|
||||
|
||||
return detailed
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting detailed metrics: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def check_memory_availability() -> bool:
|
||||
"""Check if sufficient memory is available."""
|
||||
try:
|
||||
memory = psutil.virtual_memory()
|
||||
# Consider system ready if less than 90% memory is used
|
||||
return memory.percent < 90.0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def check_disk_space() -> bool:
|
||||
"""Check if sufficient disk space is available."""
|
||||
try:
|
||||
disk = psutil.disk_usage('/')
|
||||
# Consider system ready if more than 1GB free space
|
||||
free_gb = disk.free / (1024**3)
|
||||
return free_gb > 1.0
|
||||
except Exception:
|
||||
return False
|
||||
420
v1/src/api/routers/pose.py
Normal file
420
v1/src/api/routers/pose.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""
|
||||
Pose estimation API endpoints
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.api.dependencies import (
|
||||
get_pose_service,
|
||||
get_hardware_service,
|
||||
get_current_user,
|
||||
require_auth
|
||||
)
|
||||
from src.services.pose_service import PoseService
|
||||
from src.services.hardware_service import HardwareService
|
||||
from src.config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Request/Response models
|
||||
class PoseEstimationRequest(BaseModel):
|
||||
"""Request model for pose estimation."""
|
||||
|
||||
zone_ids: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Specific zones to analyze (all zones if not specified)"
|
||||
)
|
||||
confidence_threshold: Optional[float] = Field(
|
||||
default=None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum confidence threshold for detections"
|
||||
)
|
||||
max_persons: Optional[int] = Field(
|
||||
default=None,
|
||||
ge=1,
|
||||
le=50,
|
||||
description="Maximum number of persons to detect"
|
||||
)
|
||||
include_keypoints: bool = Field(
|
||||
default=True,
|
||||
description="Include detailed keypoint data"
|
||||
)
|
||||
include_segmentation: bool = Field(
|
||||
default=False,
|
||||
description="Include DensePose segmentation masks"
|
||||
)
|
||||
|
||||
|
||||
class PersonPose(BaseModel):
|
||||
"""Person pose data model."""
|
||||
|
||||
person_id: str = Field(..., description="Unique person identifier")
|
||||
confidence: float = Field(..., description="Detection confidence score")
|
||||
bounding_box: Dict[str, float] = Field(..., description="Person bounding box")
|
||||
keypoints: Optional[List[Dict[str, Any]]] = Field(
|
||||
default=None,
|
||||
description="Body keypoints with coordinates and confidence"
|
||||
)
|
||||
segmentation: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="DensePose segmentation data"
|
||||
)
|
||||
zone_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Zone where person is detected"
|
||||
)
|
||||
activity: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Detected activity"
|
||||
)
|
||||
timestamp: datetime = Field(..., description="Detection timestamp")
|
||||
|
||||
|
||||
class PoseEstimationResponse(BaseModel):
|
||||
"""Response model for pose estimation."""
|
||||
|
||||
timestamp: datetime = Field(..., description="Analysis timestamp")
|
||||
frame_id: str = Field(..., description="Unique frame identifier")
|
||||
persons: List[PersonPose] = Field(..., description="Detected persons")
|
||||
zone_summary: Dict[str, int] = Field(..., description="Person count per zone")
|
||||
processing_time_ms: float = Field(..., description="Processing time in milliseconds")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
|
||||
|
||||
|
||||
class HistoricalDataRequest(BaseModel):
|
||||
"""Request model for historical pose data."""
|
||||
|
||||
start_time: datetime = Field(..., description="Start time for data query")
|
||||
end_time: datetime = Field(..., description="End time for data query")
|
||||
zone_ids: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Filter by specific zones"
|
||||
)
|
||||
aggregation_interval: Optional[int] = Field(
|
||||
default=300,
|
||||
ge=60,
|
||||
le=3600,
|
||||
description="Aggregation interval in seconds"
|
||||
)
|
||||
include_raw_data: bool = Field(
|
||||
default=False,
|
||||
description="Include raw detection data"
|
||||
)
|
||||
|
||||
|
||||
# Endpoints
|
||||
@router.get("/current", response_model=PoseEstimationResponse)
|
||||
async def get_current_pose_estimation(
|
||||
request: PoseEstimationRequest = Depends(),
|
||||
pose_service: PoseService = Depends(get_pose_service),
|
||||
current_user: Optional[Dict] = Depends(get_current_user)
|
||||
):
|
||||
"""Get current pose estimation from WiFi signals."""
|
||||
try:
|
||||
logger.info(f"Processing pose estimation request from user: {current_user.get('id') if current_user else 'anonymous'}")
|
||||
|
||||
# Get current pose estimation
|
||||
result = await pose_service.estimate_poses(
|
||||
zone_ids=request.zone_ids,
|
||||
confidence_threshold=request.confidence_threshold,
|
||||
max_persons=request.max_persons,
|
||||
include_keypoints=request.include_keypoints,
|
||||
include_segmentation=request.include_segmentation
|
||||
)
|
||||
|
||||
return PoseEstimationResponse(**result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in pose estimation: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Pose estimation failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/analyze", response_model=PoseEstimationResponse)
|
||||
async def analyze_pose_data(
|
||||
request: PoseEstimationRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
pose_service: PoseService = Depends(get_pose_service),
|
||||
current_user: Dict = Depends(require_auth)
|
||||
):
|
||||
"""Trigger pose analysis with custom parameters."""
|
||||
try:
|
||||
logger.info(f"Custom pose analysis requested by user: {current_user['id']}")
|
||||
|
||||
# Trigger analysis
|
||||
result = await pose_service.analyze_with_params(
|
||||
zone_ids=request.zone_ids,
|
||||
confidence_threshold=request.confidence_threshold,
|
||||
max_persons=request.max_persons,
|
||||
include_keypoints=request.include_keypoints,
|
||||
include_segmentation=request.include_segmentation
|
||||
)
|
||||
|
||||
# Schedule background processing if needed
|
||||
if request.include_segmentation:
|
||||
background_tasks.add_task(
|
||||
pose_service.process_segmentation_data,
|
||||
result["frame_id"]
|
||||
)
|
||||
|
||||
return PoseEstimationResponse(**result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in pose analysis: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Pose analysis failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/zones/{zone_id}/occupancy")
|
||||
async def get_zone_occupancy(
|
||||
zone_id: str,
|
||||
pose_service: PoseService = Depends(get_pose_service),
|
||||
current_user: Optional[Dict] = Depends(get_current_user)
|
||||
):
|
||||
"""Get current occupancy for a specific zone."""
|
||||
try:
|
||||
occupancy = await pose_service.get_zone_occupancy(zone_id)
|
||||
|
||||
if occupancy is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Zone '{zone_id}' not found"
|
||||
)
|
||||
|
||||
return {
|
||||
"zone_id": zone_id,
|
||||
"current_occupancy": occupancy["count"],
|
||||
"max_occupancy": occupancy.get("max_occupancy"),
|
||||
"persons": occupancy["persons"],
|
||||
"timestamp": occupancy["timestamp"]
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting zone occupancy: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get zone occupancy: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/zones/summary")
|
||||
async def get_zones_summary(
|
||||
pose_service: PoseService = Depends(get_pose_service),
|
||||
current_user: Optional[Dict] = Depends(get_current_user)
|
||||
):
|
||||
"""Get occupancy summary for all zones."""
|
||||
try:
|
||||
summary = await pose_service.get_zones_summary()
|
||||
|
||||
return {
|
||||
"timestamp": datetime.utcnow(),
|
||||
"total_persons": summary["total_persons"],
|
||||
"zones": summary["zones"],
|
||||
"active_zones": summary["active_zones"]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting zones summary: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get zones summary: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/historical")
|
||||
async def get_historical_data(
|
||||
request: HistoricalDataRequest,
|
||||
pose_service: PoseService = Depends(get_pose_service),
|
||||
current_user: Dict = Depends(require_auth)
|
||||
):
|
||||
"""Get historical pose estimation data."""
|
||||
try:
|
||||
# Validate time range
|
||||
if request.end_time <= request.start_time:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="End time must be after start time"
|
||||
)
|
||||
|
||||
# Limit query range to prevent excessive data
|
||||
max_range = timedelta(days=7)
|
||||
if request.end_time - request.start_time > max_range:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Query range cannot exceed 7 days"
|
||||
)
|
||||
|
||||
data = await pose_service.get_historical_data(
|
||||
start_time=request.start_time,
|
||||
end_time=request.end_time,
|
||||
zone_ids=request.zone_ids,
|
||||
aggregation_interval=request.aggregation_interval,
|
||||
include_raw_data=request.include_raw_data
|
||||
)
|
||||
|
||||
return {
|
||||
"query": {
|
||||
"start_time": request.start_time,
|
||||
"end_time": request.end_time,
|
||||
"zone_ids": request.zone_ids,
|
||||
"aggregation_interval": request.aggregation_interval
|
||||
},
|
||||
"data": data["aggregated_data"],
|
||||
"raw_data": data.get("raw_data") if request.include_raw_data else None,
|
||||
"total_records": data["total_records"]
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting historical data: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get historical data: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/activities")
|
||||
async def get_detected_activities(
|
||||
zone_id: Optional[str] = Query(None, description="Filter by zone ID"),
|
||||
limit: int = Query(10, ge=1, le=100, description="Maximum number of activities"),
|
||||
pose_service: PoseService = Depends(get_pose_service),
|
||||
current_user: Optional[Dict] = Depends(get_current_user)
|
||||
):
|
||||
"""Get recently detected activities."""
|
||||
try:
|
||||
activities = await pose_service.get_recent_activities(
|
||||
zone_id=zone_id,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return {
|
||||
"activities": activities,
|
||||
"total_count": len(activities),
|
||||
"zone_id": zone_id
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting activities: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get activities: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/calibrate")
|
||||
async def calibrate_pose_system(
|
||||
background_tasks: BackgroundTasks,
|
||||
pose_service: PoseService = Depends(get_pose_service),
|
||||
hardware_service: HardwareService = Depends(get_hardware_service),
|
||||
current_user: Dict = Depends(require_auth)
|
||||
):
|
||||
"""Calibrate the pose estimation system."""
|
||||
try:
|
||||
logger.info(f"Pose system calibration initiated by user: {current_user['id']}")
|
||||
|
||||
# Check if calibration is already in progress
|
||||
if await pose_service.is_calibrating():
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Calibration already in progress"
|
||||
)
|
||||
|
||||
# Start calibration process
|
||||
calibration_id = await pose_service.start_calibration()
|
||||
|
||||
# Schedule background calibration task
|
||||
background_tasks.add_task(
|
||||
pose_service.run_calibration,
|
||||
calibration_id
|
||||
)
|
||||
|
||||
return {
|
||||
"calibration_id": calibration_id,
|
||||
"status": "started",
|
||||
"estimated_duration_minutes": 5,
|
||||
"message": "Calibration process started"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting calibration: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to start calibration: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/calibration/status")
|
||||
async def get_calibration_status(
|
||||
pose_service: PoseService = Depends(get_pose_service),
|
||||
current_user: Dict = Depends(require_auth)
|
||||
):
|
||||
"""Get current calibration status."""
|
||||
try:
|
||||
status = await pose_service.get_calibration_status()
|
||||
|
||||
return {
|
||||
"is_calibrating": status["is_calibrating"],
|
||||
"calibration_id": status.get("calibration_id"),
|
||||
"progress_percent": status.get("progress_percent", 0),
|
||||
"current_step": status.get("current_step"),
|
||||
"estimated_remaining_minutes": status.get("estimated_remaining_minutes"),
|
||||
"last_calibration": status.get("last_calibration")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting calibration status: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get calibration status: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_pose_statistics(
|
||||
hours: int = Query(24, ge=1, le=168, description="Hours of data to analyze"),
|
||||
pose_service: PoseService = Depends(get_pose_service),
|
||||
current_user: Optional[Dict] = Depends(get_current_user)
|
||||
):
|
||||
"""Get pose estimation statistics."""
|
||||
try:
|
||||
end_time = datetime.utcnow()
|
||||
start_time = end_time - timedelta(hours=hours)
|
||||
|
||||
stats = await pose_service.get_statistics(
|
||||
start_time=start_time,
|
||||
end_time=end_time
|
||||
)
|
||||
|
||||
return {
|
||||
"period": {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"hours": hours
|
||||
},
|
||||
"statistics": stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting statistics: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get statistics: {str(e)}"
|
||||
)
|
||||
468
v1/src/api/routers/stream.py
Normal file
468
v1/src/api/routers/stream.py
Normal file
@@ -0,0 +1,468 @@
|
||||
"""
|
||||
WebSocket streaming API endpoints
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, HTTPException, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.api.dependencies import (
|
||||
get_stream_service,
|
||||
get_pose_service,
|
||||
get_current_user_ws,
|
||||
require_auth
|
||||
)
|
||||
from src.api.websocket.connection_manager import ConnectionManager
|
||||
from src.services.stream_service import StreamService
|
||||
from src.services.pose_service import PoseService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
# Initialize connection manager
|
||||
connection_manager = ConnectionManager()
|
||||
|
||||
|
||||
# Request/Response models
|
||||
class StreamSubscriptionRequest(BaseModel):
|
||||
"""Request model for stream subscription."""
|
||||
|
||||
zone_ids: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Zones to subscribe to (all zones if not specified)"
|
||||
)
|
||||
stream_types: List[str] = Field(
|
||||
default=["pose_data"],
|
||||
description="Types of data to stream"
|
||||
)
|
||||
min_confidence: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum confidence threshold for streaming"
|
||||
)
|
||||
max_fps: int = Field(
|
||||
default=30,
|
||||
ge=1,
|
||||
le=60,
|
||||
description="Maximum frames per second"
|
||||
)
|
||||
include_metadata: bool = Field(
|
||||
default=True,
|
||||
description="Include metadata in stream"
|
||||
)
|
||||
|
||||
|
||||
class StreamStatus(BaseModel):
|
||||
"""Stream status model."""
|
||||
|
||||
is_active: bool = Field(..., description="Whether streaming is active")
|
||||
connected_clients: int = Field(..., description="Number of connected clients")
|
||||
streams: List[Dict[str, Any]] = Field(..., description="Active streams")
|
||||
uptime_seconds: float = Field(..., description="Stream uptime in seconds")
|
||||
|
||||
|
||||
# WebSocket endpoints
|
||||
@router.websocket("/pose")
|
||||
async def websocket_pose_stream(
|
||||
websocket: WebSocket,
|
||||
zone_ids: Optional[str] = Query(None, description="Comma-separated zone IDs"),
|
||||
min_confidence: float = Query(0.5, ge=0.0, le=1.0),
|
||||
max_fps: int = Query(30, ge=1, le=60),
|
||||
token: Optional[str] = Query(None, description="Authentication token")
|
||||
):
|
||||
"""WebSocket endpoint for real-time pose data streaming."""
|
||||
client_id = None
|
||||
|
||||
try:
|
||||
# Accept WebSocket connection
|
||||
await websocket.accept()
|
||||
|
||||
# Check authentication if enabled
|
||||
from src.config.settings import get_settings
|
||||
settings = get_settings()
|
||||
|
||||
if settings.enable_authentication and not token:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Authentication token required"
|
||||
})
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
# Parse zone IDs
|
||||
zone_list = None
|
||||
if zone_ids:
|
||||
zone_list = [zone.strip() for zone in zone_ids.split(",") if zone.strip()]
|
||||
|
||||
# Register client with connection manager
|
||||
client_id = await connection_manager.connect(
|
||||
websocket=websocket,
|
||||
stream_type="pose",
|
||||
zone_ids=zone_list,
|
||||
min_confidence=min_confidence,
|
||||
max_fps=max_fps
|
||||
)
|
||||
|
||||
logger.info(f"WebSocket client {client_id} connected for pose streaming")
|
||||
|
||||
# Send initial connection confirmation
|
||||
await websocket.send_json({
|
||||
"type": "connection_established",
|
||||
"client_id": client_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"config": {
|
||||
"zone_ids": zone_list,
|
||||
"min_confidence": min_confidence,
|
||||
"max_fps": max_fps
|
||||
}
|
||||
})
|
||||
|
||||
# Keep connection alive and handle incoming messages
|
||||
while True:
|
||||
try:
|
||||
# Wait for client messages (ping, config updates, etc.)
|
||||
message = await websocket.receive_text()
|
||||
data = json.loads(message)
|
||||
|
||||
await handle_websocket_message(client_id, data, websocket)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Invalid JSON format"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling WebSocket message: {e}")
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Internal server error"
|
||||
})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket client {client_id} disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error: {e}")
|
||||
finally:
|
||||
if client_id:
|
||||
await connection_manager.disconnect(client_id)
|
||||
|
||||
|
||||
@router.websocket("/events")
|
||||
async def websocket_events_stream(
|
||||
websocket: WebSocket,
|
||||
event_types: Optional[str] = Query(None, description="Comma-separated event types"),
|
||||
zone_ids: Optional[str] = Query(None, description="Comma-separated zone IDs"),
|
||||
token: Optional[str] = Query(None, description="Authentication token")
|
||||
):
|
||||
"""WebSocket endpoint for real-time event streaming."""
|
||||
client_id = None
|
||||
|
||||
try:
|
||||
await websocket.accept()
|
||||
|
||||
# Check authentication if enabled
|
||||
from src.config.settings import get_settings
|
||||
settings = get_settings()
|
||||
|
||||
if settings.enable_authentication and not token:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "Authentication token required"
|
||||
})
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
# Parse parameters
|
||||
event_list = None
|
||||
if event_types:
|
||||
event_list = [event.strip() for event in event_types.split(",") if event.strip()]
|
||||
|
||||
zone_list = None
|
||||
if zone_ids:
|
||||
zone_list = [zone.strip() for zone in zone_ids.split(",") if zone.strip()]
|
||||
|
||||
# Register client
|
||||
client_id = await connection_manager.connect(
|
||||
websocket=websocket,
|
||||
stream_type="events",
|
||||
zone_ids=zone_list,
|
||||
event_types=event_list
|
||||
)
|
||||
|
||||
logger.info(f"WebSocket client {client_id} connected for event streaming")
|
||||
|
||||
# Send confirmation
|
||||
await websocket.send_json({
|
||||
"type": "connection_established",
|
||||
"client_id": client_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"config": {
|
||||
"event_types": event_list,
|
||||
"zone_ids": zone_list
|
||||
}
|
||||
})
|
||||
|
||||
# Handle messages
|
||||
while True:
|
||||
try:
|
||||
message = await websocket.receive_text()
|
||||
data = json.loads(message)
|
||||
await handle_websocket_message(client_id, data, websocket)
|
||||
except WebSocketDisconnect:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in events WebSocket: {e}")
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"Events WebSocket client {client_id} disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"Events WebSocket error: {e}")
|
||||
finally:
|
||||
if client_id:
|
||||
await connection_manager.disconnect(client_id)
|
||||
|
||||
|
||||
async def handle_websocket_message(client_id: str, data: Dict[str, Any], websocket: WebSocket):
|
||||
"""Handle incoming WebSocket messages."""
|
||||
message_type = data.get("type")
|
||||
|
||||
if message_type == "ping":
|
||||
await websocket.send_json({
|
||||
"type": "pong",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
|
||||
elif message_type == "update_config":
|
||||
# Update client configuration
|
||||
config = data.get("config", {})
|
||||
await connection_manager.update_client_config(client_id, config)
|
||||
|
||||
await websocket.send_json({
|
||||
"type": "config_updated",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"config": config
|
||||
})
|
||||
|
||||
elif message_type == "get_status":
|
||||
# Send current status
|
||||
status = await connection_manager.get_client_status(client_id)
|
||||
await websocket.send_json({
|
||||
"type": "status",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"status": status
|
||||
})
|
||||
|
||||
else:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": f"Unknown message type: {message_type}"
|
||||
})
|
||||
|
||||
|
||||
# HTTP endpoints for stream management
|
||||
@router.get("/status", response_model=StreamStatus)
|
||||
async def get_stream_status(
|
||||
stream_service: StreamService = Depends(get_stream_service)
|
||||
):
|
||||
"""Get current streaming status."""
|
||||
try:
|
||||
status = await stream_service.get_status()
|
||||
connections = await connection_manager.get_connection_stats()
|
||||
|
||||
# Calculate uptime (simplified for now)
|
||||
uptime_seconds = 0.0
|
||||
if status.get("running", False):
|
||||
uptime_seconds = 3600.0 # Default 1 hour for demo
|
||||
|
||||
return StreamStatus(
|
||||
is_active=status.get("running", False),
|
||||
connected_clients=connections.get("total_clients", status["connections"]["active"]),
|
||||
streams=[{
|
||||
"type": "pose_stream",
|
||||
"active": status.get("running", False),
|
||||
"buffer_size": status["buffers"]["pose_buffer_size"]
|
||||
}],
|
||||
uptime_seconds=uptime_seconds
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting stream status: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get stream status: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/start")
|
||||
async def start_streaming(
|
||||
stream_service: StreamService = Depends(get_stream_service),
|
||||
current_user: Dict = Depends(require_auth)
|
||||
):
|
||||
"""Start the streaming service."""
|
||||
try:
|
||||
logger.info(f"Starting streaming service by user: {current_user['id']}")
|
||||
|
||||
if await stream_service.is_active():
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={"message": "Streaming service is already active"}
|
||||
)
|
||||
|
||||
await stream_service.start()
|
||||
|
||||
return {
|
||||
"message": "Streaming service started successfully",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting streaming: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to start streaming: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/stop")
|
||||
async def stop_streaming(
|
||||
stream_service: StreamService = Depends(get_stream_service),
|
||||
current_user: Dict = Depends(require_auth)
|
||||
):
|
||||
"""Stop the streaming service."""
|
||||
try:
|
||||
logger.info(f"Stopping streaming service by user: {current_user['id']}")
|
||||
|
||||
await stream_service.stop()
|
||||
await connection_manager.disconnect_all()
|
||||
|
||||
return {
|
||||
"message": "Streaming service stopped successfully",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping streaming: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to stop streaming: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/clients")
|
||||
async def get_connected_clients(
|
||||
current_user: Dict = Depends(require_auth)
|
||||
):
|
||||
"""Get list of connected WebSocket clients."""
|
||||
try:
|
||||
clients = await connection_manager.get_connected_clients()
|
||||
|
||||
return {
|
||||
"total_clients": len(clients),
|
||||
"clients": clients,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting connected clients: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get connected clients: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/clients/{client_id}")
|
||||
async def disconnect_client(
|
||||
client_id: str,
|
||||
current_user: Dict = Depends(require_auth)
|
||||
):
|
||||
"""Disconnect a specific WebSocket client."""
|
||||
try:
|
||||
logger.info(f"Disconnecting client {client_id} by user: {current_user['id']}")
|
||||
|
||||
success = await connection_manager.disconnect(client_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Client {client_id} not found"
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Client {client_id} disconnected successfully",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error disconnecting client: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to disconnect client: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/broadcast")
|
||||
async def broadcast_message(
|
||||
message: Dict[str, Any],
|
||||
stream_type: Optional[str] = Query(None, description="Target stream type"),
|
||||
zone_ids: Optional[List[str]] = Query(None, description="Target zone IDs"),
|
||||
current_user: Dict = Depends(require_auth)
|
||||
):
|
||||
"""Broadcast a message to connected WebSocket clients."""
|
||||
try:
|
||||
logger.info(f"Broadcasting message by user: {current_user['id']}")
|
||||
|
||||
# Add metadata to message
|
||||
broadcast_data = {
|
||||
**message,
|
||||
"broadcast_timestamp": datetime.utcnow().isoformat(),
|
||||
"sender": current_user["id"]
|
||||
}
|
||||
|
||||
# Broadcast to matching clients
|
||||
sent_count = await connection_manager.broadcast(
|
||||
data=broadcast_data,
|
||||
stream_type=stream_type,
|
||||
zone_ids=zone_ids
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Broadcast sent successfully",
|
||||
"recipients": sent_count,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error broadcasting message: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to broadcast message: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/metrics")
|
||||
async def get_streaming_metrics():
|
||||
"""Get streaming performance metrics."""
|
||||
try:
|
||||
metrics = await connection_manager.get_metrics()
|
||||
|
||||
return {
|
||||
"metrics": metrics,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting streaming metrics: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get streaming metrics: {str(e)}"
|
||||
)
|
||||
8
v1/src/api/websocket/__init__.py
Normal file
8
v1/src/api/websocket/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
WebSocket handlers package
|
||||
"""
|
||||
|
||||
from .connection_manager import ConnectionManager
|
||||
from .pose_stream import PoseStreamHandler
|
||||
|
||||
__all__ = ["ConnectionManager", "PoseStreamHandler"]
|
||||
461
v1/src/api/websocket/connection_manager.py
Normal file
461
v1/src/api/websocket/connection_manager.py
Normal file
@@ -0,0 +1,461 @@
|
||||
"""
|
||||
WebSocket connection manager for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Any, Set
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict
|
||||
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebSocketConnection:
|
||||
"""Represents a WebSocket connection with metadata."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
client_id: str,
|
||||
stream_type: str,
|
||||
zone_ids: Optional[List[str]] = None,
|
||||
**config
|
||||
):
|
||||
self.websocket = websocket
|
||||
self.client_id = client_id
|
||||
self.stream_type = stream_type
|
||||
self.zone_ids = zone_ids or []
|
||||
self.config = config
|
||||
self.connected_at = datetime.utcnow()
|
||||
self.last_ping = datetime.utcnow()
|
||||
self.message_count = 0
|
||||
self.is_active = True
|
||||
|
||||
async def send_json(self, data: Dict[str, Any]):
|
||||
"""Send JSON data to client."""
|
||||
try:
|
||||
await self.websocket.send_json(data)
|
||||
self.message_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending to client {self.client_id}: {e}")
|
||||
self.is_active = False
|
||||
raise
|
||||
|
||||
async def send_text(self, message: str):
|
||||
"""Send text message to client."""
|
||||
try:
|
||||
await self.websocket.send_text(message)
|
||||
self.message_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending text to client {self.client_id}: {e}")
|
||||
self.is_active = False
|
||||
raise
|
||||
|
||||
def update_config(self, config: Dict[str, Any]):
|
||||
"""Update connection configuration."""
|
||||
self.config.update(config)
|
||||
|
||||
# Update zone IDs if provided
|
||||
if "zone_ids" in config:
|
||||
self.zone_ids = config["zone_ids"] or []
|
||||
|
||||
def matches_filter(
|
||||
self,
|
||||
stream_type: Optional[str] = None,
|
||||
zone_ids: Optional[List[str]] = None,
|
||||
**filters
|
||||
) -> bool:
|
||||
"""Check if connection matches given filters."""
|
||||
# Check stream type
|
||||
if stream_type and self.stream_type != stream_type:
|
||||
return False
|
||||
|
||||
# Check zone IDs
|
||||
if zone_ids:
|
||||
if not self.zone_ids: # Connection listens to all zones
|
||||
return True
|
||||
# Check if any requested zone is in connection's zones
|
||||
if not any(zone in self.zone_ids for zone in zone_ids):
|
||||
return False
|
||||
|
||||
# Check additional filters
|
||||
for key, value in filters.items():
|
||||
if key in self.config and self.config[key] != value:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_info(self) -> Dict[str, Any]:
|
||||
"""Get connection information."""
|
||||
return {
|
||||
"client_id": self.client_id,
|
||||
"stream_type": self.stream_type,
|
||||
"zone_ids": self.zone_ids,
|
||||
"config": self.config,
|
||||
"connected_at": self.connected_at.isoformat(),
|
||||
"last_ping": self.last_ping.isoformat(),
|
||||
"message_count": self.message_count,
|
||||
"is_active": self.is_active,
|
||||
"uptime_seconds": (datetime.utcnow() - self.connected_at).total_seconds()
|
||||
}
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""Manages WebSocket connections for real-time streaming."""
|
||||
|
||||
def __init__(self):
|
||||
self.connections: Dict[str, WebSocketConnection] = {}
|
||||
self.connections_by_type: Dict[str, Set[str]] = defaultdict(set)
|
||||
self.connections_by_zone: Dict[str, Set[str]] = defaultdict(set)
|
||||
self.metrics = {
|
||||
"total_connections": 0,
|
||||
"active_connections": 0,
|
||||
"messages_sent": 0,
|
||||
"errors": 0,
|
||||
"start_time": datetime.utcnow()
|
||||
}
|
||||
self._cleanup_task = None
|
||||
self._started = False
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
stream_type: str,
|
||||
zone_ids: Optional[List[str]] = None,
|
||||
**config
|
||||
) -> str:
|
||||
"""Register a new WebSocket connection."""
|
||||
client_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Create connection object
|
||||
connection = WebSocketConnection(
|
||||
websocket=websocket,
|
||||
client_id=client_id,
|
||||
stream_type=stream_type,
|
||||
zone_ids=zone_ids,
|
||||
**config
|
||||
)
|
||||
|
||||
# Store connection
|
||||
self.connections[client_id] = connection
|
||||
self.connections_by_type[stream_type].add(client_id)
|
||||
|
||||
# Index by zones
|
||||
if zone_ids:
|
||||
for zone_id in zone_ids:
|
||||
self.connections_by_zone[zone_id].add(client_id)
|
||||
|
||||
# Update metrics
|
||||
self.metrics["total_connections"] += 1
|
||||
self.metrics["active_connections"] = len(self.connections)
|
||||
|
||||
logger.info(f"WebSocket client {client_id} connected for {stream_type}")
|
||||
|
||||
return client_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error connecting WebSocket client: {e}")
|
||||
raise
|
||||
|
||||
async def disconnect(self, client_id: str) -> bool:
|
||||
"""Disconnect a WebSocket client."""
|
||||
if client_id not in self.connections:
|
||||
return False
|
||||
|
||||
try:
|
||||
connection = self.connections[client_id]
|
||||
|
||||
# Remove from indexes
|
||||
self.connections_by_type[connection.stream_type].discard(client_id)
|
||||
|
||||
for zone_id in connection.zone_ids:
|
||||
self.connections_by_zone[zone_id].discard(client_id)
|
||||
|
||||
# Close WebSocket if still active
|
||||
if connection.is_active:
|
||||
try:
|
||||
await connection.websocket.close()
|
||||
except:
|
||||
pass # Connection might already be closed
|
||||
|
||||
# Remove connection
|
||||
del self.connections[client_id]
|
||||
|
||||
# Update metrics
|
||||
self.metrics["active_connections"] = len(self.connections)
|
||||
|
||||
logger.info(f"WebSocket client {client_id} disconnected")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error disconnecting client {client_id}: {e}")
|
||||
return False
|
||||
|
||||
async def disconnect_all(self):
|
||||
"""Disconnect all WebSocket clients."""
|
||||
client_ids = list(self.connections.keys())
|
||||
|
||||
for client_id in client_ids:
|
||||
await self.disconnect(client_id)
|
||||
|
||||
logger.info("All WebSocket clients disconnected")
|
||||
|
||||
async def send_to_client(self, client_id: str, data: Dict[str, Any]) -> bool:
|
||||
"""Send data to a specific client."""
|
||||
if client_id not in self.connections:
|
||||
return False
|
||||
|
||||
connection = self.connections[client_id]
|
||||
|
||||
try:
|
||||
await connection.send_json(data)
|
||||
self.metrics["messages_sent"] += 1
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending to client {client_id}: {e}")
|
||||
self.metrics["errors"] += 1
|
||||
|
||||
# Mark connection as inactive and schedule for cleanup
|
||||
connection.is_active = False
|
||||
return False
|
||||
|
||||
async def broadcast(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
stream_type: Optional[str] = None,
|
||||
zone_ids: Optional[List[str]] = None,
|
||||
**filters
|
||||
) -> int:
|
||||
"""Broadcast data to matching clients."""
|
||||
sent_count = 0
|
||||
failed_clients = []
|
||||
|
||||
# Get matching connections
|
||||
matching_clients = self._get_matching_clients(
|
||||
stream_type=stream_type,
|
||||
zone_ids=zone_ids,
|
||||
**filters
|
||||
)
|
||||
|
||||
# Send to all matching clients
|
||||
for client_id in matching_clients:
|
||||
try:
|
||||
success = await self.send_to_client(client_id, data)
|
||||
if success:
|
||||
sent_count += 1
|
||||
else:
|
||||
failed_clients.append(client_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error broadcasting to client {client_id}: {e}")
|
||||
failed_clients.append(client_id)
|
||||
|
||||
# Clean up failed connections
|
||||
for client_id in failed_clients:
|
||||
await self.disconnect(client_id)
|
||||
|
||||
return sent_count
|
||||
|
||||
async def update_client_config(self, client_id: str, config: Dict[str, Any]) -> bool:
|
||||
"""Update client configuration."""
|
||||
if client_id not in self.connections:
|
||||
return False
|
||||
|
||||
connection = self.connections[client_id]
|
||||
old_zones = set(connection.zone_ids)
|
||||
|
||||
# Update configuration
|
||||
connection.update_config(config)
|
||||
|
||||
# Update zone indexes if zones changed
|
||||
new_zones = set(connection.zone_ids)
|
||||
|
||||
# Remove from old zones
|
||||
for zone_id in old_zones - new_zones:
|
||||
self.connections_by_zone[zone_id].discard(client_id)
|
||||
|
||||
# Add to new zones
|
||||
for zone_id in new_zones - old_zones:
|
||||
self.connections_by_zone[zone_id].add(client_id)
|
||||
|
||||
return True
|
||||
|
||||
async def get_client_status(self, client_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get status of a specific client."""
|
||||
if client_id not in self.connections:
|
||||
return None
|
||||
|
||||
return self.connections[client_id].get_info()
|
||||
|
||||
async def get_connected_clients(self) -> List[Dict[str, Any]]:
|
||||
"""Get list of all connected clients."""
|
||||
return [conn.get_info() for conn in self.connections.values()]
|
||||
|
||||
async def get_connection_stats(self) -> Dict[str, Any]:
|
||||
"""Get connection statistics."""
|
||||
stats = {
|
||||
"total_clients": len(self.connections),
|
||||
"clients_by_type": {
|
||||
stream_type: len(clients)
|
||||
for stream_type, clients in self.connections_by_type.items()
|
||||
},
|
||||
"clients_by_zone": {
|
||||
zone_id: len(clients)
|
||||
for zone_id, clients in self.connections_by_zone.items()
|
||||
if clients # Only include zones with active clients
|
||||
},
|
||||
"active_clients": sum(1 for conn in self.connections.values() if conn.is_active),
|
||||
"inactive_clients": sum(1 for conn in self.connections.values() if not conn.is_active)
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
async def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get detailed metrics."""
|
||||
uptime = (datetime.utcnow() - self.metrics["start_time"]).total_seconds()
|
||||
|
||||
return {
|
||||
**self.metrics,
|
||||
"active_connections": len(self.connections),
|
||||
"uptime_seconds": uptime,
|
||||
"messages_per_second": self.metrics["messages_sent"] / max(uptime, 1),
|
||||
"error_rate": self.metrics["errors"] / max(self.metrics["messages_sent"], 1)
|
||||
}
|
||||
|
||||
def _get_matching_clients(
|
||||
self,
|
||||
stream_type: Optional[str] = None,
|
||||
zone_ids: Optional[List[str]] = None,
|
||||
**filters
|
||||
) -> List[str]:
|
||||
"""Get client IDs that match the given filters."""
|
||||
candidates = set(self.connections.keys())
|
||||
|
||||
# Filter by stream type
|
||||
if stream_type:
|
||||
type_clients = self.connections_by_type.get(stream_type, set())
|
||||
candidates &= type_clients
|
||||
|
||||
# Filter by zones
|
||||
if zone_ids:
|
||||
zone_clients = set()
|
||||
for zone_id in zone_ids:
|
||||
zone_clients.update(self.connections_by_zone.get(zone_id, set()))
|
||||
|
||||
# Also include clients listening to all zones (empty zone list)
|
||||
all_zone_clients = {
|
||||
client_id for client_id, conn in self.connections.items()
|
||||
if not conn.zone_ids
|
||||
}
|
||||
zone_clients.update(all_zone_clients)
|
||||
|
||||
candidates &= zone_clients
|
||||
|
||||
# Apply additional filters
|
||||
matching_clients = []
|
||||
for client_id in candidates:
|
||||
connection = self.connections[client_id]
|
||||
if connection.is_active and connection.matches_filter(**filters):
|
||||
matching_clients.append(client_id)
|
||||
|
||||
return matching_clients
|
||||
|
||||
async def ping_clients(self):
|
||||
"""Send ping to all connected clients."""
|
||||
ping_data = {
|
||||
"type": "ping",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
failed_clients = []
|
||||
|
||||
for client_id, connection in self.connections.items():
|
||||
try:
|
||||
await connection.send_json(ping_data)
|
||||
connection.last_ping = datetime.utcnow()
|
||||
except Exception as e:
|
||||
logger.warning(f"Ping failed for client {client_id}: {e}")
|
||||
failed_clients.append(client_id)
|
||||
|
||||
# Clean up failed connections
|
||||
for client_id in failed_clients:
|
||||
await self.disconnect(client_id)
|
||||
|
||||
async def cleanup_inactive_connections(self):
|
||||
"""Clean up inactive or stale connections."""
|
||||
now = datetime.utcnow()
|
||||
stale_threshold = timedelta(minutes=5) # 5 minutes without ping
|
||||
|
||||
stale_clients = []
|
||||
|
||||
for client_id, connection in self.connections.items():
|
||||
# Check if connection is inactive
|
||||
if not connection.is_active:
|
||||
stale_clients.append(client_id)
|
||||
continue
|
||||
|
||||
# Check if connection is stale (no ping response)
|
||||
if now - connection.last_ping > stale_threshold:
|
||||
logger.warning(f"Client {client_id} appears stale, disconnecting")
|
||||
stale_clients.append(client_id)
|
||||
|
||||
# Clean up stale connections
|
||||
for client_id in stale_clients:
|
||||
await self.disconnect(client_id)
|
||||
|
||||
if stale_clients:
|
||||
logger.info(f"Cleaned up {len(stale_clients)} stale connections")
|
||||
|
||||
async def start(self):
|
||||
"""Start the connection manager."""
|
||||
if not self._started:
|
||||
self._start_cleanup_task()
|
||||
self._started = True
|
||||
logger.info("Connection manager started")
|
||||
|
||||
def _start_cleanup_task(self):
|
||||
"""Start background cleanup task."""
|
||||
async def cleanup_loop():
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(60) # Run every minute
|
||||
await self.cleanup_inactive_connections()
|
||||
|
||||
# Send periodic ping every 2 minutes
|
||||
if datetime.utcnow().minute % 2 == 0:
|
||||
await self.ping_clients()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup task: {e}")
|
||||
|
||||
try:
|
||||
self._cleanup_task = asyncio.create_task(cleanup_loop())
|
||||
except RuntimeError:
|
||||
# No event loop running, will start later
|
||||
logger.debug("No event loop running, cleanup task will start later")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown connection manager."""
|
||||
# Cancel cleanup task
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Disconnect all clients
|
||||
await self.disconnect_all()
|
||||
|
||||
logger.info("Connection manager shutdown complete")
|
||||
|
||||
|
||||
# Global connection manager instance
|
||||
connection_manager = ConnectionManager()
|
||||
384
v1/src/api/websocket/pose_stream.py
Normal file
384
v1/src/api/websocket/pose_stream.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""
|
||||
Pose streaming WebSocket handler
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import WebSocket
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.api.websocket.connection_manager import ConnectionManager
|
||||
from src.services.pose_service import PoseService
|
||||
from src.services.stream_service import StreamService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PoseStreamData(BaseModel):
|
||||
"""Pose stream data model."""
|
||||
|
||||
timestamp: datetime = Field(..., description="Data timestamp")
|
||||
zone_id: str = Field(..., description="Zone identifier")
|
||||
pose_data: Dict[str, Any] = Field(..., description="Pose estimation data")
|
||||
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score")
|
||||
activity: Optional[str] = Field(default=None, description="Detected activity")
|
||||
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Additional metadata")
|
||||
|
||||
|
||||
class PoseStreamHandler:
|
||||
"""Handles pose data streaming to WebSocket clients."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_manager: ConnectionManager,
|
||||
pose_service: PoseService,
|
||||
stream_service: StreamService
|
||||
):
|
||||
self.connection_manager = connection_manager
|
||||
self.pose_service = pose_service
|
||||
self.stream_service = stream_service
|
||||
self.is_streaming = False
|
||||
self.stream_task = None
|
||||
self.subscribers = {}
|
||||
self.stream_config = {
|
||||
"fps": 30,
|
||||
"min_confidence": 0.5,
|
||||
"include_metadata": True,
|
||||
"buffer_size": 100
|
||||
}
|
||||
|
||||
async def start_streaming(self):
|
||||
"""Start pose data streaming."""
|
||||
if self.is_streaming:
|
||||
logger.warning("Pose streaming already active")
|
||||
return
|
||||
|
||||
self.is_streaming = True
|
||||
self.stream_task = asyncio.create_task(self._stream_loop())
|
||||
logger.info("Pose streaming started")
|
||||
|
||||
async def stop_streaming(self):
|
||||
"""Stop pose data streaming."""
|
||||
if not self.is_streaming:
|
||||
return
|
||||
|
||||
self.is_streaming = False
|
||||
|
||||
if self.stream_task:
|
||||
self.stream_task.cancel()
|
||||
try:
|
||||
await self.stream_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("Pose streaming stopped")
|
||||
|
||||
async def _stream_loop(self):
|
||||
"""Main streaming loop."""
|
||||
try:
|
||||
logger.info("🚀 Starting pose streaming loop")
|
||||
while self.is_streaming:
|
||||
try:
|
||||
# Get current pose data from all zones
|
||||
logger.debug("📡 Getting current pose data...")
|
||||
pose_data = await self.pose_service.get_current_pose_data()
|
||||
logger.debug(f"📊 Received pose data: {pose_data}")
|
||||
|
||||
if pose_data:
|
||||
logger.debug("📤 Broadcasting pose data...")
|
||||
await self._process_and_broadcast_pose_data(pose_data)
|
||||
else:
|
||||
logger.debug("⚠️ No pose data received")
|
||||
|
||||
# Control streaming rate
|
||||
await asyncio.sleep(1.0 / self.stream_config["fps"])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in pose streaming loop: {e}")
|
||||
await asyncio.sleep(1.0) # Brief pause on error
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Pose streaming loop cancelled")
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error in pose streaming loop: {e}")
|
||||
finally:
|
||||
logger.info("🛑 Pose streaming loop stopped")
|
||||
self.is_streaming = False
|
||||
|
||||
async def _process_and_broadcast_pose_data(self, raw_pose_data: Dict[str, Any]):
|
||||
"""Process and broadcast pose data to subscribers."""
|
||||
try:
|
||||
# Process data for each zone
|
||||
for zone_id, zone_data in raw_pose_data.items():
|
||||
if not zone_data:
|
||||
continue
|
||||
|
||||
# Create structured pose data
|
||||
pose_stream_data = PoseStreamData(
|
||||
timestamp=datetime.utcnow(),
|
||||
zone_id=zone_id,
|
||||
pose_data=zone_data.get("pose", {}),
|
||||
confidence=zone_data.get("confidence", 0.0),
|
||||
activity=zone_data.get("activity"),
|
||||
metadata=zone_data.get("metadata") if self.stream_config["include_metadata"] else None
|
||||
)
|
||||
|
||||
# Filter by minimum confidence
|
||||
if pose_stream_data.confidence < self.stream_config["min_confidence"]:
|
||||
continue
|
||||
|
||||
# Broadcast to subscribers
|
||||
await self._broadcast_pose_data(pose_stream_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing pose data: {e}")
|
||||
|
||||
async def _broadcast_pose_data(self, pose_data: PoseStreamData):
|
||||
"""Broadcast pose data to matching WebSocket clients."""
|
||||
try:
|
||||
logger.debug(f"📡 Preparing to broadcast pose data for zone {pose_data.zone_id}")
|
||||
|
||||
# Prepare broadcast data
|
||||
broadcast_data = {
|
||||
"type": "pose_data",
|
||||
"timestamp": pose_data.timestamp.isoformat(),
|
||||
"zone_id": pose_data.zone_id,
|
||||
"data": {
|
||||
"pose": pose_data.pose_data,
|
||||
"confidence": pose_data.confidence,
|
||||
"activity": pose_data.activity
|
||||
}
|
||||
}
|
||||
|
||||
# Add metadata if enabled
|
||||
if pose_data.metadata and self.stream_config["include_metadata"]:
|
||||
broadcast_data["metadata"] = pose_data.metadata
|
||||
|
||||
logger.debug(f"📤 Broadcasting data: {broadcast_data}")
|
||||
|
||||
# Broadcast to pose stream subscribers
|
||||
sent_count = await self.connection_manager.broadcast(
|
||||
data=broadcast_data,
|
||||
stream_type="pose",
|
||||
zone_ids=[pose_data.zone_id]
|
||||
)
|
||||
|
||||
logger.info(f"✅ Broadcasted pose data for zone {pose_data.zone_id} to {sent_count} clients")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error broadcasting pose data: {e}")
|
||||
|
||||
async def handle_client_subscription(
|
||||
self,
|
||||
client_id: str,
|
||||
subscription_config: Dict[str, Any]
|
||||
):
|
||||
"""Handle client subscription configuration."""
|
||||
try:
|
||||
# Store client subscription config
|
||||
self.subscribers[client_id] = {
|
||||
"zone_ids": subscription_config.get("zone_ids", []),
|
||||
"min_confidence": subscription_config.get("min_confidence", 0.5),
|
||||
"max_fps": subscription_config.get("max_fps", 30),
|
||||
"include_metadata": subscription_config.get("include_metadata", True),
|
||||
"stream_types": subscription_config.get("stream_types", ["pose_data"]),
|
||||
"subscribed_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
logger.info(f"Updated subscription for client {client_id}")
|
||||
|
||||
# Send confirmation
|
||||
confirmation = {
|
||||
"type": "subscription_updated",
|
||||
"client_id": client_id,
|
||||
"config": self.subscribers[client_id],
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await self.connection_manager.send_to_client(client_id, confirmation)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling client subscription: {e}")
|
||||
|
||||
async def handle_client_disconnect(self, client_id: str):
|
||||
"""Handle client disconnection."""
|
||||
if client_id in self.subscribers:
|
||||
del self.subscribers[client_id]
|
||||
logger.info(f"Removed subscription for disconnected client {client_id}")
|
||||
|
||||
async def send_historical_data(
|
||||
self,
|
||||
client_id: str,
|
||||
zone_id: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
limit: int = 100
|
||||
):
|
||||
"""Send historical pose data to client."""
|
||||
try:
|
||||
# Get historical data from pose service
|
||||
historical_data = await self.pose_service.get_historical_data(
|
||||
zone_id=zone_id,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
# Send data in chunks to avoid overwhelming the client
|
||||
chunk_size = 10
|
||||
for i in range(0, len(historical_data), chunk_size):
|
||||
chunk = historical_data[i:i + chunk_size]
|
||||
|
||||
message = {
|
||||
"type": "historical_data",
|
||||
"zone_id": zone_id,
|
||||
"chunk_index": i // chunk_size,
|
||||
"total_chunks": (len(historical_data) + chunk_size - 1) // chunk_size,
|
||||
"data": chunk,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await self.connection_manager.send_to_client(client_id, message)
|
||||
|
||||
# Small delay between chunks
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Send completion message
|
||||
completion_message = {
|
||||
"type": "historical_data_complete",
|
||||
"zone_id": zone_id,
|
||||
"total_records": len(historical_data),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await self.connection_manager.send_to_client(client_id, completion_message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending historical data: {e}")
|
||||
|
||||
# Send error message to client
|
||||
error_message = {
|
||||
"type": "error",
|
||||
"message": f"Failed to retrieve historical data: {str(e)}",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await self.connection_manager.send_to_client(client_id, error_message)
|
||||
|
||||
async def send_zone_statistics(self, client_id: str, zone_id: str):
|
||||
"""Send zone statistics to client."""
|
||||
try:
|
||||
# Get zone statistics
|
||||
stats = await self.pose_service.get_zone_statistics(zone_id)
|
||||
|
||||
message = {
|
||||
"type": "zone_statistics",
|
||||
"zone_id": zone_id,
|
||||
"statistics": stats,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await self.connection_manager.send_to_client(client_id, message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending zone statistics: {e}")
|
||||
|
||||
async def broadcast_system_event(self, event_type: str, event_data: Dict[str, Any]):
|
||||
"""Broadcast system events to all connected clients."""
|
||||
try:
|
||||
message = {
|
||||
"type": "system_event",
|
||||
"event_type": event_type,
|
||||
"data": event_data,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# Broadcast to all pose stream clients
|
||||
sent_count = await self.connection_manager.broadcast(
|
||||
data=message,
|
||||
stream_type="pose"
|
||||
)
|
||||
|
||||
logger.info(f"Broadcasted system event '{event_type}' to {sent_count} clients")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error broadcasting system event: {e}")
|
||||
|
||||
async def update_stream_config(self, config: Dict[str, Any]):
|
||||
"""Update streaming configuration."""
|
||||
try:
|
||||
# Validate and update configuration
|
||||
if "fps" in config:
|
||||
fps = max(1, min(60, config["fps"]))
|
||||
self.stream_config["fps"] = fps
|
||||
|
||||
if "min_confidence" in config:
|
||||
confidence = max(0.0, min(1.0, config["min_confidence"]))
|
||||
self.stream_config["min_confidence"] = confidence
|
||||
|
||||
if "include_metadata" in config:
|
||||
self.stream_config["include_metadata"] = bool(config["include_metadata"])
|
||||
|
||||
if "buffer_size" in config:
|
||||
buffer_size = max(10, min(1000, config["buffer_size"]))
|
||||
self.stream_config["buffer_size"] = buffer_size
|
||||
|
||||
logger.info(f"Updated stream configuration: {self.stream_config}")
|
||||
|
||||
# Broadcast configuration update to clients
|
||||
await self.broadcast_system_event("stream_config_updated", {
|
||||
"new_config": self.stream_config
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating stream configuration: {e}")
|
||||
|
||||
def get_stream_status(self) -> Dict[str, Any]:
|
||||
"""Get current streaming status."""
|
||||
return {
|
||||
"is_streaming": self.is_streaming,
|
||||
"config": self.stream_config,
|
||||
"subscriber_count": len(self.subscribers),
|
||||
"subscribers": {
|
||||
client_id: {
|
||||
"zone_ids": sub["zone_ids"],
|
||||
"min_confidence": sub["min_confidence"],
|
||||
"subscribed_at": sub["subscribed_at"].isoformat()
|
||||
}
|
||||
for client_id, sub in self.subscribers.items()
|
||||
}
|
||||
}
|
||||
|
||||
async def get_performance_metrics(self) -> Dict[str, Any]:
|
||||
"""Get streaming performance metrics."""
|
||||
try:
|
||||
# Get connection manager metrics
|
||||
conn_metrics = await self.connection_manager.get_metrics()
|
||||
|
||||
# Get pose service metrics
|
||||
pose_metrics = await self.pose_service.get_performance_metrics()
|
||||
|
||||
return {
|
||||
"streaming": {
|
||||
"is_active": self.is_streaming,
|
||||
"fps": self.stream_config["fps"],
|
||||
"subscriber_count": len(self.subscribers)
|
||||
},
|
||||
"connections": conn_metrics,
|
||||
"pose_service": pose_metrics,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting performance metrics: {e}")
|
||||
return {}
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown pose stream handler."""
|
||||
await self.stop_streaming()
|
||||
self.subscribers.clear()
|
||||
logger.info("Pose stream handler shutdown complete")
|
||||
Reference in New Issue
Block a user