feat: Complete Rust port of WiFi-DensePose with modular crates

Major changes:
- Organized Python v1 implementation into v1/ subdirectory
- Created Rust workspace with 9 modular crates:
  - wifi-densepose-core: Core types, traits, errors
  - wifi-densepose-signal: CSI processing, phase sanitization, FFT
  - wifi-densepose-nn: Neural network inference (ONNX/Candle/tch)
  - wifi-densepose-api: Axum-based REST/WebSocket API
  - wifi-densepose-db: SQLx database layer
  - wifi-densepose-config: Configuration management
  - wifi-densepose-hardware: Hardware abstraction
  - wifi-densepose-wasm: WebAssembly bindings
  - wifi-densepose-cli: Command-line interface

Documentation:
- ADR-001: Workspace structure
- ADR-002: Signal processing library selection
- ADR-003: Neural network inference strategy
- DDD domain model with bounded contexts

Testing:
- 69 tests passing across all crates
- Signal processing: 45 tests
- Neural networks: 21 tests
- Core: 3 doc tests

Performance targets:
- 10x faster CSI processing (~0.5ms vs ~5ms)
- 5x lower memory usage (~100MB vs ~500MB)
- WASM support for browser deployment
This commit is contained in:
Claude
2026-01-13 03:11:16 +00:00
parent 5101504b72
commit 6ed69a3d48
427 changed files with 90993 additions and 0 deletions

467
v1/src/middleware/auth.py Normal file
View File

