feat: Complete Rust port of WiFi-DensePose with modular crates

Major changes:
- Organized Python v1 implementation into v1/ subdirectory
- Created Rust workspace with 9 modular crates:
  - wifi-densepose-core: Core types, traits, errors
  - wifi-densepose-signal: CSI processing, phase sanitization, FFT
  - wifi-densepose-nn: Neural network inference (ONNX/Candle/tch)
  - wifi-densepose-api: Axum-based REST/WebSocket API
  - wifi-densepose-db: SQLx database layer
  - wifi-densepose-config: Configuration management
  - wifi-densepose-hardware: Hardware abstraction
  - wifi-densepose-wasm: WebAssembly bindings
  - wifi-densepose-cli: Command-line interface

Documentation:
- ADR-001: Workspace structure
- ADR-002: Signal processing library selection
- ADR-003: Neural network inference strategy
- DDD domain model with bounded contexts

Testing:
- 69 tests passing across all crates
- Signal processing: 45 tests
- Neural networks: 21 tests
- Core: 3 doc tests

Performance targets:
- 10x faster CSI processing (~0.5ms vs ~5ms)
- 5x lower memory usage (~100MB vs ~500MB)
- WASM support for browser deployment
This commit is contained in:
Claude
2026-01-13 03:11:16 +00:00
parent 5101504b72
commit 6ed69a3d48
427 changed files with 90993 additions and 0 deletions

7
v1/src/api/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
"""
WiFi-DensePose FastAPI application package
"""
# API package - routers and dependencies are imported by app.py
__all__ = []

447
v1/src/api/dependencies.py Normal file
View File

@@ -0,0 +1,447 @@
"""
Dependency injection for WiFi-DensePose API
"""
import logging
from typing import Optional, Dict, Any
from functools import lru_cache
from fastapi import Depends, HTTPException, status, Request
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from src.config.settings import get_settings
from src.config.domains import get_domain_config
from src.services.pose_service import PoseService
from src.services.stream_service import StreamService
from src.services.hardware_service import HardwareService
logger = logging.getLogger(__name__)
# Security scheme for JWT authentication
security = HTTPBearer(auto_error=False)
# Service dependencies
@lru_cache()
def get_pose_service() -> PoseService:
"""Get pose service instance."""
settings = get_settings()
domain_config = get_domain_config()
return PoseService(
settings=settings,
domain_config=domain_config
)
@lru_cache()
def get_stream_service() -> StreamService:
"""Get stream service instance."""
settings = get_settings()
domain_config = get_domain_config()
return StreamService(
settings=settings,
domain_config=domain_config
)
@lru_cache()
def get_hardware_service() -> HardwareService:
"""Get hardware service instance."""
settings = get_settings()
domain_config = get_domain_config()
return HardwareService(
settings=settings,
domain_config=domain_config
)
# Authentication dependencies
async def get_current_user(
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
) -> Optional[Dict[str, Any]]:
"""Get current authenticated user."""
settings = get_settings()
# Skip authentication if disabled
if not settings.enable_authentication:
return None
# Check if user is already set by middleware
if hasattr(request.state, 'user') and request.state.user:
return request.state.user
# No credentials provided
if not credentials:
return None
# This would normally validate the JWT token
# For now, return a mock user for development
if settings.is_development:
return {
"id": "dev-user",
"username": "developer",
"email": "dev@example.com",
"is_admin": True,
"permissions": ["read", "write", "admin"]
}
# In production, implement proper JWT validation
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication not implemented",
headers={"WWW-Authenticate": "Bearer"},
)
async def get_current_active_user(
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
) -> Dict[str, Any]:
"""Get current active user (required authentication)."""
if not current_user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required",
headers={"WWW-Authenticate": "Bearer"},
)
# Check if user is active
if not current_user.get("is_active", True):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Inactive user"
)
return current_user
async def get_admin_user(
current_user: Dict[str, Any] = Depends(get_current_active_user)
) -> Dict[str, Any]:
"""Get current admin user (admin privileges required)."""
if not current_user.get("is_admin", False):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin privileges required"
)
return current_user
# Permission dependencies
def require_permission(permission: str):
"""Dependency factory for permission checking."""
async def check_permission(
current_user: Dict[str, Any] = Depends(get_current_active_user)
) -> Dict[str, Any]:
"""Check if user has required permission."""
user_permissions = current_user.get("permissions", [])
# Admin users have all permissions
if current_user.get("is_admin", False):
return current_user
# Check specific permission
if permission not in user_permissions:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Permission '{permission}' required"
)
return current_user
return check_permission
# Zone access dependencies
async def validate_zone_access(
zone_id: str,
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
) -> str:
"""Validate user access to a specific zone."""
domain_config = get_domain_config()
# Check if zone exists
zone = domain_config.get_zone(zone_id)
if not zone:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Zone '{zone_id}' not found"
)
# Check if zone is enabled
if not zone.enabled:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Zone '{zone_id}' is disabled"
)
# If authentication is enabled, check user access
if current_user:
# Admin users have access to all zones
if current_user.get("is_admin", False):
return zone_id
# Check user's zone permissions
user_zones = current_user.get("zones", [])
if user_zones and zone_id not in user_zones:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access denied to zone '{zone_id}'"
)
return zone_id
# Router access dependencies
async def validate_router_access(
router_id: str,
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
) -> str:
"""Validate user access to a specific router."""
domain_config = get_domain_config()
# Check if router exists
router = domain_config.get_router(router_id)
if not router:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Router '{router_id}' not found"
)
# Check if router is enabled
if not router.enabled:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Router '{router_id}' is disabled"
)
# If authentication is enabled, check user access
if current_user:
# Admin users have access to all routers
if current_user.get("is_admin", False):
return router_id
# Check user's router permissions
user_routers = current_user.get("routers", [])
if user_routers and router_id not in user_routers:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access denied to router '{router_id}'"
)
return router_id
# Service health dependencies
async def check_service_health(
request: Request,
service_name: str
) -> bool:
"""Check if a service is healthy."""
try:
if service_name == "pose":
service = getattr(request.app.state, 'pose_service', None)
elif service_name == "stream":
service = getattr(request.app.state, 'stream_service', None)
elif service_name == "hardware":
service = getattr(request.app.state, 'hardware_service', None)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unknown service: {service_name}"
)
if not service:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Service '{service_name}' not available"
)
# Check service health
status_info = await service.get_status()
if status_info.get("status") != "healthy":
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Service '{service_name}' is unhealthy: {status_info.get('error', 'Unknown error')}"
)
return True
except HTTPException:
raise
except Exception as e:
logger.error(f"Error checking service health for {service_name}: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Service '{service_name}' health check failed"
)
# Rate limiting dependencies
async def check_rate_limit(
request: Request,
current_user: Optional[Dict[str, Any]] = Depends(get_current_user)
) -> bool:
"""Check rate limiting status."""
settings = get_settings()
# Skip if rate limiting is disabled
if not settings.enable_rate_limiting:
return True
# Rate limiting is handled by middleware
# This dependency can be used for additional checks
return True
# Configuration dependencies
def get_zone_config(zone_id: str = Depends(validate_zone_access)):
"""Get zone configuration."""
domain_config = get_domain_config()
return domain_config.get_zone(zone_id)
def get_router_config(router_id: str = Depends(validate_router_access)):
"""Get router configuration."""
domain_config = get_domain_config()
return domain_config.get_router(router_id)
# Pagination dependencies
class PaginationParams:
"""Pagination parameters."""
def __init__(
self,
page: int = 1,
size: int = 20,
max_size: int = 100
):
if page < 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Page must be >= 1"
)
if size < 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Size must be >= 1"
)
if size > max_size:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Size must be <= {max_size}"
)
self.page = page
self.size = size
self.offset = (page - 1) * size
self.limit = size
def get_pagination_params(
page: int = 1,
size: int = 20
) -> PaginationParams:
"""Get pagination parameters."""
return PaginationParams(page=page, size=size)
# Query filter dependencies
class QueryFilters:
"""Common query filters."""
def __init__(
self,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
min_confidence: Optional[float] = None,
activity: Optional[str] = None
):
self.start_time = start_time
self.end_time = end_time
self.min_confidence = min_confidence
self.activity = activity
# Validate confidence
if min_confidence is not None:
if not 0.0 <= min_confidence <= 1.0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="min_confidence must be between 0.0 and 1.0"
)
def get_query_filters(
start_time: Optional[str] = None,
end_time: Optional[str] = None,
min_confidence: Optional[float] = None,
activity: Optional[str] = None
) -> QueryFilters:
"""Get query filters."""
return QueryFilters(
start_time=start_time,
end_time=end_time,
min_confidence=min_confidence,
activity=activity
)
# WebSocket dependencies
async def get_websocket_user(
websocket_token: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""Get user from WebSocket token."""
settings = get_settings()
# Skip authentication if disabled
if not settings.enable_authentication:
return None
# For development, return mock user
if settings.is_development:
return {
"id": "ws-user",
"username": "websocket_user",
"is_admin": False,
"permissions": ["read"]
}
# In production, implement proper token validation
return None
async def get_current_user_ws(
websocket_token: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""Get current user for WebSocket connections."""
return await get_websocket_user(websocket_token)
# Authentication requirement dependencies
async def require_auth(
current_user: Dict[str, Any] = Depends(get_current_active_user)
) -> Dict[str, Any]:
"""Require authentication for endpoint access."""
return current_user
# Development dependencies
async def development_only():
"""Dependency that only allows access in development."""
settings = get_settings()
if not settings.is_development:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Endpoint not available in production"
)
return True

421
v1/src/api/main.py Normal file
View File

