feat: Complete Rust port of WiFi-DensePose with modular crates
Major changes: - Organized Python v1 implementation into v1/ subdirectory - Created Rust workspace with 9 modular crates: - wifi-densepose-core: Core types, traits, errors - wifi-densepose-signal: CSI processing, phase sanitization, FFT - wifi-densepose-nn: Neural network inference (ONNX/Candle/tch) - wifi-densepose-api: Axum-based REST/WebSocket API - wifi-densepose-db: SQLx database layer - wifi-densepose-config: Configuration management - wifi-densepose-hardware: Hardware abstraction - wifi-densepose-wasm: WebAssembly bindings - wifi-densepose-cli: Command-line interface Documentation: - ADR-001: Workspace structure - ADR-002: Signal processing library selection - ADR-003: Neural network inference strategy - DDD domain model with bounded contexts Testing: - 69 tests passing across all crates - Signal processing: 45 tests - Neural networks: 21 tests - Core: 3 doc tests Performance targets: - 10x faster CSI processing (~0.5ms vs ~5ms) - 5x lower memory usage (~100MB vs ~500MB) - WASM support for browser deployment
This commit is contained in:
421
v1/src/api/main.py
Normal file
421
v1/src/api/main.py
Normal file
@@ -0,0 +1,421 @@
|
||||
"""
|
||||
FastAPI application for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import logging.config
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict, Any
|
||||
|
||||
from fastapi import FastAPI, Request, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
|
||||
from src.config.settings import get_settings
|
||||
from src.config.domains import get_domain_config
|
||||
from src.api.routers import pose, stream, health
|
||||
from src.api.middleware.auth import AuthMiddleware
|
||||
from src.api.middleware.rate_limit import RateLimitMiddleware
|
||||
from src.api.dependencies import get_pose_service, get_stream_service, get_hardware_service
|
||||
from src.api.websocket.connection_manager import connection_manager
|
||||
from src.api.websocket.pose_stream import PoseStreamHandler
|
||||
|
||||
# Configure logging
|
||||
settings = get_settings()
|
||||
logging.config.dictConfig(settings.get_logging_config())
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan manager."""
|
||||
logger.info("Starting WiFi-DensePose API...")
|
||||
|
||||
try:
|
||||
# Initialize services
|
||||
await initialize_services(app)
|
||||
|
||||
# Start background tasks
|
||||
await start_background_tasks(app)
|
||||
|
||||
logger.info("WiFi-DensePose API started successfully")
|
||||
|
||||
yield
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start application: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Cleanup on shutdown
|
||||
logger.info("Shutting down WiFi-DensePose API...")
|
||||
await cleanup_services(app)
|
||||
logger.info("WiFi-DensePose API shutdown complete")
|
||||
|
||||
|
||||
async def initialize_services(app: FastAPI):
|
||||
"""Initialize application services."""
|
||||
try:
|
||||
# Initialize hardware service
|
||||
hardware_service = get_hardware_service()
|
||||
await hardware_service.initialize()
|
||||
|
||||
# Initialize pose service
|
||||
pose_service = get_pose_service()
|
||||
await pose_service.initialize()
|
||||
|
||||
# Initialize stream service
|
||||
stream_service = get_stream_service()
|
||||
await stream_service.initialize()
|
||||
|
||||
# Initialize pose stream handler
|
||||
pose_stream_handler = PoseStreamHandler(
|
||||
connection_manager=connection_manager,
|
||||
pose_service=pose_service,
|
||||
stream_service=stream_service
|
||||
)
|
||||
|
||||
# Store in app state for access in routes
|
||||
app.state.hardware_service = hardware_service
|
||||
app.state.pose_service = pose_service
|
||||
app.state.stream_service = stream_service
|
||||
app.state.pose_stream_handler = pose_stream_handler
|
||||
|
||||
logger.info("Services initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize services: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def start_background_tasks(app: FastAPI):
|
||||
"""Start background tasks."""
|
||||
try:
|
||||
# Start pose service
|
||||
pose_service = app.state.pose_service
|
||||
await pose_service.start()
|
||||
logger.info("Pose service started")
|
||||
|
||||
# Start pose streaming if enabled
|
||||
if settings.enable_real_time_processing:
|
||||
pose_stream_handler = app.state.pose_stream_handler
|
||||
await pose_stream_handler.start_streaming()
|
||||
|
||||
logger.info("Background tasks started")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start background tasks: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def cleanup_services(app: FastAPI):
|
||||
"""Cleanup services on shutdown."""
|
||||
try:
|
||||
# Stop pose streaming
|
||||
if hasattr(app.state, 'pose_stream_handler'):
|
||||
await app.state.pose_stream_handler.shutdown()
|
||||
|
||||
# Shutdown connection manager
|
||||
await connection_manager.shutdown()
|
||||
|
||||
# Cleanup services
|
||||
if hasattr(app.state, 'stream_service'):
|
||||
await app.state.stream_service.shutdown()
|
||||
|
||||
if hasattr(app.state, 'pose_service'):
|
||||
await app.state.pose_service.stop()
|
||||
|
||||
if hasattr(app.state, 'hardware_service'):
|
||||
await app.state.hardware_service.shutdown()
|
||||
|
||||
logger.info("Services cleaned up successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during cleanup: {e}")
|
||||
|
||||
|
||||
# Create FastAPI application
|
||||
app = FastAPI(
|
||||
title=settings.app_name,
|
||||
version=settings.version,
|
||||
description="WiFi-based human pose estimation and activity recognition API",
|
||||
docs_url=settings.docs_url if not settings.is_production else None,
|
||||
redoc_url=settings.redoc_url if not settings.is_production else None,
|
||||
openapi_url=settings.openapi_url if not settings.is_production else None,
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Add middleware
|
||||
if settings.enable_rate_limiting:
|
||||
app.add_middleware(RateLimitMiddleware)
|
||||
|
||||
if settings.enable_authentication:
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
# Add CORS middleware
|
||||
cors_config = settings.get_cors_config()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
**cors_config
|
||||
)
|
||||
|
||||
# Add trusted host middleware for production
|
||||
if settings.is_production:
|
||||
app.add_middleware(
|
||||
TrustedHostMiddleware,
|
||||
allowed_hosts=settings.allowed_hosts
|
||||
)
|
||||
|
||||
|
||||
# Exception handlers
|
||||
@app.exception_handler(StarletteHTTPException)
|
||||
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
|
||||
"""Handle HTTP exceptions."""
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={
|
||||
"error": {
|
||||
"code": exc.status_code,
|
||||
"message": exc.detail,
|
||||
"type": "http_error"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
"""Handle request validation errors."""
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content={
|
||||
"error": {
|
||||
"code": 422,
|
||||
"message": "Validation error",
|
||||
"type": "validation_error",
|
||||
"details": exc.errors()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def general_exception_handler(request: Request, exc: Exception):
|
||||
"""Handle general exceptions."""
|
||||
logger.error(f"Unhandled exception: {exc}", exc_info=True)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": {
|
||||
"code": 500,
|
||||
"message": "Internal server error",
|
||||
"type": "internal_error"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Middleware for request logging
|
||||
@app.middleware("http")
|
||||
async def log_requests(request: Request, call_next):
|
||||
"""Log all requests."""
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Calculate processing time
|
||||
process_time = asyncio.get_event_loop().time() - start_time
|
||||
|
||||
# Log request
|
||||
logger.info(
|
||||
f"{request.method} {request.url.path} - "
|
||||
f"Status: {response.status_code} - "
|
||||
f"Time: {process_time:.3f}s"
|
||||
)
|
||||
|
||||
# Add processing time header
|
||||
response.headers["X-Process-Time"] = str(process_time)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# Include routers
|
||||
app.include_router(
|
||||
health.router,
|
||||
prefix="/health",
|
||||
tags=["Health"]
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
pose.router,
|
||||
prefix=f"{settings.api_prefix}/pose",
|
||||
tags=["Pose Estimation"]
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
stream.router,
|
||||
prefix=f"{settings.api_prefix}/stream",
|
||||
tags=["Streaming"]
|
||||
)
|
||||
|
||||
|
||||
# Root endpoint
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint with API information."""
|
||||
return {
|
||||
"name": settings.app_name,
|
||||
"version": settings.version,
|
||||
"environment": settings.environment,
|
||||
"docs_url": settings.docs_url,
|
||||
"api_prefix": settings.api_prefix,
|
||||
"features": {
|
||||
"authentication": settings.enable_authentication,
|
||||
"rate_limiting": settings.enable_rate_limiting,
|
||||
"websockets": settings.enable_websockets,
|
||||
"real_time_processing": settings.enable_real_time_processing
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# API information endpoint
|
||||
@app.get(f"{settings.api_prefix}/info")
|
||||
async def api_info():
|
||||
"""Get detailed API information."""
|
||||
domain_config = get_domain_config()
|
||||
|
||||
return {
|
||||
"api": {
|
||||
"name": settings.app_name,
|
||||
"version": settings.version,
|
||||
"environment": settings.environment,
|
||||
"prefix": settings.api_prefix
|
||||
},
|
||||
"configuration": {
|
||||
"zones": len(domain_config.zones),
|
||||
"routers": len(domain_config.routers),
|
||||
"pose_models": len(domain_config.pose_models)
|
||||
},
|
||||
"features": {
|
||||
"authentication": settings.enable_authentication,
|
||||
"rate_limiting": settings.enable_rate_limiting,
|
||||
"websockets": settings.enable_websockets,
|
||||
"real_time_processing": settings.enable_real_time_processing,
|
||||
"historical_data": settings.enable_historical_data
|
||||
},
|
||||
"limits": {
|
||||
"rate_limit_requests": settings.rate_limit_requests,
|
||||
"rate_limit_window": settings.rate_limit_window,
|
||||
"max_websocket_connections": domain_config.streaming.max_connections
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Status endpoint
|
||||
@app.get(f"{settings.api_prefix}/status")
|
||||
async def api_status(request: Request):
|
||||
"""Get current API status."""
|
||||
try:
|
||||
# Get services from app state
|
||||
hardware_service = getattr(request.app.state, 'hardware_service', None)
|
||||
pose_service = getattr(request.app.state, 'pose_service', None)
|
||||
stream_service = getattr(request.app.state, 'stream_service', None)
|
||||
pose_stream_handler = getattr(request.app.state, 'pose_stream_handler', None)
|
||||
|
||||
# Get service statuses
|
||||
status = {
|
||||
"api": {
|
||||
"status": "healthy",
|
||||
"uptime": "unknown",
|
||||
"version": settings.version
|
||||
},
|
||||
"services": {
|
||||
"hardware": await hardware_service.get_status() if hardware_service else {"status": "unavailable"},
|
||||
"pose": await pose_service.get_status() if pose_service else {"status": "unavailable"},
|
||||
"stream": await stream_service.get_status() if stream_service else {"status": "unavailable"}
|
||||
},
|
||||
"streaming": pose_stream_handler.get_stream_status() if pose_stream_handler else {"is_streaming": False},
|
||||
"connections": await connection_manager.get_connection_stats()
|
||||
}
|
||||
|
||||
return status
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting API status: {e}")
|
||||
return {
|
||||
"api": {
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Metrics endpoint (if enabled)
|
||||
if settings.metrics_enabled:
|
||||
@app.get(f"{settings.api_prefix}/metrics")
|
||||
async def api_metrics(request: Request):
|
||||
"""Get API metrics."""
|
||||
try:
|
||||
# Get services from app state
|
||||
pose_stream_handler = getattr(request.app.state, 'pose_stream_handler', None)
|
||||
|
||||
metrics = {
|
||||
"connections": await connection_manager.get_metrics(),
|
||||
"streaming": await pose_stream_handler.get_performance_metrics() if pose_stream_handler else {}
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting metrics: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
# Development endpoints (only in development)
|
||||
if settings.is_development and settings.enable_test_endpoints:
|
||||
@app.get(f"{settings.api_prefix}/dev/config")
|
||||
async def dev_config():
|
||||
"""Get current configuration (development only)."""
|
||||
domain_config = get_domain_config()
|
||||
return {
|
||||
"settings": settings.dict(),
|
||||
"domain_config": domain_config.to_dict()
|
||||
}
|
||||
|
||||
@app.post(f"{settings.api_prefix}/dev/reset")
|
||||
async def dev_reset(request: Request):
|
||||
"""Reset services (development only)."""
|
||||
try:
|
||||
# Reset services
|
||||
hardware_service = getattr(request.app.state, 'hardware_service', None)
|
||||
pose_service = getattr(request.app.state, 'pose_service', None)
|
||||
|
||||
if hardware_service:
|
||||
await hardware_service.reset()
|
||||
|
||||
if pose_service:
|
||||
await pose_service.reset()
|
||||
|
||||
return {"message": "Services reset successfully"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error resetting services: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"src.api.main:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
reload=settings.reload,
|
||||
workers=settings.workers if not settings.reload else 1,
|
||||
log_level=settings.log_level.lower()
|
||||
)
|
||||
Reference in New Issue
Block a user