@@ -0,0 +1,467 @@
"""
Authentication middleware for WiFi-DensePose API
"""
import logging
import time
from typing import Optional, Dict, Any, Callable
from datetime import datetime, timedelta
from fastapi import Request, Response, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jose import JWTError, jwt
from passlib.context import CryptContext
from src.config.settings import Settings
from src.logger import set_request_context
logger = logging.getLogger(__name__)
# Password hashing
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# JWT token handler
security = HTTPBearer(auto_error=False)
class AuthenticationError(Exception):
"""Authentication error."""
pass
class AuthorizationError(Exception):
"""Authorization error."""
pass
class TokenManager:
"""JWT token management."""
def __init__(self, settings: Settings):
self.settings = settings
self.secret_key = settings.secret_key
self.algorithm = settings.jwt_algorithm
self.expire_hours = settings.jwt_expire_hours
def create_access_token(self, data: Dict[str, Any]) -> str:
"""Create JWT access token."""
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(hours=self.expire_hours)
to_encode.update({"exp": expire, "iat": datetime.utcnow()})
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
return encoded_jwt
def verify_token(self, token: str) -> Dict[str, Any]:
"""Verify and decode JWT token."""
try:
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
return payload
except JWTError as e:
logger.warning(f"JWT verification failed: {e}")
raise AuthenticationError("Invalid token")
def decode_token(self, token: str) -> Optional[Dict[str, Any]]:
"""Decode token without verification (for debugging)."""
try:
return jwt.decode(token, options={"verify_signature": False})
except JWTError:
return None
class UserManager:
"""User management for authentication."""
def __init__(self):
# In a real application, this would connect to a database
# For now, we'll use a simple in-memory store
self._users: Dict[str, Dict[str, Any]] = {
"admin": {
"username": "admin",
"email": "admin@example.com",
"hashed_password": self.hash_password("admin123"),
"roles": ["admin"],
"is_active": True,
"created_at": datetime.utcnow(),
},
"user": {
"username": "user",
"email": "user@example.com",
"hashed_password": self.hash_password("user123"),
"roles": ["user"],
"is_active": True,
"created_at": datetime.utcnow(),
}
}
@staticmethod
def hash_password(password: str) -> str:
"""Hash a password."""
return pwd_context.hash(password)
@staticmethod
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash."""
return pwd_context.verify(plain_password, hashed_password)
def get_user(self, username: str) -> Optional[Dict[str, Any]]:
"""Get user by username."""
return self._users.get(username)
def authenticate_user(self, username: str, password: str) -> Optional[Dict[str, Any]]:
"""Authenticate user with username and password."""
user = self.get_user(username)
if not user:
return None
if not self.verify_password(password, user["hashed_password"]):
return None
if not user.get("is_active", False):
return None
return user
def create_user(self, username: str, email: str, password: str, roles: list = None) -> Dict[str, Any]:
"""Create a new user."""
if username in self._users:
raise ValueError("User already exists")
user = {
"username": username,
"email": email,
"hashed_password": self.hash_password(password),
"roles": roles or ["user"],
"is_active": True,
"created_at": datetime.utcnow(),
}
self._users[username] = user
return user
def update_user(self, username: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Update user information."""
user = self._users.get(username)
if not user:
return None
# Don't allow updating certain fields
protected_fields = {"username", "created_at", "hashed_password"}
updates = {k: v for k, v in updates.items() if k not in protected_fields}
user.update(updates)
return user
def deactivate_user(self, username: str) -> bool:
"""Deactivate a user."""
user = self._users.get(username)
if user:
user["is_active"] = False
return True
return False
class AuthenticationMiddleware:
"""Authentication middleware for FastAPI."""
def __init__(self, settings: Settings):
self.settings = settings
self.token_manager = TokenManager(settings)
self.user_manager = UserManager()
self.enabled = settings.enable_authentication
async def __call__(self, request: Request, call_next: Callable) -> Response:
"""Process request through authentication middleware."""
start_time = time.time()
try:
# Skip authentication for certain paths
if self._should_skip_auth(request):
response = await call_next(request)
return response
# Skip if authentication is disabled
if not self.enabled:
response = await call_next(request)
return response
# Extract and verify token
user_info = await self._authenticate_request(request)
# Set user context
if user_info:
request.state.user = user_info
set_request_context(user_id=user_info.get("username"))
# Process request
response = await call_next(request)
# Add authentication headers
self._add_auth_headers(response, user_info)
return response
except AuthenticationError as e:
logger.warning(f"Authentication failed: {e}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
headers={"WWW-Authenticate": "Bearer"},
)
except AuthorizationError as e:
logger.warning(f"Authorization failed: {e}")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e),
)
except Exception as e:
logger.error(f"Authentication middleware error: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Authentication service error",
)
finally:
# Log request processing time
processing_time = time.time() - start_time
logger.debug(f"Auth middleware processing time: {processing_time:.3f}s")
def _should_skip_auth(self, request: Request) -> bool:
"""Check if authentication should be skipped for this request."""
path = request.url.path
# Skip authentication for these paths
skip_paths = [
"/health",
"/metrics",
"/docs",
"/redoc",
"/openapi.json",
"/auth/login",
"/auth/register",
"/static",
]
return any(path.startswith(skip_path) for skip_path in skip_paths)
async def _authenticate_request(self, request: Request) -> Optional[Dict[str, Any]]:
"""Authenticate the request and return user info."""
# Try to get token from Authorization header
authorization = request.headers.get("Authorization")
if not authorization:
# For WebSocket connections, try to get token from query parameters
if request.url.path.startswith("/ws"):
token = request.query_params.get("token")
if token:
authorization = f"Bearer {token}"
if not authorization:
if self._requires_auth(request):
raise AuthenticationError("Missing authorization header")
return None
# Extract token
try:
scheme, token = authorization.split()
if scheme.lower() != "bearer":
raise AuthenticationError("Invalid authentication scheme")
except ValueError:
raise AuthenticationError("Invalid authorization header format")
# Verify token
try:
payload = self.token_manager.verify_token(token)
username = payload.get("sub")
if not username:
raise AuthenticationError("Invalid token payload")
# Get user info
user = self.user_manager.get_user(username)
if not user:
raise AuthenticationError("User not found")
if not user.get("is_active", False):
raise AuthenticationError("User account is disabled")
# Return user info without sensitive data
return {
"username": user["username"],
"email": user["email"],
"roles": user["roles"],
"is_active": user["is_active"],
}
except AuthenticationError:
raise
except Exception as e:
logger.error(f"Token verification error: {e}")
raise AuthenticationError("Token verification failed")
def _requires_auth(self, request: Request) -> bool:
"""Check if the request requires authentication."""
# All API endpoints require authentication by default
path = request.url.path
return path.startswith("/api/") or path.startswith("/ws/")
def _add_auth_headers(self, response: Response, user_info: Optional[Dict[str, Any]]):
"""Add authentication-related headers to response."""
if user_info:
response.headers["X-User"] = user_info["username"]
response.headers["X-User-Roles"] = ",".join(user_info["roles"])
async def login(self, username: str, password: str) -> Dict[str, Any]:
"""Authenticate user and return token."""
user = self.user_manager.authenticate_user(username, password)
if not user:
raise AuthenticationError("Invalid username or password")
# Create token
token_data = {
"sub": user["username"],
"email": user["email"],
"roles": user["roles"],
}
access_token = self.token_manager.create_access_token(token_data)
return {
"access_token": access_token,
"token_type": "bearer",
"expires_in": self.settings.jwt_expire_hours * 3600,
"user": {
"username": user["username"],
"email": user["email"],
"roles": user["roles"],
}
}
async def register(self, username: str, email: str, password: str) -> Dict[str, Any]:
"""Register a new user."""
try:
user = self.user_manager.create_user(username, email, password)
# Create token for new user
token_data = {
"sub": user["username"],
"email": user["email"],
"roles": user["roles"],
}
access_token = self.token_manager.create_access_token(token_data)
return {
"access_token": access_token,
"token_type": "bearer",
"expires_in": self.settings.jwt_expire_hours * 3600,
"user": {
"username": user["username"],
"email": user["email"],
"roles": user["roles"],
}
}
except ValueError as e:
raise AuthenticationError(str(e))
async def refresh_token(self, token: str) -> Dict[str, Any]:
"""Refresh an access token."""
try:
payload = self.token_manager.verify_token(token)
username = payload.get("sub")
user = self.user_manager.get_user(username)
if not user or not user.get("is_active", False):
raise AuthenticationError("User not found or inactive")
# Create new token
token_data = {
"sub": user["username"],
"email": user["email"],
"roles": user["roles"],
}
new_token = self.token_manager.create_access_token(token_data)
return {
"access_token": new_token,
"token_type": "bearer",
"expires_in": self.settings.jwt_expire_hours * 3600,
}
except Exception as e:
raise AuthenticationError("Token refresh failed")
def check_permission(self, user_info: Dict[str, Any], required_role: str) -> bool:
"""Check if user has required role/permission."""
user_roles = user_info.get("roles", [])
# Admin role has all permissions
if "admin" in user_roles:
return True
# Check specific role
return required_role in user_roles
def require_role(self, required_role: str):
"""Decorator to require specific role."""
def decorator(func):
import functools
@functools.wraps(func)
async def wrapper(request: Request, *args, **kwargs):
user_info = getattr(request.state, "user", None)
if not user_info:
raise AuthorizationError("Authentication required")
if not self.check_permission(user_info, required_role):
raise AuthorizationError(f"Role '{required_role}' required")
return await func(request, *args, **kwargs)
return wrapper
return decorator
# Global authentication middleware instance
_auth_middleware: Optional[AuthenticationMiddleware] = None
def get_auth_middleware(settings: Settings) -> AuthenticationMiddleware:
"""Get authentication middleware instance."""
global _auth_middleware
if _auth_middleware is None:
_auth_middleware = AuthenticationMiddleware(settings)
return _auth_middleware
def get_current_user(request: Request) -> Optional[Dict[str, Any]]:
"""Get current authenticated user from request."""
return getattr(request.state, "user", None)
def require_authentication(request: Request) -> Dict[str, Any]:
"""Require authentication and return user info."""
user = get_current_user(request)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required",
headers={"WWW-Authenticate": "Bearer"},
)
return user
def require_role(role: str):
"""Dependency to require specific role."""
def dependency(request: Request) -> Dict[str, Any]:
user = require_authentication(request)
auth_middleware = get_auth_middleware(request.app.state.settings)
if not auth_middleware.check_permission(user, role):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Role '{role}' required",
)
return user
return dependency