@@ -0,0 +1,421 @@
"""
FastAPI application for WiFi-DensePose API
"""
import asyncio
import logging
import logging.config
from contextlib import asynccontextmanager
from typing import Dict, Any
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
from src.config.settings import get_settings
from src.config.domains import get_domain_config
from src.api.routers import pose, stream, health
from src.api.middleware.auth import AuthMiddleware
from src.api.middleware.rate_limit import RateLimitMiddleware
from src.api.dependencies import get_pose_service, get_stream_service, get_hardware_service
from src.api.websocket.connection_manager import connection_manager
from src.api.websocket.pose_stream import PoseStreamHandler
# Configure logging
settings = get_settings()
logging.config.dictConfig(settings.get_logging_config())
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager."""
logger.info("Starting WiFi-DensePose API...")
try:
# Initialize services
await initialize_services(app)
# Start background tasks
await start_background_tasks(app)
logger.info("WiFi-DensePose API started successfully")
yield
except Exception as e:
logger.error(f"Failed to start application: {e}")
raise
finally:
# Cleanup on shutdown
logger.info("Shutting down WiFi-DensePose API...")
await cleanup_services(app)
logger.info("WiFi-DensePose API shutdown complete")
async def initialize_services(app: FastAPI):
"""Initialize application services."""
try:
# Initialize hardware service
hardware_service = get_hardware_service()
await hardware_service.initialize()
# Initialize pose service
pose_service = get_pose_service()
await pose_service.initialize()
# Initialize stream service
stream_service = get_stream_service()
await stream_service.initialize()
# Initialize pose stream handler
pose_stream_handler = PoseStreamHandler(
connection_manager=connection_manager,
pose_service=pose_service,
stream_service=stream_service
)
# Store in app state for access in routes
app.state.hardware_service = hardware_service
app.state.pose_service = pose_service
app.state.stream_service = stream_service
app.state.pose_stream_handler = pose_stream_handler
logger.info("Services initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize services: {e}")
raise
async def start_background_tasks(app: FastAPI):
"""Start background tasks."""
try:
# Start pose service
pose_service = app.state.pose_service
await pose_service.start()
logger.info("Pose service started")
# Start pose streaming if enabled
if settings.enable_real_time_processing:
pose_stream_handler = app.state.pose_stream_handler
await pose_stream_handler.start_streaming()
logger.info("Background tasks started")
except Exception as e:
logger.error(f"Failed to start background tasks: {e}")
raise
async def cleanup_services(app: FastAPI):
"""Cleanup services on shutdown."""
try:
# Stop pose streaming
if hasattr(app.state, 'pose_stream_handler'):
await app.state.pose_stream_handler.shutdown()
# Shutdown connection manager
await connection_manager.shutdown()
# Cleanup services
if hasattr(app.state, 'stream_service'):
await app.state.stream_service.shutdown()
if hasattr(app.state, 'pose_service'):
await app.state.pose_service.stop()
if hasattr(app.state, 'hardware_service'):
await app.state.hardware_service.shutdown()
logger.info("Services cleaned up successfully")
except Exception as e:
logger.error(f"Error during cleanup: {e}")
# Create FastAPI application
app = FastAPI(
title=settings.app_name,
version=settings.version,
description="WiFi-based human pose estimation and activity recognition API",
docs_url=settings.docs_url if not settings.is_production else None,
redoc_url=settings.redoc_url if not settings.is_production else None,
openapi_url=settings.openapi_url if not settings.is_production else None,
lifespan=lifespan
)
# Add middleware
if settings.enable_rate_limiting:
app.add_middleware(RateLimitMiddleware)
if settings.enable_authentication:
app.add_middleware(AuthMiddleware)
# Add CORS middleware
cors_config = settings.get_cors_config()
app.add_middleware(
CORSMiddleware,
**cors_config
)
# Add trusted host middleware for production
if settings.is_production:
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=settings.allowed_hosts
)
# Exception handlers
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
"""Handle HTTP exceptions."""
return JSONResponse(
status_code=exc.status_code,
content={
"error": {
"code": exc.status_code,
"message": exc.detail,
"type": "http_error"
}
}
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""Handle request validation errors."""
return JSONResponse(
status_code=422,
content={
"error": {
"code": 422,
"message": "Validation error",
"type": "validation_error",
"details": exc.errors()
}
}
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""Handle general exceptions."""
logger.error(f"Unhandled exception: {exc}", exc_info=True)
return JSONResponse(
status_code=500,
content={
"error": {
"code": 500,
"message": "Internal server error",
"type": "internal_error"
}
}
)
# Middleware for request logging
@app.middleware("http")
async def log_requests(request: Request, call_next):
"""Log all requests."""
start_time = asyncio.get_event_loop().time()
# Process request
response = await call_next(request)
# Calculate processing time
process_time = asyncio.get_event_loop().time() - start_time
# Log request
logger.info(
f"{request.method} {request.url.path} - "
f"Status: {response.status_code} - "
f"Time: {process_time:.3f}s"
)
# Add processing time header
response.headers["X-Process-Time"] = str(process_time)
return response
# Include routers
app.include_router(
health.router,
prefix="/health",
tags=["Health"]
)
app.include_router(
pose.router,
prefix=f"{settings.api_prefix}/pose",
tags=["Pose Estimation"]
)
app.include_router(
stream.router,
prefix=f"{settings.api_prefix}/stream",
tags=["Streaming"]
)
# Root endpoint
@app.get("/")
async def root():
"""Root endpoint with API information."""
return {
"name": settings.app_name,
"version": settings.version,
"environment": settings.environment,
"docs_url": settings.docs_url,
"api_prefix": settings.api_prefix,
"features": {
"authentication": settings.enable_authentication,
"rate_limiting": settings.enable_rate_limiting,
"websockets": settings.enable_websockets,
"real_time_processing": settings.enable_real_time_processing
}
}
# API information endpoint
@app.get(f"{settings.api_prefix}/info")
async def api_info():
"""Get detailed API information."""
domain_config = get_domain_config()
return {
"api": {
"name": settings.app_name,
"version": settings.version,
"environment": settings.environment,
"prefix": settings.api_prefix
},
"configuration": {
"zones": len(domain_config.zones),
"routers": len(domain_config.routers),
"pose_models": len(domain_config.pose_models)
},
"features": {
"authentication": settings.enable_authentication,
"rate_limiting": settings.enable_rate_limiting,
"websockets": settings.enable_websockets,
"real_time_processing": settings.enable_real_time_processing,
"historical_data": settings.enable_historical_data
},
"limits": {
"rate_limit_requests": settings.rate_limit_requests,
"rate_limit_window": settings.rate_limit_window,
"max_websocket_connections": domain_config.streaming.max_connections
}
}
# Status endpoint
@app.get(f"{settings.api_prefix}/status")
async def api_status(request: Request):
"""Get current API status."""
try:
# Get services from app state
hardware_service = getattr(request.app.state, 'hardware_service', None)
pose_service = getattr(request.app.state, 'pose_service', None)
stream_service = getattr(request.app.state, 'stream_service', None)
pose_stream_handler = getattr(request.app.state, 'pose_stream_handler', None)
# Get service statuses
status = {
"api": {
"status": "healthy",
"uptime": "unknown",
"version": settings.version
},
"services": {
"hardware": await hardware_service.get_status() if hardware_service else {"status": "unavailable"},
"pose": await pose_service.get_status() if pose_service else {"status": "unavailable"},
"stream": await stream_service.get_status() if stream_service else {"status": "unavailable"}
},
"streaming": pose_stream_handler.get_stream_status() if pose_stream_handler else {"is_streaming": False},
"connections": await connection_manager.get_connection_stats()
}
return status
except Exception as e:
logger.error(f"Error getting API status: {e}")
return {
"api": {
"status": "error",
"error": str(e)
}
}
# Metrics endpoint (if enabled)
if settings.metrics_enabled:
@app.get(f"{settings.api_prefix}/metrics")
async def api_metrics(request: Request):
"""Get API metrics."""
try:
# Get services from app state
pose_stream_handler = getattr(request.app.state, 'pose_stream_handler', None)
metrics = {
"connections": await connection_manager.get_metrics(),
"streaming": await pose_stream_handler.get_performance_metrics() if pose_stream_handler else {}
}
return metrics
except Exception as e:
logger.error(f"Error getting metrics: {e}")
return {"error": str(e)}
# Development endpoints (only in development)
if settings.is_development and settings.enable_test_endpoints:
@app.get(f"{settings.api_prefix}/dev/config")
async def dev_config():
"""Get current configuration (development only)."""
domain_config = get_domain_config()
return {
"settings": settings.dict(),
"domain_config": domain_config.to_dict()
}
@app.post(f"{settings.api_prefix}/dev/reset")
async def dev_reset(request: Request):
"""Reset services (development only)."""
try:
# Reset services
hardware_service = getattr(request.app.state, 'hardware_service', None)
pose_service = getattr(request.app.state, 'pose_service', None)
if hardware_service:
await hardware_service.reset()
if pose_service:
await pose_service.reset()
return {"message": "Services reset successfully"}
except Exception as e:
logger.error(f"Error resetting services: {e}")
return {"error": str(e)}
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"src.api.main:app",
host=settings.host,
port=settings.port,
reload=settings.reload,
workers=settings.workers if not settings.reload else 1,
log_level=settings.log_level.lower()
)

View File

@@ -0,0 +1,8 @@
"""
FastAPI middleware package
"""
from .auth import AuthMiddleware
from .rate_limit import RateLimitMiddleware
__all__ = ["AuthMiddleware", "RateLimitMiddleware"]

View File

@@ -0,0 +1,322 @@
"""
JWT Authentication middleware for WiFi-DensePose API
"""
import logging
from typing import Optional, Dict, Any
from datetime import datetime
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from jose import JWTError, jwt
from src.config.settings import get_settings
logger = logging.getLogger(__name__)
class AuthMiddleware(BaseHTTPMiddleware):
"""JWT Authentication middleware."""
def __init__(self, app):
super().__init__(app)
self.settings = get_settings()
# Paths that don't require authentication
self.public_paths = {
"/",
"/docs",
"/redoc",
"/openapi.json",
"/health",
"/ready",
"/live",
"/version",
"/metrics"
}
# Paths that require authentication
self.protected_paths = {
"/api/v1/pose/analyze",
"/api/v1/pose/calibrate",
"/api/v1/pose/historical",
"/api/v1/stream/start",
"/api/v1/stream/stop",
"/api/v1/stream/clients",
"/api/v1/stream/broadcast"
}
async def dispatch(self, request: Request, call_next):
"""Process request through authentication middleware."""
# Skip authentication for public paths
if self._is_public_path(request.url.path):
return await call_next(request)
# Extract and validate token
token = self._extract_token(request)
if token:
try:
# Verify token and add user info to request state
user_data = await self._verify_token(token)
request.state.user = user_data
request.state.authenticated = True
logger.debug(f"Authenticated user: {user_data.get('id')}")
except Exception as e:
logger.warning(f"Token validation failed: {e}")
# For protected paths, return 401
if self._is_protected_path(request.url.path):
return JSONResponse(
status_code=401,
content={
"error": {
"code": 401,
"message": "Invalid or expired token",
"type": "authentication_error"
}
}
)
# For other paths, continue without authentication
request.state.user = None
request.state.authenticated = False
else:
# No token provided
if self._is_protected_path(request.url.path):
return JSONResponse(
status_code=401,
content={
"error": {
"code": 401,
"message": "Authentication required",
"type": "authentication_error"
}
},
headers={"WWW-Authenticate": "Bearer"}
)
request.state.user = None
request.state.authenticated = False
# Continue with request processing
response = await call_next(request)
# Add authentication headers to response
if hasattr(request.state, 'user') and request.state.user:
response.headers["X-User-ID"] = request.state.user.get("id", "")
response.headers["X-Authenticated"] = "true"
else:
response.headers["X-Authenticated"] = "false"
return response
def _is_public_path(self, path: str) -> bool:
"""Check if path is public (doesn't require authentication)."""
# Exact match
if path in self.public_paths:
return True
# Pattern matching for public paths
public_patterns = [
"/health",
"/metrics",
"/api/v1/pose/current", # Allow anonymous access to current pose data
"/api/v1/pose/zones/", # Allow anonymous access to zone data
"/api/v1/pose/activities", # Allow anonymous access to activities
"/api/v1/pose/stats", # Allow anonymous access to stats
"/api/v1/stream/status" # Allow anonymous access to stream status
]
for pattern in public_patterns:
if path.startswith(pattern):
return True
return False
def _is_protected_path(self, path: str) -> bool:
"""Check if path requires authentication."""
# Exact match
if path in self.protected_paths:
return True
# Pattern matching for protected paths
protected_patterns = [
"/api/v1/pose/analyze",
"/api/v1/pose/calibrate",
"/api/v1/pose/historical",
"/api/v1/stream/start",
"/api/v1/stream/stop",
"/api/v1/stream/clients",
"/api/v1/stream/broadcast"
]
for pattern in protected_patterns:
if path.startswith(pattern):
return True
return False
def _extract_token(self, request: Request) -> Optional[str]:
"""Extract JWT token from request."""
# Check Authorization header
auth_header = request.headers.get("authorization")
if auth_header and auth_header.startswith("Bearer "):
return auth_header.split(" ")[1]
# Check query parameter (for WebSocket connections)
token = request.query_params.get("token")
if token:
return token
# Check cookie
token = request.cookies.get("access_token")
if token:
return token
return None
async def _verify_token(self, token: str) -> Dict[str, Any]:
"""Verify JWT token and return user data."""
try:
# Decode JWT token
payload = jwt.decode(
token,
self.settings.secret_key,
algorithms=[self.settings.jwt_algorithm]
)
# Extract user information
user_id = payload.get("sub")
if not user_id:
raise ValueError("Token missing user ID")
# Check token expiration
exp = payload.get("exp")
if exp and datetime.utcnow() > datetime.fromtimestamp(exp):
raise ValueError("Token expired")
# Build user object
user_data = {
"id": user_id,
"username": payload.get("username"),
"email": payload.get("email"),
"is_admin": payload.get("is_admin", False),
"permissions": payload.get("permissions", []),
"accessible_zones": payload.get("accessible_zones", []),
"token_issued_at": payload.get("iat"),
"token_expires_at": payload.get("exp"),
"session_id": payload.get("session_id")
}
return user_data
except JWTError as e:
raise ValueError(f"JWT validation failed: {e}")
except Exception as e:
raise ValueError(f"Token verification error: {e}")
def _log_authentication_event(self, request: Request, event_type: str, details: Dict[str, Any] = None):
"""Log authentication events for security monitoring."""
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "unknown")
log_data = {
"event_type": event_type,
"timestamp": datetime.utcnow().isoformat(),
"client_ip": client_ip,
"user_agent": user_agent,
"path": request.url.path,
"method": request.method
}
if details:
log_data.update(details)
if event_type in ["authentication_failed", "token_expired", "invalid_token"]:
logger.warning(f"Auth event: {log_data}")
else:
logger.info(f"Auth event: {log_data}")
class TokenBlacklist:
"""Simple in-memory token blacklist for logout functionality."""
def __init__(self):
self._blacklisted_tokens = set()
self._cleanup_interval = 3600 # 1 hour
self._last_cleanup = datetime.utcnow()
def add_token(self, token: str):
"""Add token to blacklist."""
self._blacklisted_tokens.add(token)
self._cleanup_if_needed()
def is_blacklisted(self, token: str) -> bool:
"""Check if token is blacklisted."""
self._cleanup_if_needed()
return token in self._blacklisted_tokens
def _cleanup_if_needed(self):
"""Clean up expired tokens from blacklist."""
now = datetime.utcnow()
if (now - self._last_cleanup).total_seconds() > self._cleanup_interval:
# In a real implementation, you would check token expiration
# For now, we'll just clear old tokens periodically
self._blacklisted_tokens.clear()
self._last_cleanup = now
# Global token blacklist instance
token_blacklist = TokenBlacklist()
class SecurityHeaders:
"""Security headers for API responses."""
@staticmethod
def add_security_headers(response: Response) -> Response:
"""Add security headers to response."""
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
response.headers["Content-Security-Policy"] = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline'; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data:; "
"connect-src 'self' ws: wss:;"
)
return response
class APIKeyAuth:
"""Alternative API key authentication for service-to-service communication."""
def __init__(self, api_keys: Dict[str, Dict[str, Any]] = None):
self.api_keys = api_keys or {}
def verify_api_key(self, api_key: str) -> Optional[Dict[str, Any]]:
"""Verify API key and return associated service info."""
if api_key in self.api_keys:
return self.api_keys[api_key]
return None
def add_api_key(self, api_key: str, service_info: Dict[str, Any]):
"""Add new API key."""
self.api_keys[api_key] = service_info
def revoke_api_key(self, api_key: str):
"""Revoke API key."""
if api_key in self.api_keys:
del self.api_keys[api_key]
# Global API key auth instance
api_key_auth = APIKeyAuth()

