feat: Complete Rust port of WiFi-DensePose with modular crates
Major changes: - Organized Python v1 implementation into v1/ subdirectory - Created Rust workspace with 9 modular crates: - wifi-densepose-core: Core types, traits, errors - wifi-densepose-signal: CSI processing, phase sanitization, FFT - wifi-densepose-nn: Neural network inference (ONNX/Candle/tch) - wifi-densepose-api: Axum-based REST/WebSocket API - wifi-densepose-db: SQLx database layer - wifi-densepose-config: Configuration management - wifi-densepose-hardware: Hardware abstraction - wifi-densepose-wasm: WebAssembly bindings - wifi-densepose-cli: Command-line interface Documentation: - ADR-001: Workspace structure - ADR-002: Signal processing library selection - ADR-003: Neural network inference strategy - DDD domain model with bounded contexts Testing: - 69 tests passing across all crates - Signal processing: 45 tests - Neural networks: 21 tests - Core: 3 doc tests Performance targets: - 10x faster CSI processing (~0.5ms vs ~5ms) - 5x lower memory usage (~100MB vs ~500MB) - WASM support for browser deployment
This commit is contained in:
467
v1/src/middleware/auth.py
Normal file
467
v1/src/middleware/auth.py
Normal file
@@ -0,0 +1,467 @@
|
||||
"""
|
||||
Authentication middleware for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, Dict, Any, Callable
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import Request, Response, HTTPException, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.logger import set_request_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Password hashing
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# JWT token handler
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
"""Authentication error."""
|
||||
pass
|
||||
|
||||
|
||||
class AuthorizationError(Exception):
|
||||
"""Authorization error."""
|
||||
pass
|
||||
|
||||
|
||||
class TokenManager:
|
||||
"""JWT token management."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self.secret_key = settings.secret_key
|
||||
self.algorithm = settings.jwt_algorithm
|
||||
self.expire_hours = settings.jwt_expire_hours
|
||||
|
||||
def create_access_token(self, data: Dict[str, Any]) -> str:
|
||||
"""Create JWT access token."""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + timedelta(hours=self.expire_hours)
|
||||
to_encode.update({"exp": expire, "iat": datetime.utcnow()})
|
||||
|
||||
encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
|
||||
return encoded_jwt
|
||||
|
||||
def verify_token(self, token: str) -> Dict[str, Any]:
|
||||
"""Verify and decode JWT token."""
|
||||
try:
|
||||
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
||||
return payload
|
||||
except JWTError as e:
|
||||
logger.warning(f"JWT verification failed: {e}")
|
||||
raise AuthenticationError("Invalid token")
|
||||
|
||||
def decode_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""Decode token without verification (for debugging)."""
|
||||
try:
|
||||
return jwt.decode(token, options={"verify_signature": False})
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
|
||||
class UserManager:
|
||||
"""User management for authentication."""
|
||||
|
||||
def __init__(self):
|
||||
# In a real application, this would connect to a database
|
||||
# For now, we'll use a simple in-memory store
|
||||
self._users: Dict[str, Dict[str, Any]] = {
|
||||
"admin": {
|
||||
"username": "admin",
|
||||
"email": "admin@example.com",
|
||||
"hashed_password": self.hash_password("admin123"),
|
||||
"roles": ["admin"],
|
||||
"is_active": True,
|
||||
"created_at": datetime.utcnow(),
|
||||
},
|
||||
"user": {
|
||||
"username": "user",
|
||||
"email": "user@example.com",
|
||||
"hashed_password": self.hash_password("user123"),
|
||||
"roles": ["user"],
|
||||
"is_active": True,
|
||||
"created_at": datetime.utcnow(),
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a password."""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
@staticmethod
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash."""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
def get_user(self, username: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get user by username."""
|
||||
return self._users.get(username)
|
||||
|
||||
def authenticate_user(self, username: str, password: str) -> Optional[Dict[str, Any]]:
|
||||
"""Authenticate user with username and password."""
|
||||
user = self.get_user(username)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
if not self.verify_password(password, user["hashed_password"]):
|
||||
return None
|
||||
|
||||
if not user.get("is_active", False):
|
||||
return None
|
||||
|
||||
return user
|
||||
|
||||
def create_user(self, username: str, email: str, password: str, roles: list = None) -> Dict[str, Any]:
|
||||
"""Create a new user."""
|
||||
if username in self._users:
|
||||
raise ValueError("User already exists")
|
||||
|
||||
user = {
|
||||
"username": username,
|
||||
"email": email,
|
||||
"hashed_password": self.hash_password(password),
|
||||
"roles": roles or ["user"],
|
||||
"is_active": True,
|
||||
"created_at": datetime.utcnow(),
|
||||
}
|
||||
|
||||
self._users[username] = user
|
||||
return user
|
||||
|
||||
def update_user(self, username: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Update user information."""
|
||||
user = self._users.get(username)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
# Don't allow updating certain fields
|
||||
protected_fields = {"username", "created_at", "hashed_password"}
|
||||
updates = {k: v for k, v in updates.items() if k not in protected_fields}
|
||||
|
||||
user.update(updates)
|
||||
return user
|
||||
|
||||
def deactivate_user(self, username: str) -> bool:
|
||||
"""Deactivate a user."""
|
||||
user = self._users.get(username)
|
||||
if user:
|
||||
user["is_active"] = False
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
"""Authentication middleware for FastAPI."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self.token_manager = TokenManager(settings)
|
||||
self.user_manager = UserManager()
|
||||
self.enabled = settings.enable_authentication
|
||||
|
||||
async def __call__(self, request: Request, call_next: Callable) -> Response:
|
||||
"""Process request through authentication middleware."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Skip authentication for certain paths
|
||||
if self._should_skip_auth(request):
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
# Skip if authentication is disabled
|
||||
if not self.enabled:
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
# Extract and verify token
|
||||
user_info = await self._authenticate_request(request)
|
||||
|
||||
# Set user context
|
||||
if user_info:
|
||||
request.state.user = user_info
|
||||
set_request_context(user_id=user_info.get("username"))
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add authentication headers
|
||||
self._add_auth_headers(response, user_info)
|
||||
|
||||
return response
|
||||
|
||||
except AuthenticationError as e:
|
||||
logger.warning(f"Authentication failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except AuthorizationError as e:
|
||||
logger.warning(f"Authorization failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication middleware error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Authentication service error",
|
||||
)
|
||||
finally:
|
||||
# Log request processing time
|
||||
processing_time = time.time() - start_time
|
||||
logger.debug(f"Auth middleware processing time: {processing_time:.3f}s")
|
||||
|
||||
def _should_skip_auth(self, request: Request) -> bool:
|
||||
"""Check if authentication should be skipped for this request."""
|
||||
path = request.url.path
|
||||
|
||||
# Skip authentication for these paths
|
||||
skip_paths = [
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/auth/login",
|
||||
"/auth/register",
|
||||
"/static",
|
||||
]
|
||||
|
||||
return any(path.startswith(skip_path) for skip_path in skip_paths)
|
||||
|
||||
async def _authenticate_request(self, request: Request) -> Optional[Dict[str, Any]]:
|
||||
"""Authenticate the request and return user info."""
|
||||
# Try to get token from Authorization header
|
||||
authorization = request.headers.get("Authorization")
|
||||
if not authorization:
|
||||
# For WebSocket connections, try to get token from query parameters
|
||||
if request.url.path.startswith("/ws"):
|
||||
token = request.query_params.get("token")
|
||||
if token:
|
||||
authorization = f"Bearer {token}"
|
||||
|
||||
if not authorization:
|
||||
if self._requires_auth(request):
|
||||
raise AuthenticationError("Missing authorization header")
|
||||
return None
|
||||
|
||||
# Extract token
|
||||
try:
|
||||
scheme, token = authorization.split()
|
||||
if scheme.lower() != "bearer":
|
||||
raise AuthenticationError("Invalid authentication scheme")
|
||||
except ValueError:
|
||||
raise AuthenticationError("Invalid authorization header format")
|
||||
|
||||
# Verify token
|
||||
try:
|
||||
payload = self.token_manager.verify_token(token)
|
||||
username = payload.get("sub")
|
||||
if not username:
|
||||
raise AuthenticationError("Invalid token payload")
|
||||
|
||||
# Get user info
|
||||
user = self.user_manager.get_user(username)
|
||||
if not user:
|
||||
raise AuthenticationError("User not found")
|
||||
|
||||
if not user.get("is_active", False):
|
||||
raise AuthenticationError("User account is disabled")
|
||||
|
||||
# Return user info without sensitive data
|
||||
return {
|
||||
"username": user["username"],
|
||||
"email": user["email"],
|
||||
"roles": user["roles"],
|
||||
"is_active": user["is_active"],
|
||||
}
|
||||
|
||||
except AuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Token verification error: {e}")
|
||||
raise AuthenticationError("Token verification failed")
|
||||
|
||||
def _requires_auth(self, request: Request) -> bool:
|
||||
"""Check if the request requires authentication."""
|
||||
# All API endpoints require authentication by default
|
||||
path = request.url.path
|
||||
return path.startswith("/api/") or path.startswith("/ws/")
|
||||
|
||||
def _add_auth_headers(self, response: Response, user_info: Optional[Dict[str, Any]]):
|
||||
"""Add authentication-related headers to response."""
|
||||
if user_info:
|
||||
response.headers["X-User"] = user_info["username"]
|
||||
response.headers["X-User-Roles"] = ",".join(user_info["roles"])
|
||||
|
||||
async def login(self, username: str, password: str) -> Dict[str, Any]:
|
||||
"""Authenticate user and return token."""
|
||||
user = self.user_manager.authenticate_user(username, password)
|
||||
if not user:
|
||||
raise AuthenticationError("Invalid username or password")
|
||||
|
||||
# Create token
|
||||
token_data = {
|
||||
"sub": user["username"],
|
||||
"email": user["email"],
|
||||
"roles": user["roles"],
|
||||
}
|
||||
|
||||
access_token = self.token_manager.create_access_token(token_data)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": self.settings.jwt_expire_hours * 3600,
|
||||
"user": {
|
||||
"username": user["username"],
|
||||
"email": user["email"],
|
||||
"roles": user["roles"],
|
||||
}
|
||||
}
|
||||
|
||||
async def register(self, username: str, email: str, password: str) -> Dict[str, Any]:
|
||||
"""Register a new user."""
|
||||
try:
|
||||
user = self.user_manager.create_user(username, email, password)
|
||||
|
||||
# Create token for new user
|
||||
token_data = {
|
||||
"sub": user["username"],
|
||||
"email": user["email"],
|
||||
"roles": user["roles"],
|
||||
}
|
||||
|
||||
access_token = self.token_manager.create_access_token(token_data)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": self.settings.jwt_expire_hours * 3600,
|
||||
"user": {
|
||||
"username": user["username"],
|
||||
"email": user["email"],
|
||||
"roles": user["roles"],
|
||||
}
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
raise AuthenticationError(str(e))
|
||||
|
||||
async def refresh_token(self, token: str) -> Dict[str, Any]:
|
||||
"""Refresh an access token."""
|
||||
try:
|
||||
payload = self.token_manager.verify_token(token)
|
||||
username = payload.get("sub")
|
||||
|
||||
user = self.user_manager.get_user(username)
|
||||
if not user or not user.get("is_active", False):
|
||||
raise AuthenticationError("User not found or inactive")
|
||||
|
||||
# Create new token
|
||||
token_data = {
|
||||
"sub": user["username"],
|
||||
"email": user["email"],
|
||||
"roles": user["roles"],
|
||||
}
|
||||
|
||||
new_token = self.token_manager.create_access_token(token_data)
|
||||
|
||||
return {
|
||||
"access_token": new_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": self.settings.jwt_expire_hours * 3600,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise AuthenticationError("Token refresh failed")
|
||||
|
||||
def check_permission(self, user_info: Dict[str, Any], required_role: str) -> bool:
|
||||
"""Check if user has required role/permission."""
|
||||
user_roles = user_info.get("roles", [])
|
||||
|
||||
# Admin role has all permissions
|
||||
if "admin" in user_roles:
|
||||
return True
|
||||
|
||||
# Check specific role
|
||||
return required_role in user_roles
|
||||
|
||||
def require_role(self, required_role: str):
|
||||
"""Decorator to require specific role."""
|
||||
def decorator(func):
|
||||
import functools
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(request: Request, *args, **kwargs):
|
||||
user_info = getattr(request.state, "user", None)
|
||||
if not user_info:
|
||||
raise AuthorizationError("Authentication required")
|
||||
|
||||
if not self.check_permission(user_info, required_role):
|
||||
raise AuthorizationError(f"Role '{required_role}' required")
|
||||
|
||||
return await func(request, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
# Global authentication middleware instance
|
||||
_auth_middleware: Optional[AuthenticationMiddleware] = None
|
||||
|
||||
|
||||
def get_auth_middleware(settings: Settings) -> AuthenticationMiddleware:
|
||||
"""Get authentication middleware instance."""
|
||||
global _auth_middleware
|
||||
if _auth_middleware is None:
|
||||
_auth_middleware = AuthenticationMiddleware(settings)
|
||||
return _auth_middleware
|
||||
|
||||
|
||||
def get_current_user(request: Request) -> Optional[Dict[str, Any]]:
|
||||
"""Get current authenticated user from request."""
|
||||
return getattr(request.state, "user", None)
|
||||
|
||||
|
||||
def require_authentication(request: Request) -> Dict[str, Any]:
|
||||
"""Require authentication and return user info."""
|
||||
user = get_current_user(request)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
def require_role(role: str):
|
||||
"""Dependency to require specific role."""
|
||||
def dependency(request: Request) -> Dict[str, Any]:
|
||||
user = require_authentication(request)
|
||||
|
||||
auth_middleware = get_auth_middleware(request.app.state.settings)
|
||||
if not auth_middleware.check_permission(user, role):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Role '{role}' required",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
return dependency
|
||||
375
v1/src/middleware/cors.py
Normal file
375
v1/src/middleware/cors.py
Normal file
@@ -0,0 +1,375 @@
|
||||
"""
|
||||
CORS middleware for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Union, Callable
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import Request, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware as FastAPICORSMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from src.config.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CORSMiddleware:
|
||||
"""Enhanced CORS middleware with additional security features."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
settings: Settings,
|
||||
allow_origins: Optional[List[str]] = None,
|
||||
allow_methods: Optional[List[str]] = None,
|
||||
allow_headers: Optional[List[str]] = None,
|
||||
allow_credentials: bool = False,
|
||||
expose_headers: Optional[List[str]] = None,
|
||||
max_age: int = 600,
|
||||
):
|
||||
self.app = app
|
||||
self.settings = settings
|
||||
self.allow_origins = allow_origins or settings.cors_origins
|
||||
self.allow_methods = allow_methods or ["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"]
|
||||
self.allow_headers = allow_headers or [
|
||||
"Accept",
|
||||
"Accept-Language",
|
||||
"Content-Language",
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
"X-Requested-With",
|
||||
"X-Request-ID",
|
||||
"X-User-Agent",
|
||||
]
|
||||
self.allow_credentials = allow_credentials or settings.cors_allow_credentials
|
||||
self.expose_headers = expose_headers or [
|
||||
"X-Request-ID",
|
||||
"X-Response-Time",
|
||||
"X-Rate-Limit-Remaining",
|
||||
"X-Rate-Limit-Reset",
|
||||
]
|
||||
self.max_age = max_age
|
||||
|
||||
# Security settings
|
||||
self.strict_origin_check = settings.is_production
|
||||
self.log_cors_violations = True
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
"""ASGI middleware implementation."""
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
request = Request(scope, receive)
|
||||
|
||||
# Check if this is a CORS preflight request
|
||||
if request.method == "OPTIONS" and "access-control-request-method" in request.headers:
|
||||
response = await self._handle_preflight(request)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
# Handle actual request
|
||||
async def send_wrapper(message):
|
||||
if message["type"] == "http.response.start":
|
||||
# Add CORS headers to response
|
||||
headers = dict(message.get("headers", []))
|
||||
cors_headers = self._get_cors_headers(request)
|
||||
|
||||
for key, value in cors_headers.items():
|
||||
headers[key.encode()] = value.encode()
|
||||
|
||||
message["headers"] = list(headers.items())
|
||||
|
||||
await send(message)
|
||||
|
||||
await self.app(scope, receive, send_wrapper)
|
||||
|
||||
async def _handle_preflight(self, request: Request) -> Response:
|
||||
"""Handle CORS preflight request."""
|
||||
origin = request.headers.get("origin")
|
||||
requested_method = request.headers.get("access-control-request-method")
|
||||
requested_headers = request.headers.get("access-control-request-headers", "")
|
||||
|
||||
# Validate origin
|
||||
if not self._is_origin_allowed(origin):
|
||||
if self.log_cors_violations:
|
||||
logger.warning(f"CORS preflight rejected for origin: {origin}")
|
||||
|
||||
return Response(
|
||||
status_code=403,
|
||||
content="CORS preflight request rejected",
|
||||
headers={"Content-Type": "text/plain"}
|
||||
)
|
||||
|
||||
# Validate method
|
||||
if requested_method not in self.allow_methods:
|
||||
if self.log_cors_violations:
|
||||
logger.warning(f"CORS preflight rejected for method: {requested_method}")
|
||||
|
||||
return Response(
|
||||
status_code=405,
|
||||
content="Method not allowed",
|
||||
headers={"Content-Type": "text/plain"}
|
||||
)
|
||||
|
||||
# Validate headers
|
||||
if requested_headers:
|
||||
requested_header_list = [h.strip().lower() for h in requested_headers.split(",")]
|
||||
allowed_headers_lower = [h.lower() for h in self.allow_headers]
|
||||
|
||||
for header in requested_header_list:
|
||||
if header not in allowed_headers_lower:
|
||||
if self.log_cors_violations:
|
||||
logger.warning(f"CORS preflight rejected for header: {header}")
|
||||
|
||||
return Response(
|
||||
status_code=400,
|
||||
content="Header not allowed",
|
||||
headers={"Content-Type": "text/plain"}
|
||||
)
|
||||
|
||||
# Build preflight response headers
|
||||
headers = {
|
||||
"Access-Control-Allow-Origin": origin,
|
||||
"Access-Control-Allow-Methods": ", ".join(self.allow_methods),
|
||||
"Access-Control-Allow-Headers": ", ".join(self.allow_headers),
|
||||
"Access-Control-Max-Age": str(self.max_age),
|
||||
}
|
||||
|
||||
if self.allow_credentials:
|
||||
headers["Access-Control-Allow-Credentials"] = "true"
|
||||
|
||||
if self.expose_headers:
|
||||
headers["Access-Control-Expose-Headers"] = ", ".join(self.expose_headers)
|
||||
|
||||
logger.debug(f"CORS preflight approved for origin: {origin}")
|
||||
|
||||
return Response(
|
||||
status_code=200,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
def _get_cors_headers(self, request: Request) -> dict:
|
||||
"""Get CORS headers for actual request."""
|
||||
origin = request.headers.get("origin")
|
||||
headers = {}
|
||||
|
||||
if self._is_origin_allowed(origin):
|
||||
headers["Access-Control-Allow-Origin"] = origin
|
||||
|
||||
if self.allow_credentials:
|
||||
headers["Access-Control-Allow-Credentials"] = "true"
|
||||
|
||||
if self.expose_headers:
|
||||
headers["Access-Control-Expose-Headers"] = ", ".join(self.expose_headers)
|
||||
|
||||
return headers
|
||||
|
||||
def _is_origin_allowed(self, origin: Optional[str]) -> bool:
|
||||
"""Check if origin is allowed."""
|
||||
if not origin:
|
||||
return not self.strict_origin_check
|
||||
|
||||
# Allow all origins in development
|
||||
if not self.settings.is_production and "*" in self.allow_origins:
|
||||
return True
|
||||
|
||||
# Check exact matches
|
||||
if origin in self.allow_origins:
|
||||
return True
|
||||
|
||||
# Check wildcard patterns
|
||||
for allowed_origin in self.allow_origins:
|
||||
if allowed_origin == "*":
|
||||
return not self.strict_origin_check
|
||||
|
||||
if self._match_origin_pattern(origin, allowed_origin):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _match_origin_pattern(self, origin: str, pattern: str) -> bool:
|
||||
"""Match origin against pattern with wildcard support."""
|
||||
if "*" not in pattern:
|
||||
return origin == pattern
|
||||
|
||||
# Simple wildcard matching
|
||||
if pattern.startswith("*."):
|
||||
domain = pattern[2:]
|
||||
parsed_origin = urlparse(origin)
|
||||
origin_host = parsed_origin.netloc
|
||||
|
||||
# Check if origin ends with the domain
|
||||
return origin_host.endswith(domain) or origin_host == domain[1:] if domain.startswith('.') else origin_host == domain
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def setup_cors_middleware(app: ASGIApp, settings: Settings) -> ASGIApp:
|
||||
"""Setup CORS middleware for the application."""
|
||||
|
||||
if settings.cors_enabled:
|
||||
logger.info("Setting up CORS middleware")
|
||||
|
||||
# Use FastAPI's built-in CORS middleware for basic functionality
|
||||
app = FastAPICORSMiddleware(
|
||||
app,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=settings.cors_allow_credentials,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
|
||||
allow_headers=[
|
||||
"Accept",
|
||||
"Accept-Language",
|
||||
"Content-Language",
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
"X-Requested-With",
|
||||
"X-Request-ID",
|
||||
"X-User-Agent",
|
||||
],
|
||||
expose_headers=[
|
||||
"X-Request-ID",
|
||||
"X-Response-Time",
|
||||
"X-Rate-Limit-Remaining",
|
||||
"X-Rate-Limit-Reset",
|
||||
],
|
||||
max_age=600,
|
||||
)
|
||||
|
||||
logger.info(f"CORS enabled for origins: {settings.cors_origins}")
|
||||
else:
|
||||
logger.info("CORS middleware disabled")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class CORSConfig:
|
||||
"""CORS configuration helper."""
|
||||
|
||||
@staticmethod
|
||||
def development_config() -> dict:
|
||||
"""Get CORS configuration for development."""
|
||||
return {
|
||||
"allow_origins": ["*"],
|
||||
"allow_credentials": True,
|
||||
"allow_methods": ["*"],
|
||||
"allow_headers": ["*"],
|
||||
"expose_headers": [
|
||||
"X-Request-ID",
|
||||
"X-Response-Time",
|
||||
"X-Rate-Limit-Remaining",
|
||||
"X-Rate-Limit-Reset",
|
||||
],
|
||||
"max_age": 600,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def production_config(allowed_origins: List[str]) -> dict:
|
||||
"""Get CORS configuration for production."""
|
||||
return {
|
||||
"allow_origins": allowed_origins,
|
||||
"allow_credentials": True,
|
||||
"allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
|
||||
"allow_headers": [
|
||||
"Accept",
|
||||
"Accept-Language",
|
||||
"Content-Language",
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
"X-Requested-With",
|
||||
"X-Request-ID",
|
||||
"X-User-Agent",
|
||||
],
|
||||
"expose_headers": [
|
||||
"X-Request-ID",
|
||||
"X-Response-Time",
|
||||
"X-Rate-Limit-Remaining",
|
||||
"X-Rate-Limit-Reset",
|
||||
],
|
||||
"max_age": 3600, # 1 hour for production
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def api_only_config(allowed_origins: List[str]) -> dict:
|
||||
"""Get CORS configuration for API-only access."""
|
||||
return {
|
||||
"allow_origins": allowed_origins,
|
||||
"allow_credentials": False,
|
||||
"allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
"allow_headers": [
|
||||
"Accept",
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
"X-Request-ID",
|
||||
],
|
||||
"expose_headers": [
|
||||
"X-Request-ID",
|
||||
"X-Rate-Limit-Remaining",
|
||||
"X-Rate-Limit-Reset",
|
||||
],
|
||||
"max_age": 3600,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def websocket_config(allowed_origins: List[str]) -> dict:
|
||||
"""Get CORS configuration for WebSocket connections."""
|
||||
return {
|
||||
"allow_origins": allowed_origins,
|
||||
"allow_credentials": True,
|
||||
"allow_methods": ["GET", "OPTIONS"],
|
||||
"allow_headers": [
|
||||
"Accept",
|
||||
"Authorization",
|
||||
"Sec-WebSocket-Protocol",
|
||||
"Sec-WebSocket-Extensions",
|
||||
],
|
||||
"expose_headers": [],
|
||||
"max_age": 86400, # 24 hours for WebSocket
|
||||
}
|
||||
|
||||
|
||||
def validate_cors_config(settings: Settings) -> List[str]:
|
||||
"""Validate CORS configuration and return issues."""
|
||||
issues = []
|
||||
|
||||
if not settings.cors_enabled:
|
||||
return issues
|
||||
|
||||
# Check origins
|
||||
if not settings.cors_origins:
|
||||
issues.append("CORS is enabled but no origins are configured")
|
||||
|
||||
# Check for wildcard in production
|
||||
if settings.is_production and "*" in settings.cors_origins:
|
||||
issues.append("Wildcard origin (*) should not be used in production")
|
||||
|
||||
# Validate origin formats
|
||||
for origin in settings.cors_origins:
|
||||
if origin != "*" and not origin.startswith(("http://", "https://")):
|
||||
issues.append(f"Invalid origin format: {origin}")
|
||||
|
||||
# Check credentials with wildcard
|
||||
if settings.cors_allow_credentials and "*" in settings.cors_origins:
|
||||
issues.append("Cannot use credentials with wildcard origin")
|
||||
|
||||
return issues
|
||||
|
||||
|
||||
def get_cors_headers_for_origin(origin: str, settings: Settings) -> dict:
|
||||
"""Get appropriate CORS headers for a specific origin."""
|
||||
headers = {}
|
||||
|
||||
if not settings.cors_enabled:
|
||||
return headers
|
||||
|
||||
# Check if origin is allowed
|
||||
cors_middleware = CORSMiddleware(None, settings)
|
||||
if cors_middleware._is_origin_allowed(origin):
|
||||
headers["Access-Control-Allow-Origin"] = origin
|
||||
|
||||
if settings.cors_allow_credentials:
|
||||
headers["Access-Control-Allow-Credentials"] = "true"
|
||||
|
||||
return headers
|
||||
505
v1/src/middleware/error_handler.py
Normal file
505
v1/src/middleware/error_handler.py
Normal file
@@ -0,0 +1,505 @@
|
||||
"""
|
||||
Global error handling middleware for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
import time
|
||||
from typing import Dict, Any, Optional, Callable, Union
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import Request, Response, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.logger import get_request_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ErrorResponse:
|
||||
"""Standardized error response format."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_code: str,
|
||||
message: str,
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
status_code: int = 500,
|
||||
request_id: Optional[str] = None,
|
||||
):
|
||||
self.error_code = error_code
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
self.status_code = status_code
|
||||
self.request_id = request_id
|
||||
self.timestamp = datetime.utcnow().isoformat()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON response."""
|
||||
response = {
|
||||
"error": {
|
||||
"code": self.error_code,
|
||||
"message": self.message,
|
||||
"timestamp": self.timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
if self.details:
|
||||
response["error"]["details"] = self.details
|
||||
|
||||
if self.request_id:
|
||||
response["error"]["request_id"] = self.request_id
|
||||
|
||||
return response
|
||||
|
||||
def to_response(self) -> JSONResponse:
|
||||
"""Convert to FastAPI JSONResponse."""
|
||||
headers = {}
|
||||
if self.request_id:
|
||||
headers["X-Request-ID"] = self.request_id
|
||||
|
||||
return JSONResponse(
|
||||
status_code=self.status_code,
|
||||
content=self.to_dict(),
|
||||
headers=headers
|
||||
)
|
||||
|
||||
|
||||
class ErrorHandler:
|
||||
"""Central error handler for the application."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self.include_traceback = settings.debug and settings.is_development
|
||||
self.log_errors = True
|
||||
|
||||
def handle_http_exception(self, request: Request, exc: HTTPException) -> ErrorResponse:
|
||||
"""Handle HTTP exceptions."""
|
||||
request_context = get_request_context()
|
||||
request_id = request_context.get("request_id")
|
||||
|
||||
# Log the error
|
||||
if self.log_errors:
|
||||
logger.warning(
|
||||
f"HTTP {exc.status_code}: {exc.detail} - "
|
||||
f"{request.method} {request.url.path} - "
|
||||
f"Request ID: {request_id}"
|
||||
)
|
||||
|
||||
# Determine error code
|
||||
error_code = self._get_error_code_for_status(exc.status_code)
|
||||
|
||||
# Build error details
|
||||
details = {}
|
||||
if hasattr(exc, "headers") and exc.headers:
|
||||
details["headers"] = exc.headers
|
||||
|
||||
if self.include_traceback and hasattr(exc, "__traceback__"):
|
||||
details["traceback"] = traceback.format_exception(
|
||||
type(exc), exc, exc.__traceback__
|
||||
)
|
||||
|
||||
return ErrorResponse(
|
||||
error_code=error_code,
|
||||
message=str(exc.detail),
|
||||
details=details,
|
||||
status_code=exc.status_code,
|
||||
request_id=request_id
|
||||
)
|
||||
|
||||
def handle_validation_error(self, request: Request, exc: RequestValidationError) -> ErrorResponse:
|
||||
"""Handle request validation errors."""
|
||||
request_context = get_request_context()
|
||||
request_id = request_context.get("request_id")
|
||||
|
||||
# Log the error
|
||||
if self.log_errors:
|
||||
logger.warning(
|
||||
f"Validation error: {exc.errors()} - "
|
||||
f"{request.method} {request.url.path} - "
|
||||
f"Request ID: {request_id}"
|
||||
)
|
||||
|
||||
# Format validation errors
|
||||
validation_details = []
|
||||
for error in exc.errors():
|
||||
validation_details.append({
|
||||
"field": ".".join(str(loc) for loc in error["loc"]),
|
||||
"message": error["msg"],
|
||||
"type": error["type"],
|
||||
"input": error.get("input"),
|
||||
})
|
||||
|
||||
details = {
|
||||
"validation_errors": validation_details,
|
||||
"error_count": len(validation_details)
|
||||
}
|
||||
|
||||
if self.include_traceback:
|
||||
details["traceback"] = traceback.format_exception(
|
||||
type(exc), exc, exc.__traceback__
|
||||
)
|
||||
|
||||
return ErrorResponse(
|
||||
error_code="VALIDATION_ERROR",
|
||||
message="Request validation failed",
|
||||
details=details,
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
request_id=request_id
|
||||
)
|
||||
|
||||
def handle_pydantic_error(self, request: Request, exc: ValidationError) -> ErrorResponse:
|
||||
"""Handle Pydantic validation errors."""
|
||||
request_context = get_request_context()
|
||||
request_id = request_context.get("request_id")
|
||||
|
||||
# Log the error
|
||||
if self.log_errors:
|
||||
logger.warning(
|
||||
f"Pydantic validation error: {exc.errors()} - "
|
||||
f"{request.method} {request.url.path} - "
|
||||
f"Request ID: {request_id}"
|
||||
)
|
||||
|
||||
# Format validation errors
|
||||
validation_details = []
|
||||
for error in exc.errors():
|
||||
validation_details.append({
|
||||
"field": ".".join(str(loc) for loc in error["loc"]),
|
||||
"message": error["msg"],
|
||||
"type": error["type"],
|
||||
})
|
||||
|
||||
details = {
|
||||
"validation_errors": validation_details,
|
||||
"error_count": len(validation_details)
|
||||
}
|
||||
|
||||
return ErrorResponse(
|
||||
error_code="DATA_VALIDATION_ERROR",
|
||||
message="Data validation failed",
|
||||
details=details,
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
request_id=request_id
|
||||
)
|
||||
|
||||
def handle_generic_exception(self, request: Request, exc: Exception) -> ErrorResponse:
|
||||
"""Handle generic exceptions."""
|
||||
request_context = get_request_context()
|
||||
request_id = request_context.get("request_id")
|
||||
|
||||
# Log the error
|
||||
if self.log_errors:
|
||||
logger.error(
|
||||
f"Unhandled exception: {type(exc).__name__}: {exc} - "
|
||||
f"{request.method} {request.url.path} - "
|
||||
f"Request ID: {request_id}",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# Determine error details
|
||||
details = {
|
||||
"exception_type": type(exc).__name__,
|
||||
}
|
||||
|
||||
if self.include_traceback:
|
||||
details["traceback"] = traceback.format_exception(
|
||||
type(exc), exc, exc.__traceback__
|
||||
)
|
||||
|
||||
# Don't expose internal error details in production
|
||||
if self.settings.is_production:
|
||||
message = "An internal server error occurred"
|
||||
else:
|
||||
message = str(exc) or "An unexpected error occurred"
|
||||
|
||||
return ErrorResponse(
|
||||
error_code="INTERNAL_SERVER_ERROR",
|
||||
message=message,
|
||||
details=details,
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
request_id=request_id
|
||||
)
|
||||
|
||||
def handle_database_error(self, request: Request, exc: Exception) -> ErrorResponse:
|
||||
"""Handle database-related errors."""
|
||||
request_context = get_request_context()
|
||||
request_id = request_context.get("request_id")
|
||||
|
||||
# Log the error
|
||||
if self.log_errors:
|
||||
logger.error(
|
||||
f"Database error: {type(exc).__name__}: {exc} - "
|
||||
f"{request.method} {request.url.path} - "
|
||||
f"Request ID: {request_id}",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
details = {
|
||||
"exception_type": type(exc).__name__,
|
||||
"category": "database"
|
||||
}
|
||||
|
||||
if self.include_traceback:
|
||||
details["traceback"] = traceback.format_exception(
|
||||
type(exc), exc, exc.__traceback__
|
||||
)
|
||||
|
||||
return ErrorResponse(
|
||||
error_code="DATABASE_ERROR",
|
||||
message="Database operation failed" if self.settings.is_production else str(exc),
|
||||
details=details,
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
request_id=request_id
|
||||
)
|
||||
|
||||
def handle_external_service_error(self, request: Request, exc: Exception) -> ErrorResponse:
|
||||
"""Handle external service errors."""
|
||||
request_context = get_request_context()
|
||||
request_id = request_context.get("request_id")
|
||||
|
||||
# Log the error
|
||||
if self.log_errors:
|
||||
logger.error(
|
||||
f"External service error: {type(exc).__name__}: {exc} - "
|
||||
f"{request.method} {request.url.path} - "
|
||||
f"Request ID: {request_id}",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
details = {
|
||||
"exception_type": type(exc).__name__,
|
||||
"category": "external_service"
|
||||
}
|
||||
|
||||
return ErrorResponse(
|
||||
error_code="EXTERNAL_SERVICE_ERROR",
|
||||
message="External service unavailable" if self.settings.is_production else str(exc),
|
||||
details=details,
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
request_id=request_id
|
||||
)
|
||||
|
||||
def _get_error_code_for_status(self, status_code: int) -> str:
|
||||
"""Get error code for HTTP status code."""
|
||||
error_codes = {
|
||||
400: "BAD_REQUEST",
|
||||
401: "UNAUTHORIZED",
|
||||
403: "FORBIDDEN",
|
||||
404: "NOT_FOUND",
|
||||
405: "METHOD_NOT_ALLOWED",
|
||||
409: "CONFLICT",
|
||||
422: "UNPROCESSABLE_ENTITY",
|
||||
429: "TOO_MANY_REQUESTS",
|
||||
500: "INTERNAL_SERVER_ERROR",
|
||||
502: "BAD_GATEWAY",
|
||||
503: "SERVICE_UNAVAILABLE",
|
||||
504: "GATEWAY_TIMEOUT",
|
||||
}
|
||||
|
||||
return error_codes.get(status_code, "HTTP_ERROR")
|
||||
|
||||
|
||||
class ErrorHandlingMiddleware:
|
||||
"""Error handling middleware for FastAPI."""
|
||||
|
||||
def __init__(self, app, settings: Settings):
|
||||
self.app = app
|
||||
self.settings = settings
|
||||
self.error_handler = ErrorHandler(settings)
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
"""Process request through error handling middleware."""
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
await self.app(scope, receive, send)
|
||||
except Exception as exc:
|
||||
# Create a mock request for error handling
|
||||
from starlette.requests import Request
|
||||
request = Request(scope, receive)
|
||||
|
||||
# Handle different exception types
|
||||
if isinstance(exc, HTTPException):
|
||||
error_response = self.error_handler.handle_http_exception(request, exc)
|
||||
elif isinstance(exc, RequestValidationError):
|
||||
error_response = self.error_handler.handle_validation_error(request, exc)
|
||||
elif isinstance(exc, ValidationError):
|
||||
error_response = self.error_handler.handle_pydantic_error(request, exc)
|
||||
else:
|
||||
# Check for specific error types
|
||||
if self._is_database_error(exc):
|
||||
error_response = self.error_handler.handle_database_error(request, exc)
|
||||
elif self._is_external_service_error(exc):
|
||||
error_response = self.error_handler.handle_external_service_error(request, exc)
|
||||
else:
|
||||
error_response = self.error_handler.handle_generic_exception(request, exc)
|
||||
|
||||
# Send the error response
|
||||
response = error_response.to_response()
|
||||
await response(scope, receive, send)
|
||||
|
||||
finally:
|
||||
# Log request processing time
|
||||
processing_time = time.time() - start_time
|
||||
logger.debug(f"Error handling middleware processing time: {processing_time:.3f}s")
|
||||
|
||||
def _is_database_error(self, exc: Exception) -> bool:
|
||||
"""Check if exception is database-related."""
|
||||
database_exceptions = [
|
||||
"sqlalchemy",
|
||||
"psycopg2",
|
||||
"pymongo",
|
||||
"redis",
|
||||
"ConnectionError",
|
||||
"OperationalError",
|
||||
"IntegrityError",
|
||||
]
|
||||
|
||||
exc_module = getattr(type(exc), "__module__", "")
|
||||
exc_name = type(exc).__name__
|
||||
|
||||
return any(
|
||||
db_exc in exc_module or db_exc in exc_name
|
||||
for db_exc in database_exceptions
|
||||
)
|
||||
|
||||
def _is_external_service_error(self, exc: Exception) -> bool:
|
||||
"""Check if exception is external service-related."""
|
||||
external_exceptions = [
|
||||
"requests",
|
||||
"httpx",
|
||||
"aiohttp",
|
||||
"urllib",
|
||||
"ConnectionError",
|
||||
"TimeoutError",
|
||||
"ConnectTimeout",
|
||||
"ReadTimeout",
|
||||
]
|
||||
|
||||
exc_module = getattr(type(exc), "__module__", "")
|
||||
exc_name = type(exc).__name__
|
||||
|
||||
return any(
|
||||
ext_exc in exc_module or ext_exc in exc_name
|
||||
for ext_exc in external_exceptions
|
||||
)
|
||||
|
||||
|
||||
def setup_error_handling(app, settings: Settings):
|
||||
"""Setup error handling for the application."""
|
||||
logger.info("Setting up error handling middleware")
|
||||
|
||||
error_handler = ErrorHandler(settings)
|
||||
|
||||
# Add exception handlers
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, exc: HTTPException):
|
||||
error_response = error_handler.handle_http_exception(request, exc)
|
||||
return error_response.to_response()
|
||||
|
||||
@app.exception_handler(StarletteHTTPException)
|
||||
async def starlette_http_exception_handler(request: Request, exc: StarletteHTTPException):
|
||||
# Convert Starlette HTTPException to FastAPI HTTPException
|
||||
fastapi_exc = HTTPException(status_code=exc.status_code, detail=exc.detail)
|
||||
error_response = error_handler.handle_http_exception(request, fastapi_exc)
|
||||
return error_response.to_response()
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
error_response = error_handler.handle_validation_error(request, exc)
|
||||
return error_response.to_response()
|
||||
|
||||
@app.exception_handler(ValidationError)
|
||||
async def pydantic_exception_handler(request: Request, exc: ValidationError):
|
||||
error_response = error_handler.handle_pydantic_error(request, exc)
|
||||
return error_response.to_response()
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def generic_exception_handler(request: Request, exc: Exception):
|
||||
error_response = error_handler.handle_generic_exception(request, exc)
|
||||
return error_response.to_response()
|
||||
|
||||
# Add middleware for additional error handling
|
||||
# Note: We use exception handlers instead of custom middleware to avoid ASGI conflicts
|
||||
# The middleware approach is commented out but kept for reference
|
||||
# middleware = ErrorHandlingMiddleware(app, settings)
|
||||
# app.add_middleware(ErrorHandlingMiddleware, settings=settings)
|
||||
|
||||
logger.info("Error handling configured")
|
||||
|
||||
|
||||
class CustomHTTPException(HTTPException):
|
||||
"""Custom HTTP exception with additional context."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
detail: str,
|
||||
error_code: Optional[str] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
super().__init__(status_code=status_code, detail=detail, headers=headers)
|
||||
self.error_code = error_code
|
||||
self.context = context or {}
|
||||
|
||||
|
||||
class BusinessLogicError(CustomHTTPException):
|
||||
"""Exception for business logic errors."""
|
||||
|
||||
def __init__(self, message: str, context: Optional[Dict[str, Any]] = None):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=message,
|
||||
error_code="BUSINESS_LOGIC_ERROR",
|
||||
context=context
|
||||
)
|
||||
|
||||
|
||||
class ResourceNotFoundError(CustomHTTPException):
|
||||
"""Exception for resource not found errors."""
|
||||
|
||||
def __init__(self, resource: str, identifier: str):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"{resource} not found",
|
||||
error_code="RESOURCE_NOT_FOUND",
|
||||
context={"resource": resource, "identifier": identifier}
|
||||
)
|
||||
|
||||
|
||||
class ConflictError(CustomHTTPException):
|
||||
"""Exception for conflict errors."""
|
||||
|
||||
def __init__(self, message: str, context: Optional[Dict[str, Any]] = None):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=message,
|
||||
error_code="CONFLICT_ERROR",
|
||||
context=context
|
||||
)
|
||||
|
||||
|
||||
class ServiceUnavailableError(CustomHTTPException):
|
||||
"""Exception for service unavailable errors."""
|
||||
|
||||
def __init__(self, service: str, reason: Optional[str] = None):
|
||||
detail = f"{service} service is unavailable"
|
||||
if reason:
|
||||
detail += f": {reason}"
|
||||
|
||||
super().__init__(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=detail,
|
||||
error_code="SERVICE_UNAVAILABLE",
|
||||
context={"service": service, "reason": reason}
|
||||
)
|
||||
465
v1/src/middleware/rate_limit.py
Normal file
465
v1/src/middleware/rate_limit.py
Normal file
@@ -0,0 +1,465 @@
|
||||
"""
|
||||
Rate limiting middleware for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Any, Optional, Callable, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fastapi import Request, Response, HTTPException, status
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from src.config.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitInfo:
|
||||
"""Rate limit information."""
|
||||
requests: int
|
||||
window_start: float
|
||||
window_size: int
|
||||
limit: int
|
||||
|
||||
@property
|
||||
def remaining(self) -> int:
|
||||
"""Get remaining requests in current window."""
|
||||
return max(0, self.limit - self.requests)
|
||||
|
||||
@property
|
||||
def reset_time(self) -> float:
|
||||
"""Get time when window resets."""
|
||||
return self.window_start + self.window_size
|
||||
|
||||
@property
|
||||
def is_exceeded(self) -> bool:
|
||||
"""Check if rate limit is exceeded."""
|
||||
return self.requests >= self.limit
|
||||
|
||||
|
||||
class TokenBucket:
|
||||
"""Token bucket algorithm for rate limiting."""
|
||||
|
||||
def __init__(self, capacity: int, refill_rate: float):
|
||||
self.capacity = capacity
|
||||
self.tokens = capacity
|
||||
self.refill_rate = refill_rate
|
||||
self.last_refill = time.time()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def consume(self, tokens: int = 1) -> bool:
|
||||
"""Try to consume tokens from bucket."""
|
||||
async with self._lock:
|
||||
now = time.time()
|
||||
|
||||
# Refill tokens based on time elapsed
|
||||
time_passed = now - self.last_refill
|
||||
tokens_to_add = time_passed * self.refill_rate
|
||||
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
|
||||
self.last_refill = now
|
||||
|
||||
# Check if we have enough tokens
|
||||
if self.tokens >= tokens:
|
||||
self.tokens -= tokens
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_info(self) -> Dict[str, Any]:
|
||||
"""Get bucket information."""
|
||||
return {
|
||||
"capacity": self.capacity,
|
||||
"tokens": self.tokens,
|
||||
"refill_rate": self.refill_rate,
|
||||
"last_refill": self.last_refill
|
||||
}
|
||||
|
||||
|
||||
class SlidingWindowCounter:
|
||||
"""Sliding window counter for rate limiting."""
|
||||
|
||||
def __init__(self, window_size: int, limit: int):
|
||||
self.window_size = window_size
|
||||
self.limit = limit
|
||||
self.requests = deque()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def is_allowed(self) -> Tuple[bool, RateLimitInfo]:
|
||||
"""Check if request is allowed."""
|
||||
async with self._lock:
|
||||
now = time.time()
|
||||
window_start = now - self.window_size
|
||||
|
||||
# Remove old requests outside the window
|
||||
while self.requests and self.requests[0] < window_start:
|
||||
self.requests.popleft()
|
||||
|
||||
# Check if limit is exceeded
|
||||
current_requests = len(self.requests)
|
||||
allowed = current_requests < self.limit
|
||||
|
||||
if allowed:
|
||||
self.requests.append(now)
|
||||
|
||||
rate_limit_info = RateLimitInfo(
|
||||
requests=current_requests + (1 if allowed else 0),
|
||||
window_start=window_start,
|
||||
window_size=self.window_size,
|
||||
limit=self.limit
|
||||
)
|
||||
|
||||
return allowed, rate_limit_info
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Rate limiter with multiple algorithms."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self.enabled = settings.enable_rate_limiting
|
||||
|
||||
# Rate limit configurations
|
||||
self.default_limit = settings.rate_limit_requests
|
||||
self.authenticated_limit = settings.rate_limit_authenticated_requests
|
||||
self.window_size = settings.rate_limit_window
|
||||
|
||||
# Storage for rate limit data
|
||||
self._sliding_windows: Dict[str, SlidingWindowCounter] = {}
|
||||
self._token_buckets: Dict[str, TokenBucket] = {}
|
||||
|
||||
# Cleanup task
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
self._cleanup_interval = 300 # 5 minutes
|
||||
|
||||
async def start(self):
|
||||
"""Start rate limiter background tasks."""
|
||||
if self.enabled:
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
logger.info("Rate limiter started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop rate limiter background tasks."""
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Rate limiter stopped")
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""Background task to cleanup old rate limit data."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self._cleanup_interval)
|
||||
await self._cleanup_old_data()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in rate limiter cleanup: {e}")
|
||||
|
||||
async def _cleanup_old_data(self):
|
||||
"""Remove old rate limit data."""
|
||||
now = time.time()
|
||||
cutoff = now - (self.window_size * 2) # Keep data for 2 windows
|
||||
|
||||
# Cleanup sliding windows
|
||||
keys_to_remove = []
|
||||
for key, window in self._sliding_windows.items():
|
||||
# Remove old requests
|
||||
while window.requests and window.requests[0] < cutoff:
|
||||
window.requests.popleft()
|
||||
|
||||
# Remove empty windows
|
||||
if not window.requests:
|
||||
keys_to_remove.append(key)
|
||||
|
||||
for key in keys_to_remove:
|
||||
del self._sliding_windows[key]
|
||||
|
||||
logger.debug(f"Cleaned up {len(keys_to_remove)} old rate limit windows")
|
||||
|
||||
def _get_client_identifier(self, request: Request) -> str:
|
||||
"""Get client identifier for rate limiting."""
|
||||
# Try to get user ID from authenticated request
|
||||
user = getattr(request.state, "user", None)
|
||||
if user:
|
||||
return f"user:{user.get('username', 'unknown')}"
|
||||
|
||||
# Fall back to IP address
|
||||
client_ip = self._get_client_ip(request)
|
||||
return f"ip:{client_ip}"
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Get client IP address."""
|
||||
# Check for forwarded headers
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
# Fall back to direct connection
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
def _get_rate_limit(self, request: Request) -> int:
|
||||
"""Get rate limit for request."""
|
||||
# Check if user is authenticated
|
||||
user = getattr(request.state, "user", None)
|
||||
if user:
|
||||
return self.authenticated_limit
|
||||
|
||||
return self.default_limit
|
||||
|
||||
def _get_rate_limit_key(self, request: Request) -> str:
|
||||
"""Get rate limit key for request."""
|
||||
client_id = self._get_client_identifier(request)
|
||||
endpoint = f"{request.method}:{request.url.path}"
|
||||
return f"{client_id}:{endpoint}"
|
||||
|
||||
async def check_rate_limit(self, request: Request) -> Tuple[bool, RateLimitInfo]:
|
||||
"""Check if request is within rate limits."""
|
||||
if not self.enabled:
|
||||
# Return dummy info when rate limiting is disabled
|
||||
return True, RateLimitInfo(
|
||||
requests=0,
|
||||
window_start=time.time(),
|
||||
window_size=self.window_size,
|
||||
limit=float('inf')
|
||||
)
|
||||
|
||||
key = self._get_rate_limit_key(request)
|
||||
limit = self._get_rate_limit(request)
|
||||
|
||||
# Get or create sliding window counter
|
||||
if key not in self._sliding_windows:
|
||||
self._sliding_windows[key] = SlidingWindowCounter(self.window_size, limit)
|
||||
|
||||
window = self._sliding_windows[key]
|
||||
|
||||
# Update limit if it changed (e.g., user authenticated)
|
||||
window.limit = limit
|
||||
|
||||
return await window.is_allowed()
|
||||
|
||||
async def check_token_bucket(self, request: Request, tokens: int = 1) -> bool:
|
||||
"""Check rate limit using token bucket algorithm."""
|
||||
if not self.enabled:
|
||||
return True
|
||||
|
||||
key = self._get_client_identifier(request)
|
||||
limit = self._get_rate_limit(request)
|
||||
|
||||
# Get or create token bucket
|
||||
if key not in self._token_buckets:
|
||||
# Refill rate: limit per window size
|
||||
refill_rate = limit / self.window_size
|
||||
self._token_buckets[key] = TokenBucket(limit, refill_rate)
|
||||
|
||||
bucket = self._token_buckets[key]
|
||||
return await bucket.consume(tokens)
|
||||
|
||||
def get_rate_limit_headers(self, rate_limit_info: RateLimitInfo) -> Dict[str, str]:
|
||||
"""Get rate limit headers for response."""
|
||||
return {
|
||||
"X-RateLimit-Limit": str(rate_limit_info.limit),
|
||||
"X-RateLimit-Remaining": str(rate_limit_info.remaining),
|
||||
"X-RateLimit-Reset": str(int(rate_limit_info.reset_time)),
|
||||
"X-RateLimit-Window": str(rate_limit_info.window_size),
|
||||
}
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get rate limiter statistics."""
|
||||
return {
|
||||
"enabled": self.enabled,
|
||||
"default_limit": self.default_limit,
|
||||
"authenticated_limit": self.authenticated_limit,
|
||||
"window_size": self.window_size,
|
||||
"active_windows": len(self._sliding_windows),
|
||||
"active_buckets": len(self._token_buckets),
|
||||
}
|
||||
|
||||
|
||||
class RateLimitMiddleware:
|
||||
"""Rate limiting middleware for FastAPI."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self.rate_limiter = RateLimiter(settings)
|
||||
self.enabled = settings.enable_rate_limiting
|
||||
|
||||
async def __call__(self, request: Request, call_next: Callable) -> Response:
|
||||
"""Process request through rate limiting middleware."""
|
||||
if not self.enabled:
|
||||
return await call_next(request)
|
||||
|
||||
# Skip rate limiting for certain paths
|
||||
if self._should_skip_rate_limit(request):
|
||||
return await call_next(request)
|
||||
|
||||
try:
|
||||
# Check rate limit
|
||||
allowed, rate_limit_info = await self.rate_limiter.check_rate_limit(request)
|
||||
|
||||
if not allowed:
|
||||
# Rate limit exceeded
|
||||
logger.warning(
|
||||
f"Rate limit exceeded for {self.rate_limiter._get_client_identifier(request)} "
|
||||
f"on {request.method} {request.url.path}"
|
||||
)
|
||||
|
||||
headers = self.rate_limiter.get_rate_limit_headers(rate_limit_info)
|
||||
headers["Retry-After"] = str(int(rate_limit_info.reset_time - time.time()))
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Rate limit exceeded",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add rate limit headers to response
|
||||
headers = self.rate_limiter.get_rate_limit_headers(rate_limit_info)
|
||||
for key, value in headers.items():
|
||||
response.headers[key] = value
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Rate limiting middleware error: {e}")
|
||||
# Continue without rate limiting on error
|
||||
return await call_next(request)
|
||||
|
||||
def _should_skip_rate_limit(self, request: Request) -> bool:
|
||||
"""Check if rate limiting should be skipped for this request."""
|
||||
path = request.url.path
|
||||
|
||||
# Skip rate limiting for these paths
|
||||
skip_paths = [
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/static",
|
||||
]
|
||||
|
||||
return any(path.startswith(skip_path) for skip_path in skip_paths)
|
||||
|
||||
async def start(self):
|
||||
"""Start rate limiting middleware."""
|
||||
await self.rate_limiter.start()
|
||||
|
||||
async def stop(self):
|
||||
"""Stop rate limiting middleware."""
|
||||
await self.rate_limiter.stop()
|
||||
|
||||
|
||||
# Global rate limit middleware instance
|
||||
_rate_limit_middleware: Optional[RateLimitMiddleware] = None
|
||||
|
||||
|
||||
def get_rate_limit_middleware(settings: Settings) -> RateLimitMiddleware:
|
||||
"""Get rate limit middleware instance."""
|
||||
global _rate_limit_middleware
|
||||
if _rate_limit_middleware is None:
|
||||
_rate_limit_middleware = RateLimitMiddleware(settings)
|
||||
return _rate_limit_middleware
|
||||
|
||||
|
||||
def setup_rate_limiting(app: ASGIApp, settings: Settings) -> ASGIApp:
|
||||
"""Setup rate limiting middleware for the application."""
|
||||
if settings.enable_rate_limiting:
|
||||
logger.info("Setting up rate limiting middleware")
|
||||
|
||||
middleware = get_rate_limit_middleware(settings)
|
||||
|
||||
# Add middleware to app
|
||||
@app.middleware("http")
|
||||
async def rate_limit_middleware(request: Request, call_next):
|
||||
return await middleware(request, call_next)
|
||||
|
||||
logger.info(
|
||||
f"Rate limiting enabled - Default: {settings.rate_limit_requests}/"
|
||||
f"{settings.rate_limit_window}s, Authenticated: "
|
||||
f"{settings.rate_limit_authenticated_requests}/{settings.rate_limit_window}s"
|
||||
)
|
||||
else:
|
||||
logger.info("Rate limiting disabled")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class RateLimitConfig:
|
||||
"""Rate limiting configuration helper."""
|
||||
|
||||
@staticmethod
|
||||
def development_config() -> dict:
|
||||
"""Get rate limiting configuration for development."""
|
||||
return {
|
||||
"enable_rate_limiting": False, # Disabled in development
|
||||
"rate_limit_requests": 1000,
|
||||
"rate_limit_authenticated_requests": 5000,
|
||||
"rate_limit_window": 3600, # 1 hour
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def production_config() -> dict:
|
||||
"""Get rate limiting configuration for production."""
|
||||
return {
|
||||
"enable_rate_limiting": True,
|
||||
"rate_limit_requests": 100, # 100 requests per hour for unauthenticated
|
||||
"rate_limit_authenticated_requests": 1000, # 1000 requests per hour for authenticated
|
||||
"rate_limit_window": 3600, # 1 hour
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def api_config() -> dict:
|
||||
"""Get rate limiting configuration for API access."""
|
||||
return {
|
||||
"enable_rate_limiting": True,
|
||||
"rate_limit_requests": 60, # 60 requests per minute
|
||||
"rate_limit_authenticated_requests": 300, # 300 requests per minute
|
||||
"rate_limit_window": 60, # 1 minute
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def strict_config() -> dict:
|
||||
"""Get strict rate limiting configuration."""
|
||||
return {
|
||||
"enable_rate_limiting": True,
|
||||
"rate_limit_requests": 10, # 10 requests per minute
|
||||
"rate_limit_authenticated_requests": 100, # 100 requests per minute
|
||||
"rate_limit_window": 60, # 1 minute
|
||||
}
|
||||
|
||||
|
||||
def validate_rate_limit_config(settings: Settings) -> list:
|
||||
"""Validate rate limiting configuration."""
|
||||
issues = []
|
||||
|
||||
if settings.enable_rate_limiting:
|
||||
if settings.rate_limit_requests <= 0:
|
||||
issues.append("Rate limit requests must be positive")
|
||||
|
||||
if settings.rate_limit_authenticated_requests <= 0:
|
||||
issues.append("Authenticated rate limit requests must be positive")
|
||||
|
||||
if settings.rate_limit_window <= 0:
|
||||
issues.append("Rate limit window must be positive")
|
||||
|
||||
if settings.rate_limit_authenticated_requests < settings.rate_limit_requests:
|
||||
issues.append("Authenticated rate limit should be higher than default rate limit")
|
||||
|
||||
return issues
|
||||
Reference in New Issue
Block a user