375
v1/src/middleware/cors.py Normal file
View File

@@ -0,0 +1,375 @@
"""
CORS middleware for WiFi-DensePose API
"""
import logging
from typing import List, Optional, Union, Callable
from urllib.parse import urlparse
from fastapi import Request, Response
from fastapi.middleware.cors import CORSMiddleware as FastAPICORSMiddleware
from starlette.types import ASGIApp
from src.config.settings import Settings
logger = logging.getLogger(__name__)
class CORSMiddleware:
"""Enhanced CORS middleware with additional security features."""
def __init__(
self,
app: ASGIApp,
settings: Settings,
allow_origins: Optional[List[str]] = None,
allow_methods: Optional[List[str]] = None,
allow_headers: Optional[List[str]] = None,
allow_credentials: bool = False,
expose_headers: Optional[List[str]] = None,
max_age: int = 600,
):
self.app = app
self.settings = settings
self.allow_origins = allow_origins or settings.cors_origins
self.allow_methods = allow_methods or ["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"]
self.allow_headers = allow_headers or [
"Accept",
"Accept-Language",
"Content-Language",
"Content-Type",
"Authorization",
"X-Requested-With",
"X-Request-ID",
"X-User-Agent",
]
self.allow_credentials = allow_credentials or settings.cors_allow_credentials
self.expose_headers = expose_headers or [
"X-Request-ID",
"X-Response-Time",
"X-Rate-Limit-Remaining",
"X-Rate-Limit-Reset",
]
self.max_age = max_age
# Security settings
self.strict_origin_check = settings.is_production
self.log_cors_violations = True
async def __call__(self, scope, receive, send):
"""ASGI middleware implementation."""
if scope["type"] != "http":
await self.app(scope, receive, send)
return
request = Request(scope, receive)
# Check if this is a CORS preflight request
if request.method == "OPTIONS" and "access-control-request-method" in request.headers:
response = await self._handle_preflight(request)
await response(scope, receive, send)
return
# Handle actual request
async def send_wrapper(message):
if message["type"] == "http.response.start":
# Add CORS headers to response
headers = dict(message.get("headers", []))
cors_headers = self._get_cors_headers(request)
for key, value in cors_headers.items():
headers[key.encode()] = value.encode()
message["headers"] = list(headers.items())
await send(message)
await self.app(scope, receive, send_wrapper)
async def _handle_preflight(self, request: Request) -> Response:
"""Handle CORS preflight request."""
origin = request.headers.get("origin")
requested_method = request.headers.get("access-control-request-method")
requested_headers = request.headers.get("access-control-request-headers", "")
# Validate origin
if not self._is_origin_allowed(origin):
if self.log_cors_violations:
logger.warning(f"CORS preflight rejected for origin: {origin}")
return Response(
status_code=403,
content="CORS preflight request rejected",
headers={"Content-Type": "text/plain"}
)
# Validate method
if requested_method not in self.allow_methods:
if self.log_cors_violations:
logger.warning(f"CORS preflight rejected for method: {requested_method}")
return Response(
status_code=405,
content="Method not allowed",
headers={"Content-Type": "text/plain"}
)
# Validate headers
if requested_headers:
requested_header_list = [h.strip().lower() for h in requested_headers.split(",")]
allowed_headers_lower = [h.lower() for h in self.allow_headers]
for header in requested_header_list:
if header not in allowed_headers_lower:
if self.log_cors_violations:
logger.warning(f"CORS preflight rejected for header: {header}")
return Response(
status_code=400,
content="Header not allowed",
headers={"Content-Type": "text/plain"}
)
# Build preflight response headers
headers = {
"Access-Control-Allow-Origin": origin,
"Access-Control-Allow-Methods": ", ".join(self.allow_methods),
"Access-Control-Allow-Headers": ", ".join(self.allow_headers),
"Access-Control-Max-Age": str(self.max_age),
}
if self.allow_credentials:
headers["Access-Control-Allow-Credentials"] = "true"
if self.expose_headers:
headers["Access-Control-Expose-Headers"] = ", ".join(self.expose_headers)
logger.debug(f"CORS preflight approved for origin: {origin}")
return Response(
status_code=200,
headers=headers
)
def _get_cors_headers(self, request: Request) -> dict:
"""Get CORS headers for actual request."""
origin = request.headers.get("origin")
headers = {}
if self._is_origin_allowed(origin):
headers["Access-Control-Allow-Origin"] = origin
if self.allow_credentials:
headers["Access-Control-Allow-Credentials"] = "true"
if self.expose_headers:
headers["Access-Control-Expose-Headers"] = ", ".join(self.expose_headers)
return headers
def _is_origin_allowed(self, origin: Optional[str]) -> bool:
"""Check if origin is allowed."""
if not origin:
return not self.strict_origin_check
# Allow all origins in development
if not self.settings.is_production and "*" in self.allow_origins:
return True
# Check exact matches
if origin in self.allow_origins:
return True
# Check wildcard patterns
for allowed_origin in self.allow_origins:
if allowed_origin == "*":
return not self.strict_origin_check
if self._match_origin_pattern(origin, allowed_origin):
return True
return False
def _match_origin_pattern(self, origin: str, pattern: str) -> bool:
"""Match origin against pattern with wildcard support."""
if "*" not in pattern:
return origin == pattern
# Simple wildcard matching
if pattern.startswith("*."):
domain = pattern[2:]
parsed_origin = urlparse(origin)
origin_host = parsed_origin.netloc
# Check if origin ends with the domain
return origin_host.endswith(domain) or origin_host == domain[1:] if domain.startswith('.') else origin_host == domain
return False
def setup_cors_middleware(app: ASGIApp, settings: Settings) -> ASGIApp:
"""Setup CORS middleware for the application."""
if settings.cors_enabled:
logger.info("Setting up CORS middleware")
# Use FastAPI's built-in CORS middleware for basic functionality
app = FastAPICORSMiddleware(
app,
allow_origins=settings.cors_origins,
allow_credentials=settings.cors_allow_credentials,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
allow_headers=[
"Accept",
"Accept-Language",
"Content-Language",
"Content-Type",
"Authorization",
"X-Requested-With",
"X-Request-ID",
"X-User-Agent",
],
expose_headers=[
"X-Request-ID",
"X-Response-Time",
"X-Rate-Limit-Remaining",
"X-Rate-Limit-Reset",
],
max_age=600,
)
logger.info(f"CORS enabled for origins: {settings.cors_origins}")
else:
logger.info("CORS middleware disabled")
return app
class CORSConfig:
"""CORS configuration helper."""
@staticmethod
def development_config() -> dict:
"""Get CORS configuration for development."""
return {
"allow_origins": ["*"],
"allow_credentials": True,
"allow_methods": ["*"],
"allow_headers": ["*"],
"expose_headers": [
"X-Request-ID",
"X-Response-Time",
"X-Rate-Limit-Remaining",
"X-Rate-Limit-Reset",
],
"max_age": 600,
}
@staticmethod
def production_config(allowed_origins: List[str]) -> dict:
"""Get CORS configuration for production."""
return {
"allow_origins": allowed_origins,
"allow_credentials": True,
"allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
"allow_headers": [
"Accept",
"Accept-Language",
"Content-Language",
"Content-Type",
"Authorization",
"X-Requested-With",
"X-Request-ID",
"X-User-Agent",
],
"expose_headers": [
"X-Request-ID",
"X-Response-Time",
"X-Rate-Limit-Remaining",
"X-Rate-Limit-Reset",
],
"max_age": 3600, # 1 hour for production
}
@staticmethod
def api_only_config(allowed_origins: List[str]) -> dict:
"""Get CORS configuration for API-only access."""
return {
"allow_origins": allowed_origins,
"allow_credentials": False,
"allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
"allow_headers": [
"Accept",
"Content-Type",
"Authorization",
"X-Request-ID",
],
"expose_headers": [
"X-Request-ID",
"X-Rate-Limit-Remaining",
"X-Rate-Limit-Reset",
],
"max_age": 3600,
}
@staticmethod
def websocket_config(allowed_origins: List[str]) -> dict:
"""Get CORS configuration for WebSocket connections."""
return {
"allow_origins": allowed_origins,
"allow_credentials": True,
"allow_methods": ["GET", "OPTIONS"],
"allow_headers": [
"Accept",
"Authorization",
"Sec-WebSocket-Protocol",
"Sec-WebSocket-Extensions",
],
"expose_headers": [],
"max_age": 86400, # 24 hours for WebSocket
}
def validate_cors_config(settings: Settings) -> List[str]:
"""Validate CORS configuration and return issues."""
issues = []
if not settings.cors_enabled:
return issues
# Check origins
if not settings.cors_origins:
issues.append("CORS is enabled but no origins are configured")
# Check for wildcard in production
if settings.is_production and "*" in settings.cors_origins:
issues.append("Wildcard origin (*) should not be used in production")
# Validate origin formats
for origin in settings.cors_origins:
if origin != "*" and not origin.startswith(("http://", "https://")):
issues.append(f"Invalid origin format: {origin}")
# Check credentials with wildcard
if settings.cors_allow_credentials and "*" in settings.cors_origins:
issues.append("Cannot use credentials with wildcard origin")
return issues
def get_cors_headers_for_origin(origin: str, settings: Settings) -> dict:
"""Get appropriate CORS headers for a specific origin."""
headers = {}
if not settings.cors_enabled:
return headers
# Check if origin is allowed
cors_middleware = CORSMiddleware(None, settings)
if cors_middleware._is_origin_allowed(origin):
headers["Access-Control-Allow-Origin"] = origin
if settings.cors_allow_credentials:
headers["Access-Control-Allow-Credentials"] = "true"
return headers