View File

@@ -0,0 +1,429 @@
"""
Rate limiting middleware for WiFi-DensePose API
"""
import logging
import time
from typing import Dict, Optional, Tuple
from datetime import datetime, timedelta
from collections import defaultdict, deque
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from src.config.settings import get_settings
logger = logging.getLogger(__name__)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Rate limiting middleware with sliding window algorithm."""
def __init__(self, app):
super().__init__(app)
self.settings = get_settings()
# Rate limit storage (in production, use Redis)
self.request_counts = defaultdict(lambda: deque())
self.blocked_clients = {}
# Rate limit configurations
self.rate_limits = {
"anonymous": {
"requests": self.settings.rate_limit_requests,
"window": self.settings.rate_limit_window,
"burst": 10 # Allow burst of 10 requests
},
"authenticated": {
"requests": self.settings.rate_limit_authenticated_requests,
"window": self.settings.rate_limit_window,
"burst": 50
},
"admin": {
"requests": 10000, # Very high limit for admins
"window": self.settings.rate_limit_window,
"burst": 100
}
}
# Path-specific rate limits
self.path_limits = {
"/api/v1/pose/current": {"requests": 60, "window": 60}, # 1 per second
"/api/v1/pose/analyze": {"requests": 10, "window": 60}, # 10 per minute
"/api/v1/pose/calibrate": {"requests": 1, "window": 300}, # 1 per 5 minutes
"/api/v1/stream/start": {"requests": 5, "window": 60}, # 5 per minute
"/api/v1/stream/stop": {"requests": 5, "window": 60}, # 5 per minute
}
# Exempt paths from rate limiting
self.exempt_paths = {
"/health",
"/ready",
"/live",
"/version",
"/metrics"
}
async def dispatch(self, request: Request, call_next):
"""Process request through rate limiting middleware."""
# Skip rate limiting for exempt paths
if self._is_exempt_path(request.url.path):
return await call_next(request)
# Get client identifier
client_id = self._get_client_id(request)
# Check if client is temporarily blocked
if self._is_client_blocked(client_id):
return self._create_rate_limit_response(
"Client temporarily blocked due to excessive requests"
)
# Get user type for rate limiting
user_type = self._get_user_type(request)
# Check rate limits
rate_limit_result = self._check_rate_limits(
client_id,
request.url.path,
user_type
)
if not rate_limit_result["allowed"]:
# Log rate limit violation
self._log_rate_limit_violation(request, client_id, rate_limit_result)
# Check if client should be temporarily blocked
if rate_limit_result.get("violations", 0) > 5:
self._block_client(client_id, duration=300) # 5 minutes
return self._create_rate_limit_response(
rate_limit_result["message"],
retry_after=rate_limit_result.get("retry_after", 60)
)
# Record the request
self._record_request(client_id, request.url.path)
# Process request
response = await call_next(request)
# Add rate limit headers
self._add_rate_limit_headers(response, client_id, user_type)
return response
def _is_exempt_path(self, path: str) -> bool:
"""Check if path is exempt from rate limiting."""
return path in self.exempt_paths
def _get_client_id(self, request: Request) -> str:
"""Get unique client identifier for rate limiting."""
# Try to get user ID from request state (set by auth middleware)
if hasattr(request.state, 'user') and request.state.user:
return f"user:{request.state.user['id']}"
# Fall back to IP address
client_ip = request.client.host if request.client else "unknown"
# Include user agent for better identification
user_agent = request.headers.get("user-agent", "")
user_agent_hash = str(hash(user_agent))[:8]
return f"ip:{client_ip}:{user_agent_hash}"
def _get_user_type(self, request: Request) -> str:
"""Determine user type for rate limiting."""
if hasattr(request.state, 'user') and request.state.user:
if request.state.user.get("is_admin", False):
return "admin"
return "authenticated"
return "anonymous"
def _check_rate_limits(self, client_id: str, path: str, user_type: str) -> Dict:
"""Check if request is within rate limits."""
now = time.time()
# Get applicable rate limits
general_limit = self.rate_limits[user_type]
path_limit = self.path_limits.get(path)
# Check general rate limit
general_result = self._check_limit(
client_id,
"general",
general_limit["requests"],
general_limit["window"],
now
)
if not general_result["allowed"]:
return general_result
# Check path-specific rate limit if exists
if path_limit:
path_result = self._check_limit(
client_id,
f"path:{path}",
path_limit["requests"],
path_limit["window"],
now
)
if not path_result["allowed"]:
return path_result
return {"allowed": True}
def _check_limit(self, client_id: str, limit_type: str, max_requests: int, window: int, now: float) -> Dict:
"""Check specific rate limit using sliding window."""
key = f"{client_id}:{limit_type}"
requests = self.request_counts[key]
# Remove old requests outside the window
cutoff = now - window
while requests and requests[0] <= cutoff:
requests.popleft()
# Check if limit exceeded
if len(requests) >= max_requests:
# Calculate retry after time
oldest_request = requests[0] if requests else now
retry_after = int(oldest_request + window - now) + 1
return {
"allowed": False,
"message": f"Rate limit exceeded: {max_requests} requests per {window} seconds",
"retry_after": retry_after,
"current_count": len(requests),
"limit": max_requests,
"window": window
}
return {
"allowed": True,
"current_count": len(requests),
"limit": max_requests,
"window": window
}
def _record_request(self, client_id: str, path: str):
"""Record a request for rate limiting."""
now = time.time()
# Record general request
general_key = f"{client_id}:general"
self.request_counts[general_key].append(now)
# Record path-specific request if path has specific limits
if path in self.path_limits:
path_key = f"{client_id}:path:{path}"
self.request_counts[path_key].append(now)
def _is_client_blocked(self, client_id: str) -> bool:
"""Check if client is temporarily blocked."""
if client_id in self.blocked_clients:
block_until = self.blocked_clients[client_id]
if time.time() < block_until:
return True
else:
# Block expired, remove it
del self.blocked_clients[client_id]
return False
def _block_client(self, client_id: str, duration: int):
"""Temporarily block a client."""
self.blocked_clients[client_id] = time.time() + duration
logger.warning(f"Client {client_id} blocked for {duration} seconds due to rate limit violations")
def _create_rate_limit_response(self, message: str, retry_after: int = 60) -> JSONResponse:
"""Create rate limit exceeded response."""
return JSONResponse(
status_code=429,
content={
"error": {
"code": 429,
"message": message,
"type": "rate_limit_exceeded"
}
},
headers={
"Retry-After": str(retry_after),
"X-RateLimit-Limit": "Exceeded",
"X-RateLimit-Remaining": "0"
}
)
def _add_rate_limit_headers(self, response: Response, client_id: str, user_type: str):
"""Add rate limit headers to response."""
try:
general_limit = self.rate_limits[user_type]
general_key = f"{client_id}:general"
current_requests = len(self.request_counts[general_key])
remaining = max(0, general_limit["requests"] - current_requests)
response.headers["X-RateLimit-Limit"] = str(general_limit["requests"])
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Window"] = str(general_limit["window"])
# Add reset time
if self.request_counts[general_key]:
oldest_request = self.request_counts[general_key][0]
reset_time = int(oldest_request + general_limit["window"])
response.headers["X-RateLimit-Reset"] = str(reset_time)
except Exception as e:
logger.error(f"Error adding rate limit headers: {e}")
def _log_rate_limit_violation(self, request: Request, client_id: str, result: Dict):
"""Log rate limit violations for monitoring."""
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "unknown")
log_data = {
"event_type": "rate_limit_violation",
"timestamp": datetime.utcnow().isoformat(),
"client_id": client_id,
"client_ip": client_ip,
"user_agent": user_agent,
"path": request.url.path,
"method": request.method,
"current_count": result.get("current_count"),
"limit": result.get("limit"),
"window": result.get("window")
}
logger.warning(f"Rate limit violation: {log_data}")
def cleanup_old_data(self):
"""Clean up old rate limiting data (call periodically)."""
now = time.time()
cutoff = now - 3600 # Keep data for 1 hour
# Clean up request counts
for key in list(self.request_counts.keys()):
requests = self.request_counts[key]
while requests and requests[0] <= cutoff:
requests.popleft()
# Remove empty deques
if not requests:
del self.request_counts[key]
# Clean up expired blocks
expired_blocks = [
client_id for client_id, block_until in self.blocked_clients.items()
if now >= block_until
]
for client_id in expired_blocks:
del self.blocked_clients[client_id]
class AdaptiveRateLimit:
"""Adaptive rate limiting based on system load."""
def __init__(self):
self.base_limits = {}
self.current_multiplier = 1.0
self.load_history = deque(maxlen=60) # Keep 1 minute of load data
def update_system_load(self, cpu_percent: float, memory_percent: float):
"""Update system load metrics."""
load_score = (cpu_percent + memory_percent) / 2
self.load_history.append(load_score)
# Calculate adaptive multiplier
if len(self.load_history) >= 10:
avg_load = sum(self.load_history) / len(self.load_history)
if avg_load > 80:
self.current_multiplier = 0.5 # Reduce limits by 50%
elif avg_load > 60:
self.current_multiplier = 0.7 # Reduce limits by 30%
elif avg_load < 30:
self.current_multiplier = 1.2 # Increase limits by 20%
else:
self.current_multiplier = 1.0 # Normal limits
def get_adjusted_limit(self, base_limit: int) -> int:
"""Get adjusted rate limit based on system load."""
return max(1, int(base_limit * self.current_multiplier))
class RateLimitStorage:
"""Abstract interface for rate limit storage (Redis implementation)."""
async def get_count(self, key: str, window: int) -> int:
"""Get current request count for key within window."""
raise NotImplementedError
async def increment(self, key: str, window: int) -> int:
"""Increment request count and return new count."""
raise NotImplementedError
async def is_blocked(self, client_id: str) -> bool:
"""Check if client is blocked."""
raise NotImplementedError
async def block_client(self, client_id: str, duration: int):
"""Block client for duration seconds."""
raise NotImplementedError
class RedisRateLimitStorage(RateLimitStorage):
"""Redis-based rate limit storage for production use."""
def __init__(self, redis_client):
self.redis = redis_client
async def get_count(self, key: str, window: int) -> int:
"""Get current request count using Redis sliding window."""
now = time.time()
pipeline = self.redis.pipeline()
# Remove old entries
pipeline.zremrangebyscore(key, 0, now - window)
# Count current entries
pipeline.zcard(key)
results = await pipeline.execute()
return results[1]
async def increment(self, key: str, window: int) -> int:
"""Increment request count using Redis."""
now = time.time()
pipeline = self.redis.pipeline()
# Add current request
pipeline.zadd(key, {str(now): now})
# Remove old entries
pipeline.zremrangebyscore(key, 0, now - window)
# Set expiration
pipeline.expire(key, window + 1)
# Get count
pipeline.zcard(key)
results = await pipeline.execute()
return results[3]
async def is_blocked(self, client_id: str) -> bool:
"""Check if client is blocked."""
block_key = f"blocked:{client_id}"
return await self.redis.exists(block_key)
async def block_client(self, client_id: str, duration: int):
"""Block client for duration seconds."""
block_key = f"blocked:{client_id}"
await self.redis.setex(block_key, duration, "1")
# Global adaptive rate limiter instance
adaptive_rate_limit = AdaptiveRateLimit()

View File

@@ -0,0 +1,7 @@
"""
API routers package
"""
from . import pose, stream, health
__all__ = ["pose", "stream", "health"]

View File

@@ -0,0 +1,419 @@
"""
Health check API endpoints
"""
import logging
import psutil
from typing import Dict, Any, Optional
from datetime import datetime, timedelta
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel, Field
from src.api.dependencies import get_current_user
from src.config.settings import get_settings
logger = logging.getLogger(__name__)
router = APIRouter()
# Response models
class ComponentHealth(BaseModel):
"""Health status for a system component."""
name: str = Field(..., description="Component name")
status: str = Field(..., description="Health status (healthy, degraded, unhealthy)")
message: Optional[str] = Field(default=None, description="Status message")
last_check: datetime = Field(..., description="Last health check timestamp")
uptime_seconds: Optional[float] = Field(default=None, description="Component uptime")
metrics: Optional[Dict[str, Any]] = Field(default=None, description="Component metrics")
class SystemHealth(BaseModel):
"""Overall system health status."""
status: str = Field(..., description="Overall system status")
timestamp: datetime = Field(..., description="Health check timestamp")
uptime_seconds: float = Field(..., description="System uptime")
components: Dict[str, ComponentHealth] = Field(..., description="Component health status")
system_metrics: Dict[str, Any] = Field(..., description="System-level metrics")
class ReadinessCheck(BaseModel):
"""System readiness check result."""
ready: bool = Field(..., description="Whether system is ready to serve requests")
timestamp: datetime = Field(..., description="Readiness check timestamp")
checks: Dict[str, bool] = Field(..., description="Individual readiness checks")
message: str = Field(..., description="Readiness status message")
# Health check endpoints
@router.get("/health", response_model=SystemHealth)
async def health_check(request: Request):
"""Comprehensive system health check."""
try:
# Get services from app state
hardware_service = getattr(request.app.state, 'hardware_service', None)
pose_service = getattr(request.app.state, 'pose_service', None)
stream_service = getattr(request.app.state, 'stream_service', None)
timestamp = datetime.utcnow()
components = {}
overall_status = "healthy"
# Check hardware service
if hardware_service:
try:
hw_health = await hardware_service.health_check()
components["hardware"] = ComponentHealth(
name="Hardware Service",
status=hw_health["status"],
message=hw_health.get("message"),
last_check=timestamp,
uptime_seconds=hw_health.get("uptime_seconds"),
metrics=hw_health.get("metrics")
)
if hw_health["status"] != "healthy":
overall_status = "degraded" if overall_status == "healthy" else "unhealthy"
except Exception as e:
logger.error(f"Hardware service health check failed: {e}")
components["hardware"] = ComponentHealth(
name="Hardware Service",
status="unhealthy",
message=f"Health check failed: {str(e)}",
last_check=timestamp
)
overall_status = "unhealthy"
else:
components["hardware"] = ComponentHealth(
name="Hardware Service",
status="unavailable",
message="Service not initialized",
last_check=timestamp
)
overall_status = "degraded"
# Check pose service
if pose_service:
try:
pose_health = await pose_service.health_check()
components["pose"] = ComponentHealth(
name="Pose Service",
status=pose_health["status"],
message=pose_health.get("message"),
last_check=timestamp,
uptime_seconds=pose_health.get("uptime_seconds"),
metrics=pose_health.get("metrics")
)
if pose_health["status"] != "healthy":
overall_status = "degraded" if overall_status == "healthy" else "unhealthy"
except Exception as e:
logger.error(f"Pose service health check failed: {e}")
components["pose"] = ComponentHealth(
name="Pose Service",
status="unhealthy",
message=f"Health check failed: {str(e)}",
last_check=timestamp
)
overall_status = "unhealthy"
else:
components["pose"] = ComponentHealth(
name="Pose Service",
status="unavailable",
message="Service not initialized",
last_check=timestamp
)
overall_status = "degraded"
# Check stream service
if stream_service:
try:
stream_health = await stream_service.health_check()
components["stream"] = ComponentHealth(
name="Stream Service",
status=stream_health["status"],
message=stream_health.get("message"),
last_check=timestamp,
uptime_seconds=stream_health.get("uptime_seconds"),
metrics=stream_health.get("metrics")
)
if stream_health["status"] != "healthy":
overall_status = "degraded" if overall_status == "healthy" else "unhealthy"
except Exception as e:
logger.error(f"Stream service health check failed: {e}")
components["stream"] = ComponentHealth(
name="Stream Service",
status="unhealthy",
message=f"Health check failed: {str(e)}",
last_check=timestamp
)
overall_status = "unhealthy"
else:
components["stream"] = ComponentHealth(
name="Stream Service",
status="unavailable",
message="Service not initialized",
last_check=timestamp
)
overall_status = "degraded"
# Get system metrics
system_metrics = get_system_metrics()
# Calculate system uptime (placeholder - would need actual startup time)
uptime_seconds = 0.0 # TODO: Implement actual uptime tracking
return SystemHealth(
status=overall_status,
timestamp=timestamp,
uptime_seconds=uptime_seconds,
components=components,
system_metrics=system_metrics
)
except Exception as e:
logger.error(f"Health check failed: {e}")
raise HTTPException(
status_code=500,
detail=f"Health check failed: {str(e)}"
)
@router.get("/ready", response_model=ReadinessCheck)
async def readiness_check(request: Request):
"""Check if system is ready to serve requests."""
try:
timestamp = datetime.utcnow()
checks = {}
# Check if services are available in app state
if hasattr(request.app.state, 'pose_service') and request.app.state.pose_service:
try:
checks["pose_ready"] = await request.app.state.pose_service.is_ready()
except Exception as e:
logger.warning(f"Pose service readiness check failed: {e}")
checks["pose_ready"] = False
else:
checks["pose_ready"] = False
if hasattr(request.app.state, 'stream_service') and request.app.state.stream_service:
try:
checks["stream_ready"] = await request.app.state.stream_service.is_ready()
except Exception as e:
logger.warning(f"Stream service readiness check failed: {e}")
checks["stream_ready"] = False
else:
checks["stream_ready"] = False
# Hardware service check (basic availability)
checks["hardware_ready"] = True # Basic readiness - API is responding
# Check system resources
checks["memory_available"] = check_memory_availability()
checks["disk_space_available"] = check_disk_space()
# Application is ready if at least the basic services are available
# For now, we'll consider it ready if the API is responding
ready = True # Basic readiness
message = "System is ready" if ready else "System is not ready"
if not ready:
failed_checks = [name for name, status in checks.items() if not status]
message += f". Failed checks: {', '.join(failed_checks)}"
return ReadinessCheck(
ready=ready,
timestamp=timestamp,
checks=checks,
message=message
)
except Exception as e:
logger.error(f"Readiness check failed: {e}")
return ReadinessCheck(
ready=False,
timestamp=datetime.utcnow(),
checks={},
message=f"Readiness check failed: {str(e)}"
)
@router.get("/live")
async def liveness_check():
"""Simple liveness check for load balancers."""
return {
"status": "alive",
"timestamp": datetime.utcnow().isoformat()
}
@router.get("/metrics")
async def get_health_metrics(
request: Request,
current_user: Optional[Dict] = Depends(get_current_user)
):
"""Get detailed system metrics."""
try:
metrics = get_system_metrics()
# Add additional metrics if authenticated
if current_user:
metrics.update(get_detailed_metrics())
return {
"timestamp": datetime.utcnow().isoformat(),
"metrics": metrics
}
except Exception as e:
logger.error(f"Error getting system metrics: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to get system metrics: {str(e)}"
)
@router.get("/version")
async def get_version_info():
"""Get application version information."""
settings = get_settings()
return {
"name": settings.app_name,
"version": settings.version,
"environment": settings.environment,
"debug": settings.debug,
"timestamp": datetime.utcnow().isoformat()
}
def get_system_metrics() -> Dict[str, Any]:
"""Get basic system metrics."""
try:
# CPU metrics
cpu_percent = psutil.cpu_percent(interval=1)
cpu_count = psutil.cpu_count()
# Memory metrics
memory = psutil.virtual_memory()
memory_metrics = {
"total_gb": round(memory.total / (1024**3), 2),
"available_gb": round(memory.available / (1024**3), 2),
"used_gb": round(memory.used / (1024**3), 2),
"percent": memory.percent
}
# Disk metrics
disk = psutil.disk_usage('/')
disk_metrics = {
"total_gb": round(disk.total / (1024**3), 2),
"free_gb": round(disk.free / (1024**3), 2),
"used_gb": round(disk.used / (1024**3), 2),
"percent": round((disk.used / disk.total) * 100, 2)
}
# Network metrics (basic)
network = psutil.net_io_counters()
network_metrics = {
"bytes_sent": network.bytes_sent,
"bytes_recv": network.bytes_recv,
"packets_sent": network.packets_sent,
"packets_recv": network.packets_recv
}
return {
"cpu": {
"percent": cpu_percent,
"count": cpu_count
},
"memory": memory_metrics,
"disk": disk_metrics,
"network": network_metrics
}
except Exception as e:
logger.error(f"Error getting system metrics: {e}")
return {}
def get_detailed_metrics() -> Dict[str, Any]:
"""Get detailed system metrics (requires authentication)."""
try:
# Process metrics
process = psutil.Process()
process_metrics = {
"pid": process.pid,
"cpu_percent": process.cpu_percent(),
"memory_mb": round(process.memory_info().rss / (1024**2), 2),
"num_threads": process.num_threads(),
"create_time": datetime.fromtimestamp(process.create_time()).isoformat()
}
# Load average (Unix-like systems)
load_avg = None
try:
load_avg = psutil.getloadavg()
except AttributeError:
# Windows doesn't have load average
pass
# Temperature sensors (if available)
temperatures = {}
try:
temps = psutil.sensors_temperatures()
for name, entries in temps.items():
temperatures[name] = [
{"label": entry.label, "current": entry.current}
for entry in entries
]
except AttributeError:
# Not available on all systems
pass
detailed = {
"process": process_metrics
}
if load_avg:
detailed["load_average"] = {
"1min": load_avg[0],
"5min": load_avg[1],
"15min": load_avg[2]
}
if temperatures:
detailed["temperatures"] = temperatures
return detailed
except Exception as e:
logger.error(f"Error getting detailed metrics: {e}")
return {}
def check_memory_availability() -> bool:
"""Check if sufficient memory is available."""
try:
memory = psutil.virtual_memory()
# Consider system ready if less than 90% memory is used
return memory.percent < 90.0
except Exception:
return False
def check_disk_space() -> bool:
"""Check if sufficient disk space is available."""
try:
disk = psutil.disk_usage('/')
# Consider system ready if more than 1GB free space
free_gb = disk.free / (1024**3)
return free_gb > 1.0
except Exception:
return False

420
v1/src/api/routers/pose.py Normal file
View File

@@ -0,0 +1,420 @@
"""
Pose estimation API endpoints
"""
import logging
from typing import List, Optional, Dict, Any
from datetime import datetime, timedelta
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from src.api.dependencies import (
get_pose_service,
get_hardware_service,
get_current_user,
require_auth
)
from src.services.pose_service import PoseService
from src.services.hardware_service import HardwareService
from src.config.settings import get_settings
logger = logging.getLogger(__name__)
router = APIRouter()
# Request/Response models
class PoseEstimationRequest(BaseModel):
"""Request model for pose estimation."""
zone_ids: Optional[List[str]] = Field(
default=None,
description="Specific zones to analyze (all zones if not specified)"
)
confidence_threshold: Optional[float] = Field(
default=None,
ge=0.0,
le=1.0,
description="Minimum confidence threshold for detections"
)
max_persons: Optional[int] = Field(
default=None,
ge=1,
le=50,
description="Maximum number of persons to detect"
)
include_keypoints: bool = Field(
default=True,
description="Include detailed keypoint data"
)
include_segmentation: bool = Field(
default=False,
description="Include DensePose segmentation masks"
)
class PersonPose(BaseModel):
"""Person pose data model."""
person_id: str = Field(..., description="Unique person identifier")
confidence: float = Field(..., description="Detection confidence score")
bounding_box: Dict[str, float] = Field(..., description="Person bounding box")
keypoints: Optional[List[Dict[str, Any]]] = Field(
default=None,
description="Body keypoints with coordinates and confidence"
)
segmentation: Optional[Dict[str, Any]] = Field(
default=None,
description="DensePose segmentation data"
)
zone_id: Optional[str] = Field(
default=None,
description="Zone where person is detected"
)
activity: Optional[str] = Field(
default=None,
description="Detected activity"
)
timestamp: datetime = Field(..., description="Detection timestamp")
class PoseEstimationResponse(BaseModel):
"""Response model for pose estimation."""
timestamp: datetime = Field(..., description="Analysis timestamp")
frame_id: str = Field(..., description="Unique frame identifier")
persons: List[PersonPose] = Field(..., description="Detected persons")
zone_summary: Dict[str, int] = Field(..., description="Person count per zone")
processing_time_ms: float = Field(..., description="Processing time in milliseconds")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
class HistoricalDataRequest(BaseModel):
"""Request model for historical pose data."""
start_time: datetime = Field(..., description="Start time for data query")
end_time: datetime = Field(..., description="End time for data query")
zone_ids: Optional[List[str]] = Field(
default=None,
description="Filter by specific zones"
)
aggregation_interval: Optional[int] = Field(
default=300,
ge=60,
le=3600,
description="Aggregation interval in seconds"
)
include_raw_data: bool = Field(
default=False,
description="Include raw detection data"
)
# Endpoints
@router.get("/current", response_model=PoseEstimationResponse)
async def get_current_pose_estimation(
request: PoseEstimationRequest = Depends(),
pose_service: PoseService = Depends(get_pose_service),
current_user: Optional[Dict] = Depends(get_current_user)
):
"""Get current pose estimation from WiFi signals."""
try:
logger.info(f"Processing pose estimation request from user: {current_user.get('id') if current_user else 'anonymous'}")
# Get current pose estimation
result = await pose_service.estimate_poses(
zone_ids=request.zone_ids,
confidence_threshold=request.confidence_threshold,
max_persons=request.max_persons,
include_keypoints=request.include_keypoints,
include_segmentation=request.include_segmentation
)
return PoseEstimationResponse(**result)
except Exception as e:
logger.error(f"Error in pose estimation: {e}")
raise HTTPException(
status_code=500,
detail=f"Pose estimation failed: {str(e)}"
)
@router.post("/analyze", response_model=PoseEstimationResponse)
async def analyze_pose_data(
request: PoseEstimationRequest,
background_tasks: BackgroundTasks,
pose_service: PoseService = Depends(get_pose_service),
current_user: Dict = Depends(require_auth)
):
"""Trigger pose analysis with custom parameters."""
try:
logger.info(f"Custom pose analysis requested by user: {current_user['id']}")
# Trigger analysis
result = await pose_service.analyze_with_params(
zone_ids=request.zone_ids,
confidence_threshold=request.confidence_threshold,
max_persons=request.max_persons,
include_keypoints=request.include_keypoints,
include_segmentation=request.include_segmentation
)
# Schedule background processing if needed
if request.include_segmentation:
background_tasks.add_task(
pose_service.process_segmentation_data,
result["frame_id"]
)
return PoseEstimationResponse(**result)
except Exception as e:
logger.error(f"Error in pose analysis: {e}")
raise HTTPException(
status_code=500,
detail=f"Pose analysis failed: {str(e)}"
)
@router.get("/zones/{zone_id}/occupancy")
async def get_zone_occupancy(
zone_id: str,
pose_service: PoseService = Depends(get_pose_service),
current_user: Optional[Dict] = Depends(get_current_user)
):
"""Get current occupancy for a specific zone."""
try:
occupancy = await pose_service.get_zone_occupancy(zone_id)
if occupancy is None:
raise HTTPException(
status_code=404,
detail=f"Zone '{zone_id}' not found"
)
return {
"zone_id": zone_id,
"current_occupancy": occupancy["count"],
"max_occupancy": occupancy.get("max_occupancy"),
"persons": occupancy["persons"],
"timestamp": occupancy["timestamp"]
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting zone occupancy: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to get zone occupancy: {str(e)}"
)
@router.get("/zones/summary")
async def get_zones_summary(
pose_service: PoseService = Depends(get_pose_service),
current_user: Optional[Dict] = Depends(get_current_user)
):
"""Get occupancy summary for all zones."""
try:
summary = await pose_service.get_zones_summary()
return {
"timestamp": datetime.utcnow(),
"total_persons": summary["total_persons"],
"zones": summary["zones"],
"active_zones": summary["active_zones"]
}
except Exception as e:
logger.error(f"Error getting zones summary: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to get zones summary: {str(e)}"
)
@router.post("/historical")
async def get_historical_data(
request: HistoricalDataRequest,
pose_service: PoseService = Depends(get_pose_service),
current_user: Dict = Depends(require_auth)
):
"""Get historical pose estimation data."""
try:
# Validate time range
if request.end_time <= request.start_time:
raise HTTPException(
status_code=400,
detail="End time must be after start time"
)
# Limit query range to prevent excessive data
max_range = timedelta(days=7)
if request.end_time - request.start_time > max_range:
raise HTTPException(
status_code=400,
detail="Query range cannot exceed 7 days"
)
data = await pose_service.get_historical_data(
start_time=request.start_time,
end_time=request.end_time,
zone_ids=request.zone_ids,
aggregation_interval=request.aggregation_interval,
include_raw_data=request.include_raw_data
)
return {
"query": {
"start_time": request.start_time,
"end_time": request.end_time,
"zone_ids": request.zone_ids,
"aggregation_interval": request.aggregation_interval
},
"data": data["aggregated_data"],
"raw_data": data.get("raw_data") if request.include_raw_data else None,
"total_records": data["total_records"]
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting historical data: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to get historical data: {str(e)}"
)
@router.get("/activities")
async def get_detected_activities(
zone_id: Optional[str] = Query(None, description="Filter by zone ID"),
limit: int = Query(10, ge=1, le=100, description="Maximum number of activities"),
pose_service: PoseService = Depends(get_pose_service),
current_user: Optional[Dict] = Depends(get_current_user)
):
"""Get recently detected activities."""
try:
activities = await pose_service.get_recent_activities(
zone_id=zone_id,
limit=limit
)
return {
"activities": activities,
"total_count": len(activities),
"zone_id": zone_id
}
except Exception as e:
logger.error(f"Error getting activities: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to get activities: {str(e)}"
)
@router.post("/calibrate")
async def calibrate_pose_system(
background_tasks: BackgroundTasks,
pose_service: PoseService = Depends(get_pose_service),
hardware_service: HardwareService = Depends(get_hardware_service),
current_user: Dict = Depends(require_auth)
):
"""Calibrate the pose estimation system."""
try:
logger.info(f"Pose system calibration initiated by user: {current_user['id']}")
# Check if calibration is already in progress
if await pose_service.is_calibrating():
raise HTTPException(
status_code=409,
detail="Calibration already in progress"
)
# Start calibration process
calibration_id = await pose_service.start_calibration()
# Schedule background calibration task
background_tasks.add_task(
pose_service.run_calibration,
calibration_id
)
return {
"calibration_id": calibration_id,
"status": "started",
"estimated_duration_minutes": 5,
"message": "Calibration process started"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error starting calibration: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to start calibration: {str(e)}"
)
@router.get("/calibration/status")
async def get_calibration_status(
pose_service: PoseService = Depends(get_pose_service),
current_user: Dict = Depends(require_auth)
):
"""Get current calibration status."""
try:
status = await pose_service.get_calibration_status()
return {
"is_calibrating": status["is_calibrating"],
"calibration_id": status.get("calibration_id"),
"progress_percent": status.get("progress_percent", 0),
"current_step": status.get("current_step"),
"estimated_remaining_minutes": status.get("estimated_remaining_minutes"),
"last_calibration": status.get("last_calibration")
}
except Exception as e:
logger.error(f"Error getting calibration status: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to get calibration status: {str(e)}"
)
@router.get("/stats")
async def get_pose_statistics(
hours: int = Query(24, ge=1, le=168, description="Hours of data to analyze"),
pose_service: PoseService = Depends(get_pose_service),
current_user: Optional[Dict] = Depends(get_current_user)
):
"""Get pose estimation statistics."""
try:
end_time = datetime.utcnow()
start_time = end_time - timedelta(hours=hours)
stats = await pose_service.get_statistics(
start_time=start_time,
end_time=end_time
)
return {
"period": {
"start_time": start_time,
"end_time": end_time,
"hours": hours
},
"statistics": stats
}
except Exception as e:
logger.error(f"Error getting statistics: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to get statistics: {str(e)}"
)

View File

@@ -0,0 +1,468 @@
"""
WebSocket streaming API endpoints
"""
import json
import logging
from typing import Dict, List, Optional, Any
from datetime import datetime
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, HTTPException, Query
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from src.api.dependencies import (
get_stream_service,
get_pose_service,
get_current_user_ws,
require_auth
)
from src.api.websocket.connection_manager import ConnectionManager
from src.services.stream_service import StreamService
from src.services.pose_service import PoseService
logger = logging.getLogger(__name__)
router = APIRouter()
# Initialize connection manager
connection_manager = ConnectionManager()
# Request/Response models
class StreamSubscriptionRequest(BaseModel):
"""Request model for stream subscription."""
zone_ids: Optional[List[str]] = Field(
default=None,
description="Zones to subscribe to (all zones if not specified)"
)
stream_types: List[str] = Field(
default=["pose_data"],
description="Types of data to stream"
)
min_confidence: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Minimum confidence threshold for streaming"
)
max_fps: int = Field(
default=30,
ge=1,
le=60,
description="Maximum frames per second"
)
include_metadata: bool = Field(
default=True,
description="Include metadata in stream"
)
class StreamStatus(BaseModel):
"""Stream status model."""
is_active: bool = Field(..., description="Whether streaming is active")
connected_clients: int = Field(..., description="Number of connected clients")
streams: List[Dict[str, Any]] = Field(..., description="Active streams")
uptime_seconds: float = Field(..., description="Stream uptime in seconds")
# WebSocket endpoints
@router.websocket("/pose")
async def websocket_pose_stream(
websocket: WebSocket,
zone_ids: Optional[str] = Query(None, description="Comma-separated zone IDs"),
min_confidence: float = Query(0.5, ge=0.0, le=1.0),
max_fps: int = Query(30, ge=1, le=60),
token: Optional[str] = Query(None, description="Authentication token")
):
"""WebSocket endpoint for real-time pose data streaming."""
client_id = None
try:
# Accept WebSocket connection
await websocket.accept()
# Check authentication if enabled
from src.config.settings import get_settings
settings = get_settings()
if settings.enable_authentication and not token:
await websocket.send_json({
"type": "error",
"message": "Authentication token required"
})
await websocket.close(code=1008)
return
# Parse zone IDs
zone_list = None
if zone_ids:
zone_list = [zone.strip() for zone in zone_ids.split(",") if zone.strip()]
# Register client with connection manager
client_id = await connection_manager.connect(
websocket=websocket,
stream_type="pose",
zone_ids=zone_list,
min_confidence=min_confidence,
max_fps=max_fps
)
logger.info(f"WebSocket client {client_id} connected for pose streaming")
# Send initial connection confirmation
await websocket.send_json({
"type": "connection_established",
"client_id": client_id,
"timestamp": datetime.utcnow().isoformat(),
"config": {
"zone_ids": zone_list,
"min_confidence": min_confidence,
"max_fps": max_fps
}
})
# Keep connection alive and handle incoming messages
while True:
try:
# Wait for client messages (ping, config updates, etc.)
message = await websocket.receive_text()
data = json.loads(message)
await handle_websocket_message(client_id, data, websocket)
except WebSocketDisconnect:
break
except json.JSONDecodeError:
await websocket.send_json({
"type": "error",
"message": "Invalid JSON format"
})
except Exception as e:
logger.error(f"Error handling WebSocket message: {e}")
await websocket.send_json({
"type": "error",
"message": "Internal server error"
})
except WebSocketDisconnect:
logger.info(f"WebSocket client {client_id} disconnected")
except Exception as e:
logger.error(f"WebSocket error: {e}")
finally:
if client_id:
await connection_manager.disconnect(client_id)
@router.websocket("/events")
async def websocket_events_stream(
websocket: WebSocket,
event_types: Optional[str] = Query(None, description="Comma-separated event types"),
zone_ids: Optional[str] = Query(None, description="Comma-separated zone IDs"),
token: Optional[str] = Query(None, description="Authentication token")
):
"""WebSocket endpoint for real-time event streaming."""
client_id = None
try:
await websocket.accept()
# Check authentication if enabled
from src.config.settings import get_settings
settings = get_settings()
if settings.enable_authentication and not token:
await websocket.send_json({
"type": "error",
"message": "Authentication token required"
})
await websocket.close(code=1008)
return
# Parse parameters
event_list = None
if event_types:
event_list = [event.strip() for event in event_types.split(",") if event.strip()]
zone_list = None
if zone_ids:
zone_list = [zone.strip() for zone in zone_ids.split(",") if zone.strip()]
# Register client
client_id = await connection_manager.connect(
websocket=websocket,
stream_type="events",
zone_ids=zone_list,
event_types=event_list
)
logger.info(f"WebSocket client {client_id} connected for event streaming")
# Send confirmation
await websocket.send_json({
"type": "connection_established",
"client_id": client_id,
"timestamp": datetime.utcnow().isoformat(),
"config": {
"event_types": event_list,
"zone_ids": zone_list
}
})
# Handle messages
while True:
try:
message = await websocket.receive_text()
data = json.loads(message)
await handle_websocket_message(client_id, data, websocket)
except WebSocketDisconnect:
break
except Exception as e:
logger.error(f"Error in events WebSocket: {e}")
except WebSocketDisconnect:
logger.info(f"Events WebSocket client {client_id} disconnected")
except Exception as e:
logger.error(f"Events WebSocket error: {e}")
finally:
if client_id:
await connection_manager.disconnect(client_id)
async def handle_websocket_message(client_id: str, data: Dict[str, Any], websocket: WebSocket):
"""Handle incoming WebSocket messages."""
message_type = data.get("type")
if message_type == "ping":
await websocket.send_json({
"type": "pong",
"timestamp": datetime.utcnow().isoformat()
})
elif message_type == "update_config":
# Update client configuration
config = data.get("config", {})
await connection_manager.update_client_config(client_id, config)
await websocket.send_json({
"type": "config_updated",
"timestamp": datetime.utcnow().isoformat(),
"config": config
})
elif message_type == "get_status":
# Send current status
status = await connection_manager.get_client_status(client_id)
await websocket.send_json({
"type": "status",
"timestamp": datetime.utcnow().isoformat(),
"status": status
})
else:
await websocket.send_json({
"type": "error",
"message": f"Unknown message type: {message_type}"
})
# HTTP endpoints for stream management
@router.get("/status", response_model=StreamStatus)
async def get_stream_status(
stream_service: StreamService = Depends(get_stream_service)
):
"""Get current streaming status."""
try:
status = await stream_service.get_status()
connections = await connection_manager.get_connection_stats()
# Calculate uptime (simplified for now)
uptime_seconds = 0.0
if status.get("running", False):
uptime_seconds = 3600.0 # Default 1 hour for demo
return StreamStatus(
is_active=status.get("running", False),
connected_clients=connections.get("total_clients", status["connections"]["active"]),
streams=[{
"type": "pose_stream",
"active": status.get("running", False),
"buffer_size": status["buffers"]["pose_buffer_size"]
}],
uptime_seconds=uptime_seconds
)
except Exception as e:
logger.error(f"Error getting stream status: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to get stream status: {str(e)}"
)
@router.post("/start")
async def start_streaming(
stream_service: StreamService = Depends(get_stream_service),
current_user: Dict = Depends(require_auth)
):
"""Start the streaming service."""
try:
logger.info(f"Starting streaming service by user: {current_user['id']}")
if await stream_service.is_active():
return JSONResponse(
status_code=200,
content={"message": "Streaming service is already active"}
)
await stream_service.start()
return {
"message": "Streaming service started successfully",
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error starting streaming: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to start streaming: {str(e)}"
)
@router.post("/stop")
async def stop_streaming(
stream_service: StreamService = Depends(get_stream_service),
current_user: Dict = Depends(require_auth)
):
"""Stop the streaming service."""
try:
logger.info(f"Stopping streaming service by user: {current_user['id']}")
await stream_service.stop()
await connection_manager.disconnect_all()
return {
"message": "Streaming service stopped successfully",
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error stopping streaming: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to stop streaming: {str(e)}"
)
@router.get("/clients")
async def get_connected_clients(
current_user: Dict = Depends(require_auth)
):
"""Get list of connected WebSocket clients."""
try:
clients = await connection_manager.get_connected_clients()
return {
"total_clients": len(clients),
"clients": clients,
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error getting connected clients: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to get connected clients: {str(e)}"
)
@router.delete("/clients/{client_id}")
async def disconnect_client(
client_id: str,
current_user: Dict = Depends(require_auth)
):
"""Disconnect a specific WebSocket client."""
try:
logger.info(f"Disconnecting client {client_id} by user: {current_user['id']}")
success = await connection_manager.disconnect(client_id)
if not success:
raise HTTPException(
status_code=404,
detail=f"Client {client_id} not found"
)
return {
"message": f"Client {client_id} disconnected successfully",
"timestamp": datetime.utcnow().isoformat()
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error disconnecting client: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to disconnect client: {str(e)}"
)
@router.post("/broadcast")
async def broadcast_message(
message: Dict[str, Any],
stream_type: Optional[str] = Query(None, description="Target stream type"),
zone_ids: Optional[List[str]] = Query(None, description="Target zone IDs"),
current_user: Dict = Depends(require_auth)
):
"""Broadcast a message to connected WebSocket clients."""
try:
logger.info(f"Broadcasting message by user: {current_user['id']}")
# Add metadata to message
broadcast_data = {
**message,
"broadcast_timestamp": datetime.utcnow().isoformat(),
"sender": current_user["id"]
}
# Broadcast to matching clients
sent_count = await connection_manager.broadcast(
data=broadcast_data,
stream_type=stream_type,
zone_ids=zone_ids
)
return {
"message": "Broadcast sent successfully",
"recipients": sent_count,
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error broadcasting message: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to broadcast message: {str(e)}"
)
@router.get("/metrics")
async def get_streaming_metrics():
"""Get streaming performance metrics."""
try:
metrics = await connection_manager.get_metrics()
return {
"metrics": metrics,
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error getting streaming metrics: {e}")
raise HTTPException(
status_code=500,
detail=f"Failed to get streaming metrics: {str(e)}"
)

View File

@@ -0,0 +1,8 @@
"""
WebSocket handlers package
"""
from .connection_manager import ConnectionManager
from .pose_stream import PoseStreamHandler
__all__ = ["ConnectionManager", "PoseStreamHandler"]

View File

@@ -0,0 +1,461 @@
"""
WebSocket connection manager for WiFi-DensePose API
"""
import asyncio
import json
import logging
import uuid
from typing import Dict, List, Optional, Any, Set
from datetime import datetime, timedelta
from collections import defaultdict
from fastapi import WebSocket, WebSocketDisconnect
logger = logging.getLogger(__name__)
class WebSocketConnection:
"""Represents a WebSocket connection with metadata."""
def __init__(
self,
websocket: WebSocket,
client_id: str,
stream_type: str,
zone_ids: Optional[List[str]] = None,
**config
):
self.websocket = websocket
self.client_id = client_id
self.stream_type = stream_type
self.zone_ids = zone_ids or []
self.config = config
self.connected_at = datetime.utcnow()
self.last_ping = datetime.utcnow()
self.message_count = 0
self.is_active = True
async def send_json(self, data: Dict[str, Any]):
"""Send JSON data to client."""
try:
await self.websocket.send_json(data)
self.message_count += 1
except Exception as e:
logger.error(f"Error sending to client {self.client_id}: {e}")
self.is_active = False
raise
async def send_text(self, message: str):
"""Send text message to client."""
try:
await self.websocket.send_text(message)
self.message_count += 1
except Exception as e:
logger.error(f"Error sending text to client {self.client_id}: {e}")
self.is_active = False
raise
def update_config(self, config: Dict[str, Any]):
"""Update connection configuration."""
self.config.update(config)
# Update zone IDs if provided
if "zone_ids" in config:
self.zone_ids = config["zone_ids"] or []
def matches_filter(
self,
stream_type: Optional[str] = None,
zone_ids: Optional[List[str]] = None,
**filters
) -> bool:
"""Check if connection matches given filters."""
# Check stream type
if stream_type and self.stream_type != stream_type:
return False
# Check zone IDs
if zone_ids:
if not self.zone_ids: # Connection listens to all zones
return True
# Check if any requested zone is in connection's zones
if not any(zone in self.zone_ids for zone in zone_ids):
return False
# Check additional filters
for key, value in filters.items():
if key in self.config and self.config[key] != value:
return False
return True
def get_info(self) -> Dict[str, Any]:
"""Get connection information."""
return {
"client_id": self.client_id,
"stream_type": self.stream_type,
"zone_ids": self.zone_ids,
"config": self.config,
"connected_at": self.connected_at.isoformat(),
"last_ping": self.last_ping.isoformat(),
"message_count": self.message_count,
"is_active": self.is_active,
"uptime_seconds": (datetime.utcnow() - self.connected_at).total_seconds()
}
class ConnectionManager:
"""Manages WebSocket connections for real-time streaming."""
def __init__(self):
self.connections: Dict[str, WebSocketConnection] = {}
self.connections_by_type: Dict[str, Set[str]] = defaultdict(set)
self.connections_by_zone: Dict[str, Set[str]] = defaultdict(set)
self.metrics = {
"total_connections": 0,
"active_connections": 0,
"messages_sent": 0,
"errors": 0,
"start_time": datetime.utcnow()
}
self._cleanup_task = None
self._started = False
async def connect(
self,
websocket: WebSocket,
stream_type: str,
zone_ids: Optional[List[str]] = None,
**config
) -> str:
"""Register a new WebSocket connection."""
client_id = str(uuid.uuid4())
try:
# Create connection object
connection = WebSocketConnection(
websocket=websocket,
client_id=client_id,
stream_type=stream_type,
zone_ids=zone_ids,
**config
)
# Store connection
self.connections[client_id] = connection
self.connections_by_type[stream_type].add(client_id)
# Index by zones
if zone_ids:
for zone_id in zone_ids:
self.connections_by_zone[zone_id].add(client_id)
# Update metrics
self.metrics["total_connections"] += 1
self.metrics["active_connections"] = len(self.connections)
logger.info(f"WebSocket client {client_id} connected for {stream_type}")
return client_id
except Exception as e:
logger.error(f"Error connecting WebSocket client: {e}")
raise
async def disconnect(self, client_id: str) -> bool:
"""Disconnect a WebSocket client."""
if client_id not in self.connections:
return False
try:
connection = self.connections[client_id]
# Remove from indexes
self.connections_by_type[connection.stream_type].discard(client_id)
for zone_id in connection.zone_ids:
self.connections_by_zone[zone_id].discard(client_id)
# Close WebSocket if still active
if connection.is_active:
try:
await connection.websocket.close()
except:
pass # Connection might already be closed
# Remove connection
del self.connections[client_id]
# Update metrics
self.metrics["active_connections"] = len(self.connections)
logger.info(f"WebSocket client {client_id} disconnected")
return True
except Exception as e:
logger.error(f"Error disconnecting client {client_id}: {e}")
return False
async def disconnect_all(self):
"""Disconnect all WebSocket clients."""
client_ids = list(self.connections.keys())
for client_id in client_ids:
await self.disconnect(client_id)
logger.info("All WebSocket clients disconnected")
async def send_to_client(self, client_id: str, data: Dict[str, Any]) -> bool:
"""Send data to a specific client."""
if client_id not in self.connections:
return False
connection = self.connections[client_id]
try:
await connection.send_json(data)
self.metrics["messages_sent"] += 1
return True
except Exception as e:
logger.error(f"Error sending to client {client_id}: {e}")
self.metrics["errors"] += 1
# Mark connection as inactive and schedule for cleanup
connection.is_active = False
return False
async def broadcast(
self,
data: Dict[str, Any],
stream_type: Optional[str] = None,
zone_ids: Optional[List[str]] = None,
**filters
) -> int:
"""Broadcast data to matching clients."""
sent_count = 0
failed_clients = []
# Get matching connections
matching_clients = self._get_matching_clients(
stream_type=stream_type,
zone_ids=zone_ids,
**filters
)
# Send to all matching clients
for client_id in matching_clients:
try:
success = await self.send_to_client(client_id, data)
if success:
sent_count += 1
else:
failed_clients.append(client_id)
except Exception as e:
logger.error(f"Error broadcasting to client {client_id}: {e}")
failed_clients.append(client_id)
# Clean up failed connections
for client_id in failed_clients:
await self.disconnect(client_id)
return sent_count
async def update_client_config(self, client_id: str, config: Dict[str, Any]) -> bool:
"""Update client configuration."""
if client_id not in self.connections:
return False
connection = self.connections[client_id]
old_zones = set(connection.zone_ids)
# Update configuration
connection.update_config(config)
# Update zone indexes if zones changed
new_zones = set(connection.zone_ids)
# Remove from old zones
for zone_id in old_zones - new_zones:
self.connections_by_zone[zone_id].discard(client_id)
# Add to new zones
for zone_id in new_zones - old_zones:
self.connections_by_zone[zone_id].add(client_id)
return True
async def get_client_status(self, client_id: str) -> Optional[Dict[str, Any]]:
"""Get status of a specific client."""
if client_id not in self.connections:
return None
return self.connections[client_id].get_info()
async def get_connected_clients(self) -> List[Dict[str, Any]]:
"""Get list of all connected clients."""
return [conn.get_info() for conn in self.connections.values()]
async def get_connection_stats(self) -> Dict[str, Any]:
"""Get connection statistics."""
stats = {
"total_clients": len(self.connections),
"clients_by_type": {
stream_type: len(clients)
for stream_type, clients in self.connections_by_type.items()
},
"clients_by_zone": {
zone_id: len(clients)
for zone_id, clients in self.connections_by_zone.items()
if clients # Only include zones with active clients
},
"active_clients": sum(1 for conn in self.connections.values() if conn.is_active),
"inactive_clients": sum(1 for conn in self.connections.values() if not conn.is_active)
}
return stats
async def get_metrics(self) -> Dict[str, Any]:
"""Get detailed metrics."""
uptime = (datetime.utcnow() - self.metrics["start_time"]).total_seconds()
return {
**self.metrics,
"active_connections": len(self.connections),
"uptime_seconds": uptime,
"messages_per_second": self.metrics["messages_sent"] / max(uptime, 1),
"error_rate": self.metrics["errors"] / max(self.metrics["messages_sent"], 1)
}
def _get_matching_clients(
self,
stream_type: Optional[str] = None,
zone_ids: Optional[List[str]] = None,
**filters
) -> List[str]:
"""Get client IDs that match the given filters."""
candidates = set(self.connections.keys())
# Filter by stream type
if stream_type:
type_clients = self.connections_by_type.get(stream_type, set())
candidates &= type_clients
# Filter by zones
if zone_ids:
zone_clients = set()
for zone_id in zone_ids:
zone_clients.update(self.connections_by_zone.get(zone_id, set()))
# Also include clients listening to all zones (empty zone list)
all_zone_clients = {
client_id for client_id, conn in self.connections.items()
if not conn.zone_ids
}
zone_clients.update(all_zone_clients)
candidates &= zone_clients
# Apply additional filters
matching_clients = []
for client_id in candidates:
connection = self.connections[client_id]
if connection.is_active and connection.matches_filter(**filters):
matching_clients.append(client_id)
return matching_clients
async def ping_clients(self):
"""Send ping to all connected clients."""
ping_data = {
"type": "ping",
"timestamp": datetime.utcnow().isoformat()
}
failed_clients = []
for client_id, connection in self.connections.items():
try:
await connection.send_json(ping_data)
connection.last_ping = datetime.utcnow()
except Exception as e:
logger.warning(f"Ping failed for client {client_id}: {e}")
failed_clients.append(client_id)
# Clean up failed connections
for client_id in failed_clients:
await self.disconnect(client_id)
async def cleanup_inactive_connections(self):
"""Clean up inactive or stale connections."""
now = datetime.utcnow()
stale_threshold = timedelta(minutes=5) # 5 minutes without ping
stale_clients = []
for client_id, connection in self.connections.items():
# Check if connection is inactive
if not connection.is_active:
stale_clients.append(client_id)
continue
# Check if connection is stale (no ping response)
if now - connection.last_ping > stale_threshold:
logger.warning(f"Client {client_id} appears stale, disconnecting")
stale_clients.append(client_id)
# Clean up stale connections
for client_id in stale_clients:
await self.disconnect(client_id)
if stale_clients:
logger.info(f"Cleaned up {len(stale_clients)} stale connections")
async def start(self):
"""Start the connection manager."""
if not self._started:
self._start_cleanup_task()
self._started = True
logger.info("Connection manager started")
def _start_cleanup_task(self):
"""Start background cleanup task."""
async def cleanup_loop():
while True:
try:
await asyncio.sleep(60) # Run every minute
await self.cleanup_inactive_connections()
# Send periodic ping every 2 minutes
if datetime.utcnow().minute % 2 == 0:
await self.ping_clients()
except Exception as e:
logger.error(f"Error in cleanup task: {e}")
try:
self._cleanup_task = asyncio.create_task(cleanup_loop())
except RuntimeError:
# No event loop running, will start later
logger.debug("No event loop running, cleanup task will start later")
async def shutdown(self):
"""Shutdown connection manager."""
# Cancel cleanup task
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
# Disconnect all clients
await self.disconnect_all()
logger.info("Connection manager shutdown complete")
# Global connection manager instance
connection_manager = ConnectionManager()

View File

@@ -0,0 +1,384 @@
"""
Pose streaming WebSocket handler
"""
import asyncio
import json
import logging
from typing import Dict, List, Optional, Any
from datetime import datetime
from fastapi import WebSocket
from pydantic import BaseModel, Field
from src.api.websocket.connection_manager import ConnectionManager
from src.services.pose_service import PoseService
from src.services.stream_service import StreamService
logger = logging.getLogger(__name__)
class PoseStreamData(BaseModel):
"""Pose stream data model."""
timestamp: datetime = Field(..., description="Data timestamp")
zone_id: str = Field(..., description="Zone identifier")
pose_data: Dict[str, Any] = Field(..., description="Pose estimation data")
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score")
activity: Optional[str] = Field(default=None, description="Detected activity")
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Additional metadata")
class PoseStreamHandler:
"""Handles pose data streaming to WebSocket clients."""
def __init__(
self,
connection_manager: ConnectionManager,
pose_service: PoseService,
stream_service: StreamService
):
self.connection_manager = connection_manager
self.pose_service = pose_service
self.stream_service = stream_service
self.is_streaming = False
self.stream_task = None
self.subscribers = {}
self.stream_config = {
"fps": 30,
"min_confidence": 0.5,
"include_metadata": True,
"buffer_size": 100
}
async def start_streaming(self):
"""Start pose data streaming."""
if self.is_streaming:
logger.warning("Pose streaming already active")
return
self.is_streaming = True
self.stream_task = asyncio.create_task(self._stream_loop())
logger.info("Pose streaming started")
async def stop_streaming(self):
"""Stop pose data streaming."""
if not self.is_streaming:
return
self.is_streaming = False
if self.stream_task:
self.stream_task.cancel()
try:
await self.stream_task
except asyncio.CancelledError:
pass
logger.info("Pose streaming stopped")
async def _stream_loop(self):
"""Main streaming loop."""
try:
logger.info("🚀 Starting pose streaming loop")
while self.is_streaming:
try:
# Get current pose data from all zones
logger.debug("📡 Getting current pose data...")
pose_data = await self.pose_service.get_current_pose_data()
logger.debug(f"📊 Received pose data: {pose_data}")
if pose_data:
logger.debug("📤 Broadcasting pose data...")
await self._process_and_broadcast_pose_data(pose_data)
else:
logger.debug("⚠️ No pose data received")
# Control streaming rate
await asyncio.sleep(1.0 / self.stream_config["fps"])
except Exception as e:
logger.error(f"Error in pose streaming loop: {e}")
await asyncio.sleep(1.0) # Brief pause on error
except asyncio.CancelledError:
logger.info("Pose streaming loop cancelled")
except Exception as e:
logger.error(f"Fatal error in pose streaming loop: {e}")
finally:
logger.info("🛑 Pose streaming loop stopped")
self.is_streaming = False
async def _process_and_broadcast_pose_data(self, raw_pose_data: Dict[str, Any]):
"""Process and broadcast pose data to subscribers."""
try:
# Process data for each zone
for zone_id, zone_data in raw_pose_data.items():
if not zone_data:
continue
# Create structured pose data
pose_stream_data = PoseStreamData(
timestamp=datetime.utcnow(),
zone_id=zone_id,
pose_data=zone_data.get("pose", {}),
confidence=zone_data.get("confidence", 0.0),
activity=zone_data.get("activity"),
metadata=zone_data.get("metadata") if self.stream_config["include_metadata"] else None
)
# Filter by minimum confidence
if pose_stream_data.confidence < self.stream_config["min_confidence"]:
continue
# Broadcast to subscribers
await self._broadcast_pose_data(pose_stream_data)
except Exception as e:
logger.error(f"Error processing pose data: {e}")
async def _broadcast_pose_data(self, pose_data: PoseStreamData):
"""Broadcast pose data to matching WebSocket clients."""
try:
logger.debug(f"📡 Preparing to broadcast pose data for zone {pose_data.zone_id}")
# Prepare broadcast data
broadcast_data = {
"type": "pose_data",
"timestamp": pose_data.timestamp.isoformat(),
"zone_id": pose_data.zone_id,
"data": {
"pose": pose_data.pose_data,
"confidence": pose_data.confidence,
"activity": pose_data.activity
}
}
# Add metadata if enabled
if pose_data.metadata and self.stream_config["include_metadata"]:
broadcast_data["metadata"] = pose_data.metadata
logger.debug(f"📤 Broadcasting data: {broadcast_data}")
# Broadcast to pose stream subscribers
sent_count = await self.connection_manager.broadcast(
data=broadcast_data,
stream_type="pose",
zone_ids=[pose_data.zone_id]
)
logger.info(f"✅ Broadcasted pose data for zone {pose_data.zone_id} to {sent_count} clients")
except Exception as e:
logger.error(f"Error broadcasting pose data: {e}")
async def handle_client_subscription(
self,
client_id: str,
subscription_config: Dict[str, Any]
):
"""Handle client subscription configuration."""
try:
# Store client subscription config
self.subscribers[client_id] = {
"zone_ids": subscription_config.get("zone_ids", []),
"min_confidence": subscription_config.get("min_confidence", 0.5),
"max_fps": subscription_config.get("max_fps", 30),
"include_metadata": subscription_config.get("include_metadata", True),
"stream_types": subscription_config.get("stream_types", ["pose_data"]),
"subscribed_at": datetime.utcnow()
}
logger.info(f"Updated subscription for client {client_id}")
# Send confirmation
confirmation = {
"type": "subscription_updated",
"client_id": client_id,
"config": self.subscribers[client_id],
"timestamp": datetime.utcnow().isoformat()
}
await self.connection_manager.send_to_client(client_id, confirmation)
except Exception as e:
logger.error(f"Error handling client subscription: {e}")
async def handle_client_disconnect(self, client_id: str):
"""Handle client disconnection."""
if client_id in self.subscribers:
del self.subscribers[client_id]
logger.info(f"Removed subscription for disconnected client {client_id}")
async def send_historical_data(
self,
client_id: str,
zone_id: str,
start_time: datetime,
end_time: datetime,
limit: int = 100
):
"""Send historical pose data to client."""
try:
# Get historical data from pose service
historical_data = await self.pose_service.get_historical_data(
zone_id=zone_id,
start_time=start_time,
end_time=end_time,
limit=limit
)
# Send data in chunks to avoid overwhelming the client
chunk_size = 10
for i in range(0, len(historical_data), chunk_size):
chunk = historical_data[i:i + chunk_size]
message = {
"type": "historical_data",
"zone_id": zone_id,
"chunk_index": i // chunk_size,
"total_chunks": (len(historical_data) + chunk_size - 1) // chunk_size,
"data": chunk,
"timestamp": datetime.utcnow().isoformat()
}
await self.connection_manager.send_to_client(client_id, message)
# Small delay between chunks
await asyncio.sleep(0.1)
# Send completion message
completion_message = {
"type": "historical_data_complete",
"zone_id": zone_id,
"total_records": len(historical_data),
"timestamp": datetime.utcnow().isoformat()
}
await self.connection_manager.send_to_client(client_id, completion_message)
except Exception as e:
logger.error(f"Error sending historical data: {e}")
# Send error message to client
error_message = {
"type": "error",
"message": f"Failed to retrieve historical data: {str(e)}",
"timestamp": datetime.utcnow().isoformat()
}
await self.connection_manager.send_to_client(client_id, error_message)
async def send_zone_statistics(self, client_id: str, zone_id: str):
"""Send zone statistics to client."""
try:
# Get zone statistics
stats = await self.pose_service.get_zone_statistics(zone_id)
message = {
"type": "zone_statistics",
"zone_id": zone_id,
"statistics": stats,
"timestamp": datetime.utcnow().isoformat()
}
await self.connection_manager.send_to_client(client_id, message)
except Exception as e:
logger.error(f"Error sending zone statistics: {e}")
async def broadcast_system_event(self, event_type: str, event_data: Dict[str, Any]):
"""Broadcast system events to all connected clients."""
try:
message = {
"type": "system_event",
"event_type": event_type,
"data": event_data,
"timestamp": datetime.utcnow().isoformat()
}
# Broadcast to all pose stream clients
sent_count = await self.connection_manager.broadcast(
data=message,
stream_type="pose"
)
logger.info(f"Broadcasted system event '{event_type}' to {sent_count} clients")
except Exception as e:
logger.error(f"Error broadcasting system event: {e}")
async def update_stream_config(self, config: Dict[str, Any]):
"""Update streaming configuration."""
try:
# Validate and update configuration
if "fps" in config:
fps = max(1, min(60, config["fps"]))
self.stream_config["fps"] = fps
if "min_confidence" in config:
confidence = max(0.0, min(1.0, config["min_confidence"]))
self.stream_config["min_confidence"] = confidence
if "include_metadata" in config:
self.stream_config["include_metadata"] = bool(config["include_metadata"])
if "buffer_size" in config:
buffer_size = max(10, min(1000, config["buffer_size"]))
self.stream_config["buffer_size"] = buffer_size
logger.info(f"Updated stream configuration: {self.stream_config}")
# Broadcast configuration update to clients
await self.broadcast_system_event("stream_config_updated", {
"new_config": self.stream_config
})
except Exception as e:
logger.error(f"Error updating stream configuration: {e}")
def get_stream_status(self) -> Dict[str, Any]:
"""Get current streaming status."""
return {
"is_streaming": self.is_streaming,
"config": self.stream_config,
"subscriber_count": len(self.subscribers),
"subscribers": {
client_id: {
"zone_ids": sub["zone_ids"],
"min_confidence": sub["min_confidence"],
"subscribed_at": sub["subscribed_at"].isoformat()
}
for client_id, sub in self.subscribers.items()
}
}
async def get_performance_metrics(self) -> Dict[str, Any]:
"""Get streaming performance metrics."""
try:
# Get connection manager metrics
conn_metrics = await self.connection_manager.get_metrics()
# Get pose service metrics
pose_metrics = await self.pose_service.get_performance_metrics()
return {
"streaming": {
"is_active": self.is_streaming,
"fps": self.stream_config["fps"],
"subscriber_count": len(self.subscribers)
},
"connections": conn_metrics,
"pose_service": pose_metrics,
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error getting performance metrics: {e}")
return {}
async def shutdown(self):
"""Shutdown pose stream handler."""
await self.stop_streaming()
self.subscribers.clear()
logger.info("Pose stream handler shutdown complete")