View File

@@ -0,0 +1,505 @@
"""
Global error handling middleware for WiFi-DensePose API
"""
import logging
import traceback
import time
from typing import Dict, Any, Optional, Callable, Union
from datetime import datetime
from fastapi import Request, Response, HTTPException, status
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
from pydantic import ValidationError
from src.config.settings import Settings
from src.logger import get_request_context
logger = logging.getLogger(__name__)
class ErrorResponse:
"""Standardized error response format."""
def __init__(
self,
error_code: str,
message: str,
details: Optional[Dict[str, Any]] = None,
status_code: int = 500,
request_id: Optional[str] = None,
):
self.error_code = error_code
self.message = message
self.details = details or {}
self.status_code = status_code
self.request_id = request_id
self.timestamp = datetime.utcnow().isoformat()
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON response."""
response = {
"error": {
"code": self.error_code,
"message": self.message,
"timestamp": self.timestamp,
}
}
if self.details:
response["error"]["details"] = self.details
if self.request_id:
response["error"]["request_id"] = self.request_id
return response
def to_response(self) -> JSONResponse:
"""Convert to FastAPI JSONResponse."""
headers = {}
if self.request_id:
headers["X-Request-ID"] = self.request_id
return JSONResponse(
status_code=self.status_code,
content=self.to_dict(),
headers=headers
)
class ErrorHandler:
"""Central error handler for the application."""
def __init__(self, settings: Settings):
self.settings = settings
self.include_traceback = settings.debug and settings.is_development
self.log_errors = True
def handle_http_exception(self, request: Request, exc: HTTPException) -> ErrorResponse:
"""Handle HTTP exceptions."""
request_context = get_request_context()
request_id = request_context.get("request_id")
# Log the error
if self.log_errors:
logger.warning(
f"HTTP {exc.status_code}: {exc.detail} - "
f"{request.method} {request.url.path} - "
f"Request ID: {request_id}"
)
# Determine error code
error_code = self._get_error_code_for_status(exc.status_code)
# Build error details
details = {}
if hasattr(exc, "headers") and exc.headers:
details["headers"] = exc.headers
if self.include_traceback and hasattr(exc, "__traceback__"):
details["traceback"] = traceback.format_exception(
type(exc), exc, exc.__traceback__
)
return ErrorResponse(
error_code=error_code,
message=str(exc.detail),
details=details,
status_code=exc.status_code,
request_id=request_id
)
def handle_validation_error(self, request: Request, exc: RequestValidationError) -> ErrorResponse:
"""Handle request validation errors."""
request_context = get_request_context()
request_id = request_context.get("request_id")
# Log the error
if self.log_errors:
logger.warning(
f"Validation error: {exc.errors()} - "
f"{request.method} {request.url.path} - "
f"Request ID: {request_id}"
)
# Format validation errors
validation_details = []
for error in exc.errors():
validation_details.append({
"field": ".".join(str(loc) for loc in error["loc"]),
"message": error["msg"],
"type": error["type"],
"input": error.get("input"),
})
details = {
"validation_errors": validation_details,
"error_count": len(validation_details)
}
if self.include_traceback:
details["traceback"] = traceback.format_exception(
type(exc), exc, exc.__traceback__
)
return ErrorResponse(
error_code="VALIDATION_ERROR",
message="Request validation failed",
details=details,
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
request_id=request_id
)
def handle_pydantic_error(self, request: Request, exc: ValidationError) -> ErrorResponse:
"""Handle Pydantic validation errors."""
request_context = get_request_context()
request_id = request_context.get("request_id")
# Log the error
if self.log_errors:
logger.warning(
f"Pydantic validation error: {exc.errors()} - "
f"{request.method} {request.url.path} - "
f"Request ID: {request_id}"
)
# Format validation errors
validation_details = []
for error in exc.errors():
validation_details.append({
"field": ".".join(str(loc) for loc in error["loc"]),
"message": error["msg"],
"type": error["type"],
})
details = {
"validation_errors": validation_details,
"error_count": len(validation_details)
}
return ErrorResponse(
error_code="DATA_VALIDATION_ERROR",
message="Data validation failed",
details=details,
status_code=status.HTTP_400_BAD_REQUEST,
request_id=request_id
)
def handle_generic_exception(self, request: Request, exc: Exception) -> ErrorResponse:
"""Handle generic exceptions."""
request_context = get_request_context()
request_id = request_context.get("request_id")
# Log the error
if self.log_errors:
logger.error(
f"Unhandled exception: {type(exc).__name__}: {exc} - "
f"{request.method} {request.url.path} - "
f"Request ID: {request_id}",
exc_info=True
)
# Determine error details
details = {
"exception_type": type(exc).__name__,
}
if self.include_traceback:
details["traceback"] = traceback.format_exception(
type(exc), exc, exc.__traceback__
)
# Don't expose internal error details in production
if self.settings.is_production:
message = "An internal server error occurred"
else:
message = str(exc) or "An unexpected error occurred"
return ErrorResponse(
error_code="INTERNAL_SERVER_ERROR",
message=message,
details=details,
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
request_id=request_id
)
def handle_database_error(self, request: Request, exc: Exception) -> ErrorResponse:
"""Handle database-related errors."""
request_context = get_request_context()
request_id = request_context.get("request_id")
# Log the error
if self.log_errors:
logger.error(
f"Database error: {type(exc).__name__}: {exc} - "
f"{request.method} {request.url.path} - "
f"Request ID: {request_id}",
exc_info=True
)
details = {
"exception_type": type(exc).__name__,
"category": "database"
}
if self.include_traceback:
details["traceback"] = traceback.format_exception(
type(exc), exc, exc.__traceback__
)
return ErrorResponse(
error_code="DATABASE_ERROR",
message="Database operation failed" if self.settings.is_production else str(exc),
details=details,
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
request_id=request_id
)
def handle_external_service_error(self, request: Request, exc: Exception) -> ErrorResponse:
"""Handle external service errors."""
request_context = get_request_context()
request_id = request_context.get("request_id")
# Log the error
if self.log_errors:
logger.error(
f"External service error: {type(exc).__name__}: {exc} - "
f"{request.method} {request.url.path} - "
f"Request ID: {request_id}",
exc_info=True
)
details = {
"exception_type": type(exc).__name__,
"category": "external_service"
}
return ErrorResponse(
error_code="EXTERNAL_SERVICE_ERROR",
message="External service unavailable" if self.settings.is_production else str(exc),
details=details,
status_code=status.HTTP_502_BAD_GATEWAY,
request_id=request_id
)
def _get_error_code_for_status(self, status_code: int) -> str:
"""Get error code for HTTP status code."""
error_codes = {
400: "BAD_REQUEST",
401: "UNAUTHORIZED",
403: "FORBIDDEN",
404: "NOT_FOUND",
405: "METHOD_NOT_ALLOWED",
409: "CONFLICT",
422: "UNPROCESSABLE_ENTITY",
429: "TOO_MANY_REQUESTS",
500: "INTERNAL_SERVER_ERROR",
502: "BAD_GATEWAY",
503: "SERVICE_UNAVAILABLE",
504: "GATEWAY_TIMEOUT",
}
return error_codes.get(status_code, "HTTP_ERROR")
class ErrorHandlingMiddleware:
"""Error handling middleware for FastAPI."""
def __init__(self, app, settings: Settings):
self.app = app
self.settings = settings
self.error_handler = ErrorHandler(settings)
async def __call__(self, scope, receive, send):
"""Process request through error handling middleware."""
if scope["type"] != "http":
await self.app(scope, receive, send)
return
start_time = time.time()
try:
await self.app(scope, receive, send)
except Exception as exc:
# Create a mock request for error handling
from starlette.requests import Request
request = Request(scope, receive)
# Handle different exception types
if isinstance(exc, HTTPException):
error_response = self.error_handler.handle_http_exception(request, exc)
elif isinstance(exc, RequestValidationError):
error_response = self.error_handler.handle_validation_error(request, exc)
elif isinstance(exc, ValidationError):
error_response = self.error_handler.handle_pydantic_error(request, exc)
else:
# Check for specific error types
if self._is_database_error(exc):
error_response = self.error_handler.handle_database_error(request, exc)
elif self._is_external_service_error(exc):
error_response = self.error_handler.handle_external_service_error(request, exc)
else:
error_response = self.error_handler.handle_generic_exception(request, exc)
# Send the error response
response = error_response.to_response()
await response(scope, receive, send)
finally:
# Log request processing time
processing_time = time.time() - start_time
logger.debug(f"Error handling middleware processing time: {processing_time:.3f}s")
def _is_database_error(self, exc: Exception) -> bool:
"""Check if exception is database-related."""
database_exceptions = [
"sqlalchemy",
"psycopg2",
"pymongo",
"redis",
"ConnectionError",
"OperationalError",
"IntegrityError",
]
exc_module = getattr(type(exc), "__module__", "")
exc_name = type(exc).__name__
return any(
db_exc in exc_module or db_exc in exc_name
for db_exc in database_exceptions
)
def _is_external_service_error(self, exc: Exception) -> bool:
"""Check if exception is external service-related."""
external_exceptions = [
"requests",
"httpx",
"aiohttp",
"urllib",
"ConnectionError",
"TimeoutError",
"ConnectTimeout",
"ReadTimeout",
]
exc_module = getattr(type(exc), "__module__", "")
exc_name = type(exc).__name__
return any(
ext_exc in exc_module or ext_exc in exc_name
for ext_exc in external_exceptions
)
def setup_error_handling(app, settings: Settings):
"""Setup error handling for the application."""
logger.info("Setting up error handling middleware")
error_handler = ErrorHandler(settings)
# Add exception handlers
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
error_response = error_handler.handle_http_exception(request, exc)
return error_response.to_response()
@app.exception_handler(StarletteHTTPException)
async def starlette_http_exception_handler(request: Request, exc: StarletteHTTPException):
# Convert Starlette HTTPException to FastAPI HTTPException
fastapi_exc = HTTPException(status_code=exc.status_code, detail=exc.detail)
error_response = error_handler.handle_http_exception(request, fastapi_exc)
return error_response.to_response()
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
error_response = error_handler.handle_validation_error(request, exc)
return error_response.to_response()
@app.exception_handler(ValidationError)
async def pydantic_exception_handler(request: Request, exc: ValidationError):
error_response = error_handler.handle_pydantic_error(request, exc)
return error_response.to_response()
@app.exception_handler(Exception)
async def generic_exception_handler(request: Request, exc: Exception):
error_response = error_handler.handle_generic_exception(request, exc)
return error_response.to_response()
# Add middleware for additional error handling
# Note: We use exception handlers instead of custom middleware to avoid ASGI conflicts
# The middleware approach is commented out but kept for reference
# middleware = ErrorHandlingMiddleware(app, settings)
# app.add_middleware(ErrorHandlingMiddleware, settings=settings)
logger.info("Error handling configured")
class CustomHTTPException(HTTPException):
"""Custom HTTP exception with additional context."""
def __init__(
self,
status_code: int,
detail: str,
error_code: Optional[str] = None,
context: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
):
super().__init__(status_code=status_code, detail=detail, headers=headers)
self.error_code = error_code
self.context = context or {}
class BusinessLogicError(CustomHTTPException):
"""Exception for business logic errors."""
def __init__(self, message: str, context: Optional[Dict[str, Any]] = None):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail=message,
error_code="BUSINESS_LOGIC_ERROR",
context=context
)
class ResourceNotFoundError(CustomHTTPException):
"""Exception for resource not found errors."""
def __init__(self, resource: str, identifier: str):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"{resource} not found",
error_code="RESOURCE_NOT_FOUND",
context={"resource": resource, "identifier": identifier}
)
class ConflictError(CustomHTTPException):
"""Exception for conflict errors."""
def __init__(self, message: str, context: Optional[Dict[str, Any]] = None):
super().__init__(
status_code=status.HTTP_409_CONFLICT,
detail=message,
error_code="CONFLICT_ERROR",
context=context
)
class ServiceUnavailableError(CustomHTTPException):
"""Exception for service unavailable errors."""
def __init__(self, service: str, reason: Optional[str] = None):
detail = f"{service} service is unavailable"
if reason:
detail += f": {reason}"
super().__init__(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=detail,
error_code="SERVICE_UNAVAILABLE",
context={"service": service, "reason": reason}
)

View File

@@ -0,0 +1,465 @@
"""
Rate limiting middleware for WiFi-DensePose API
"""
import asyncio
import logging
import time
from typing import Dict, Any, Optional, Callable, Tuple
from datetime import datetime, timedelta
from collections import defaultdict, deque
from dataclasses import dataclass
from fastapi import Request, Response, HTTPException, status
from starlette.types import ASGIApp
from src.config.settings import Settings
logger = logging.getLogger(__name__)
@dataclass
class RateLimitInfo:
"""Rate limit information."""
requests: int
window_start: float
window_size: int
limit: int
@property
def remaining(self) -> int:
"""Get remaining requests in current window."""
return max(0, self.limit - self.requests)
@property
def reset_time(self) -> float:
"""Get time when window resets."""
return self.window_start + self.window_size
@property
def is_exceeded(self) -> bool:
"""Check if rate limit is exceeded."""
return self.requests >= self.limit
class TokenBucket:
"""Token bucket algorithm for rate limiting."""
def __init__(self, capacity: int, refill_rate: float):
self.capacity = capacity
self.tokens = capacity
self.refill_rate = refill_rate
self.last_refill = time.time()
self._lock = asyncio.Lock()
async def consume(self, tokens: int = 1) -> bool:
"""Try to consume tokens from bucket."""
async with self._lock:
now = time.time()
# Refill tokens based on time elapsed
time_passed = now - self.last_refill
tokens_to_add = time_passed * self.refill_rate
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
self.last_refill = now
# Check if we have enough tokens
if self.tokens >= tokens:
self.tokens -= tokens
return True
return False
def get_info(self) -> Dict[str, Any]:
"""Get bucket information."""
return {
"capacity": self.capacity,
"tokens": self.tokens,
"refill_rate": self.refill_rate,
"last_refill": self.last_refill
}
class SlidingWindowCounter:
"""Sliding window counter for rate limiting."""
def __init__(self, window_size: int, limit: int):
self.window_size = window_size
self.limit = limit
self.requests = deque()
self._lock = asyncio.Lock()
async def is_allowed(self) -> Tuple[bool, RateLimitInfo]:
"""Check if request is allowed."""
async with self._lock:
now = time.time()
window_start = now - self.window_size
# Remove old requests outside the window
while self.requests and self.requests[0] < window_start:
self.requests.popleft()
# Check if limit is exceeded
current_requests = len(self.requests)
allowed = current_requests < self.limit
if allowed:
self.requests.append(now)
rate_limit_info = RateLimitInfo(
requests=current_requests + (1 if allowed else 0),
window_start=window_start,
window_size=self.window_size,
limit=self.limit
)
return allowed, rate_limit_info
class RateLimiter:
"""Rate limiter with multiple algorithms."""
def __init__(self, settings: Settings):
self.settings = settings
self.enabled = settings.enable_rate_limiting
# Rate limit configurations
self.default_limit = settings.rate_limit_requests
self.authenticated_limit = settings.rate_limit_authenticated_requests
self.window_size = settings.rate_limit_window
# Storage for rate limit data
self._sliding_windows: Dict[str, SlidingWindowCounter] = {}
self._token_buckets: Dict[str, TokenBucket] = {}
# Cleanup task
self._cleanup_task: Optional[asyncio.Task] = None
self._cleanup_interval = 300 # 5 minutes
async def start(self):
"""Start rate limiter background tasks."""
if self.enabled:
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info("Rate limiter started")
async def stop(self):
"""Stop rate limiter background tasks."""
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
logger.info("Rate limiter stopped")
async def _cleanup_loop(self):
"""Background task to cleanup old rate limit data."""
while True:
try:
await asyncio.sleep(self._cleanup_interval)
await self._cleanup_old_data()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in rate limiter cleanup: {e}")
async def _cleanup_old_data(self):
"""Remove old rate limit data."""
now = time.time()
cutoff = now - (self.window_size * 2) # Keep data for 2 windows
# Cleanup sliding windows
keys_to_remove = []
for key, window in self._sliding_windows.items():
# Remove old requests
while window.requests and window.requests[0] < cutoff:
window.requests.popleft()
# Remove empty windows
if not window.requests:
keys_to_remove.append(key)
for key in keys_to_remove:
del self._sliding_windows[key]
logger.debug(f"Cleaned up {len(keys_to_remove)} old rate limit windows")
def _get_client_identifier(self, request: Request) -> str:
"""Get client identifier for rate limiting."""
# Try to get user ID from authenticated request
user = getattr(request.state, "user", None)
if user:
return f"user:{user.get('username', 'unknown')}"
# Fall back to IP address
client_ip = self._get_client_ip(request)
return f"ip:{client_ip}"
def _get_client_ip(self, request: Request) -> str:
"""Get client IP address."""
# Check for forwarded headers
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# Fall back to direct connection
return request.client.host if request.client else "unknown"
def _get_rate_limit(self, request: Request) -> int:
"""Get rate limit for request."""
# Check if user is authenticated
user = getattr(request.state, "user", None)
if user:
return self.authenticated_limit
return self.default_limit
def _get_rate_limit_key(self, request: Request) -> str:
"""Get rate limit key for request."""
client_id = self._get_client_identifier(request)
endpoint = f"{request.method}:{request.url.path}"
return f"{client_id}:{endpoint}"
async def check_rate_limit(self, request: Request) -> Tuple[bool, RateLimitInfo]:
"""Check if request is within rate limits."""
if not self.enabled:
# Return dummy info when rate limiting is disabled
return True, RateLimitInfo(
requests=0,
window_start=time.time(),
window_size=self.window_size,
limit=float('inf')
)
key = self._get_rate_limit_key(request)
limit = self._get_rate_limit(request)
# Get or create sliding window counter
if key not in self._sliding_windows:
self._sliding_windows[key] = SlidingWindowCounter(self.window_size, limit)
window = self._sliding_windows[key]
# Update limit if it changed (e.g., user authenticated)
window.limit = limit
return await window.is_allowed()
async def check_token_bucket(self, request: Request, tokens: int = 1) -> bool:
"""Check rate limit using token bucket algorithm."""
if not self.enabled:
return True
key = self._get_client_identifier(request)
limit = self._get_rate_limit(request)
# Get or create token bucket
if key not in self._token_buckets:
# Refill rate: limit per window size
refill_rate = limit / self.window_size
self._token_buckets[key] = TokenBucket(limit, refill_rate)
bucket = self._token_buckets[key]
return await bucket.consume(tokens)
def get_rate_limit_headers(self, rate_limit_info: RateLimitInfo) -> Dict[str, str]:
"""Get rate limit headers for response."""
return {
"X-RateLimit-Limit": str(rate_limit_info.limit),
"X-RateLimit-Remaining": str(rate_limit_info.remaining),
"X-RateLimit-Reset": str(int(rate_limit_info.reset_time)),
"X-RateLimit-Window": str(rate_limit_info.window_size),
}
async def get_stats(self) -> Dict[str, Any]:
"""Get rate limiter statistics."""
return {
"enabled": self.enabled,
"default_limit": self.default_limit,
"authenticated_limit": self.authenticated_limit,
"window_size": self.window_size,
"active_windows": len(self._sliding_windows),
"active_buckets": len(self._token_buckets),
}
class RateLimitMiddleware:
"""Rate limiting middleware for FastAPI."""
def __init__(self, settings: Settings):
self.settings = settings
self.rate_limiter = RateLimiter(settings)
self.enabled = settings.enable_rate_limiting
async def __call__(self, request: Request, call_next: Callable) -> Response:
"""Process request through rate limiting middleware."""
if not self.enabled:
return await call_next(request)
# Skip rate limiting for certain paths
if self._should_skip_rate_limit(request):
return await call_next(request)
try:
# Check rate limit
allowed, rate_limit_info = await self.rate_limiter.check_rate_limit(request)
if not allowed:
# Rate limit exceeded
logger.warning(
f"Rate limit exceeded for {self.rate_limiter._get_client_identifier(request)} "
f"on {request.method} {request.url.path}"
)
headers = self.rate_limiter.get_rate_limit_headers(rate_limit_info)
headers["Retry-After"] = str(int(rate_limit_info.reset_time - time.time()))
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded",
headers=headers
)
# Process request
response = await call_next(request)
# Add rate limit headers to response
headers = self.rate_limiter.get_rate_limit_headers(rate_limit_info)
for key, value in headers.items():
response.headers[key] = value
return response
except HTTPException:
raise
except Exception as e:
logger.error(f"Rate limiting middleware error: {e}")
# Continue without rate limiting on error
return await call_next(request)
def _should_skip_rate_limit(self, request: Request) -> bool:
"""Check if rate limiting should be skipped for this request."""
path = request.url.path
# Skip rate limiting for these paths
skip_paths = [
"/health",
"/metrics",
"/docs",
"/redoc",
"/openapi.json",
"/static",
]
return any(path.startswith(skip_path) for skip_path in skip_paths)
async def start(self):
"""Start rate limiting middleware."""
await self.rate_limiter.start()
async def stop(self):
"""Stop rate limiting middleware."""
await self.rate_limiter.stop()
# Global rate limit middleware instance
_rate_limit_middleware: Optional[RateLimitMiddleware] = None
def get_rate_limit_middleware(settings: Settings) -> RateLimitMiddleware:
"""Get rate limit middleware instance."""
global _rate_limit_middleware
if _rate_limit_middleware is None:
_rate_limit_middleware = RateLimitMiddleware(settings)
return _rate_limit_middleware
def setup_rate_limiting(app: ASGIApp, settings: Settings) -> ASGIApp:
"""Setup rate limiting middleware for the application."""
if settings.enable_rate_limiting:
logger.info("Setting up rate limiting middleware")
middleware = get_rate_limit_middleware(settings)
# Add middleware to app
@app.middleware("http")
async def rate_limit_middleware(request: Request, call_next):
return await middleware(request, call_next)
logger.info(
f"Rate limiting enabled - Default: {settings.rate_limit_requests}/"
f"{settings.rate_limit_window}s, Authenticated: "
f"{settings.rate_limit_authenticated_requests}/{settings.rate_limit_window}s"
)
else:
logger.info("Rate limiting disabled")
return app
class RateLimitConfig:
"""Rate limiting configuration helper."""
@staticmethod
def development_config() -> dict:
"""Get rate limiting configuration for development."""
return {
"enable_rate_limiting": False, # Disabled in development
"rate_limit_requests": 1000,
"rate_limit_authenticated_requests": 5000,
"rate_limit_window": 3600, # 1 hour
}
@staticmethod
def production_config() -> dict:
"""Get rate limiting configuration for production."""
return {
"enable_rate_limiting": True,
"rate_limit_requests": 100, # 100 requests per hour for unauthenticated
"rate_limit_authenticated_requests": 1000, # 1000 requests per hour for authenticated
"rate_limit_window": 3600, # 1 hour
}
@staticmethod
def api_config() -> dict:
"""Get rate limiting configuration for API access."""
return {
"enable_rate_limiting": True,
"rate_limit_requests": 60, # 60 requests per minute
"rate_limit_authenticated_requests": 300, # 300 requests per minute
"rate_limit_window": 60, # 1 minute
}
@staticmethod
def strict_config() -> dict:
"""Get strict rate limiting configuration."""
return {
"enable_rate_limiting": True,
"rate_limit_requests": 10, # 10 requests per minute
"rate_limit_authenticated_requests": 100, # 100 requests per minute
"rate_limit_window": 60, # 1 minute
}
def validate_rate_limit_config(settings: Settings) -> list:
"""Validate rate limiting configuration."""
issues = []
if settings.enable_rate_limiting:
if settings.rate_limit_requests <= 0:
issues.append("Rate limit requests must be positive")
if settings.rate_limit_authenticated_requests <= 0:
issues.append("Authenticated rate limit requests must be positive")
if settings.rate_limit_window <= 0:
issues.append("Rate limit window must be positive")
if settings.rate_limit_authenticated_requests < settings.rate_limit_requests:
issues.append("Authenticated rate limit should be higher than default rate limit")
return issues