feat: Complete Rust port of WiFi-DensePose with modular crates
Major changes: - Organized Python v1 implementation into v1/ subdirectory - Created Rust workspace with 9 modular crates: - wifi-densepose-core: Core types, traits, errors - wifi-densepose-signal: CSI processing, phase sanitization, FFT - wifi-densepose-nn: Neural network inference (ONNX/Candle/tch) - wifi-densepose-api: Axum-based REST/WebSocket API - wifi-densepose-db: SQLx database layer - wifi-densepose-config: Configuration management - wifi-densepose-hardware: Hardware abstraction - wifi-densepose-wasm: WebAssembly bindings - wifi-densepose-cli: Command-line interface Documentation: - ADR-001: Workspace structure - ADR-002: Signal processing library selection - ADR-003: Neural network inference strategy - DDD domain model with bounded contexts Testing: - 69 tests passing across all crates - Signal processing: 45 tests - Neural networks: 21 tests - Core: 3 doc tests Performance targets: - 10x faster CSI processing (~0.5ms vs ~5ms) - 5x lower memory usage (~100MB vs ~500MB) - WASM support for browser deployment
This commit is contained in:
8
v1/src/api/middleware/__init__.py
Normal file
8
v1/src/api/middleware/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
FastAPI middleware package
|
||||
"""
|
||||
|
||||
from .auth import AuthMiddleware
|
||||
from .rate_limit import RateLimitMiddleware
|
||||
|
||||
__all__ = ["AuthMiddleware", "RateLimitMiddleware"]
|
||||
322
v1/src/api/middleware/auth.py
Normal file
322
v1/src/api/middleware/auth.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
JWT Authentication middleware for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import Request, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from src.config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
"""JWT Authentication middleware."""
|
||||
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
self.settings = get_settings()
|
||||
|
||||
# Paths that don't require authentication
|
||||
self.public_paths = {
|
||||
"/",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/health",
|
||||
"/ready",
|
||||
"/live",
|
||||
"/version",
|
||||
"/metrics"
|
||||
}
|
||||
|
||||
# Paths that require authentication
|
||||
self.protected_paths = {
|
||||
"/api/v1/pose/analyze",
|
||||
"/api/v1/pose/calibrate",
|
||||
"/api/v1/pose/historical",
|
||||
"/api/v1/stream/start",
|
||||
"/api/v1/stream/stop",
|
||||
"/api/v1/stream/clients",
|
||||
"/api/v1/stream/broadcast"
|
||||
}
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""Process request through authentication middleware."""
|
||||
|
||||
# Skip authentication for public paths
|
||||
if self._is_public_path(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Extract and validate token
|
||||
token = self._extract_token(request)
|
||||
|
||||
if token:
|
||||
try:
|
||||
# Verify token and add user info to request state
|
||||
user_data = await self._verify_token(token)
|
||||
request.state.user = user_data
|
||||
request.state.authenticated = True
|
||||
|
||||
logger.debug(f"Authenticated user: {user_data.get('id')}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Token validation failed: {e}")
|
||||
|
||||
# For protected paths, return 401
|
||||
if self._is_protected_path(request.url.path):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"error": {
|
||||
"code": 401,
|
||||
"message": "Invalid or expired token",
|
||||
"type": "authentication_error"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# For other paths, continue without authentication
|
||||
request.state.user = None
|
||||
request.state.authenticated = False
|
||||
else:
|
||||
# No token provided
|
||||
if self._is_protected_path(request.url.path):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"error": {
|
||||
"code": 401,
|
||||
"message": "Authentication required",
|
||||
"type": "authentication_error"
|
||||
}
|
||||
},
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
request.state.user = None
|
||||
request.state.authenticated = False
|
||||
|
||||
# Continue with request processing
|
||||
response = await call_next(request)
|
||||
|
||||
# Add authentication headers to response
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
response.headers["X-User-ID"] = request.state.user.get("id", "")
|
||||
response.headers["X-Authenticated"] = "true"
|
||||
else:
|
||||
response.headers["X-Authenticated"] = "false"
|
||||
|
||||
return response
|
||||
|
||||
def _is_public_path(self, path: str) -> bool:
|
||||
"""Check if path is public (doesn't require authentication)."""
|
||||
# Exact match
|
||||
if path in self.public_paths:
|
||||
return True
|
||||
|
||||
# Pattern matching for public paths
|
||||
public_patterns = [
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/api/v1/pose/current", # Allow anonymous access to current pose data
|
||||
"/api/v1/pose/zones/", # Allow anonymous access to zone data
|
||||
"/api/v1/pose/activities", # Allow anonymous access to activities
|
||||
"/api/v1/pose/stats", # Allow anonymous access to stats
|
||||
"/api/v1/stream/status" # Allow anonymous access to stream status
|
||||
]
|
||||
|
||||
for pattern in public_patterns:
|
||||
if path.startswith(pattern):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_protected_path(self, path: str) -> bool:
|
||||
"""Check if path requires authentication."""
|
||||
# Exact match
|
||||
if path in self.protected_paths:
|
||||
return True
|
||||
|
||||
# Pattern matching for protected paths
|
||||
protected_patterns = [
|
||||
"/api/v1/pose/analyze",
|
||||
"/api/v1/pose/calibrate",
|
||||
"/api/v1/pose/historical",
|
||||
"/api/v1/stream/start",
|
||||
"/api/v1/stream/stop",
|
||||
"/api/v1/stream/clients",
|
||||
"/api/v1/stream/broadcast"
|
||||
]
|
||||
|
||||
for pattern in protected_patterns:
|
||||
if path.startswith(pattern):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _extract_token(self, request: Request) -> Optional[str]:
|
||||
"""Extract JWT token from request."""
|
||||
# Check Authorization header
|
||||
auth_header = request.headers.get("authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
return auth_header.split(" ")[1]
|
||||
|
||||
# Check query parameter (for WebSocket connections)
|
||||
token = request.query_params.get("token")
|
||||
if token:
|
||||
return token
|
||||
|
||||
# Check cookie
|
||||
token = request.cookies.get("access_token")
|
||||
if token:
|
||||
return token
|
||||
|
||||
return None
|
||||
|
||||
async def _verify_token(self, token: str) -> Dict[str, Any]:
|
||||
"""Verify JWT token and return user data."""
|
||||
try:
|
||||
# Decode JWT token
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
self.settings.secret_key,
|
||||
algorithms=[self.settings.jwt_algorithm]
|
||||
)
|
||||
|
||||
# Extract user information
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
raise ValueError("Token missing user ID")
|
||||
|
||||
# Check token expiration
|
||||
exp = payload.get("exp")
|
||||
if exp and datetime.utcnow() > datetime.fromtimestamp(exp):
|
||||
raise ValueError("Token expired")
|
||||
|
||||
# Build user object
|
||||
user_data = {
|
||||
"id": user_id,
|
||||
"username": payload.get("username"),
|
||||
"email": payload.get("email"),
|
||||
"is_admin": payload.get("is_admin", False),
|
||||
"permissions": payload.get("permissions", []),
|
||||
"accessible_zones": payload.get("accessible_zones", []),
|
||||
"token_issued_at": payload.get("iat"),
|
||||
"token_expires_at": payload.get("exp"),
|
||||
"session_id": payload.get("session_id")
|
||||
}
|
||||
|
||||
return user_data
|
||||
|
||||
except JWTError as e:
|
||||
raise ValueError(f"JWT validation failed: {e}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Token verification error: {e}")
|
||||
|
||||
def _log_authentication_event(self, request: Request, event_type: str, details: Dict[str, Any] = None):
|
||||
"""Log authentication events for security monitoring."""
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
user_agent = request.headers.get("user-agent", "unknown")
|
||||
|
||||
log_data = {
|
||||
"event_type": event_type,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"client_ip": client_ip,
|
||||
"user_agent": user_agent,
|
||||
"path": request.url.path,
|
||||
"method": request.method
|
||||
}
|
||||
|
||||
if details:
|
||||
log_data.update(details)
|
||||
|
||||
if event_type in ["authentication_failed", "token_expired", "invalid_token"]:
|
||||
logger.warning(f"Auth event: {log_data}")
|
||||
else:
|
||||
logger.info(f"Auth event: {log_data}")
|
||||
|
||||
|
||||
class TokenBlacklist:
|
||||
"""Simple in-memory token blacklist for logout functionality."""
|
||||
|
||||
def __init__(self):
|
||||
self._blacklisted_tokens = set()
|
||||
self._cleanup_interval = 3600 # 1 hour
|
||||
self._last_cleanup = datetime.utcnow()
|
||||
|
||||
def add_token(self, token: str):
|
||||
"""Add token to blacklist."""
|
||||
self._blacklisted_tokens.add(token)
|
||||
self._cleanup_if_needed()
|
||||
|
||||
def is_blacklisted(self, token: str) -> bool:
|
||||
"""Check if token is blacklisted."""
|
||||
self._cleanup_if_needed()
|
||||
return token in self._blacklisted_tokens
|
||||
|
||||
def _cleanup_if_needed(self):
|
||||
"""Clean up expired tokens from blacklist."""
|
||||
now = datetime.utcnow()
|
||||
if (now - self._last_cleanup).total_seconds() > self._cleanup_interval:
|
||||
# In a real implementation, you would check token expiration
|
||||
# For now, we'll just clear old tokens periodically
|
||||
self._blacklisted_tokens.clear()
|
||||
self._last_cleanup = now
|
||||
|
||||
|
||||
# Global token blacklist instance
|
||||
token_blacklist = TokenBlacklist()
|
||||
|
||||
|
||||
class SecurityHeaders:
|
||||
"""Security headers for API responses."""
|
||||
|
||||
@staticmethod
|
||||
def add_security_headers(response: Response) -> Response:
|
||||
"""Add security headers to response."""
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
response.headers["Content-Security-Policy"] = (
|
||||
"default-src 'self'; "
|
||||
"script-src 'self' 'unsafe-inline'; "
|
||||
"style-src 'self' 'unsafe-inline'; "
|
||||
"img-src 'self' data:; "
|
||||
"connect-src 'self' ws: wss:;"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class APIKeyAuth:
|
||||
"""Alternative API key authentication for service-to-service communication."""
|
||||
|
||||
def __init__(self, api_keys: Dict[str, Dict[str, Any]] = None):
|
||||
self.api_keys = api_keys or {}
|
||||
|
||||
def verify_api_key(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""Verify API key and return associated service info."""
|
||||
if api_key in self.api_keys:
|
||||
return self.api_keys[api_key]
|
||||
return None
|
||||
|
||||
def add_api_key(self, api_key: str, service_info: Dict[str, Any]):
|
||||
"""Add new API key."""
|
||||
self.api_keys[api_key] = service_info
|
||||
|
||||
def revoke_api_key(self, api_key: str):
|
||||
"""Revoke API key."""
|
||||
if api_key in self.api_keys:
|
||||
del self.api_keys[api_key]
|
||||
|
||||
|
||||
# Global API key auth instance
|
||||
api_key_auth = APIKeyAuth()
|
||||
429
v1/src/api/middleware/rate_limit.py
Normal file
429
v1/src/api/middleware/rate_limit.py
Normal file
@@ -0,0 +1,429 @@
|
||||
"""
|
||||
Rate limiting middleware for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict, deque
|
||||
|
||||
from fastapi import Request, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from src.config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Rate limiting middleware with sliding window algorithm."""
|
||||
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
self.settings = get_settings()
|
||||
|
||||
# Rate limit storage (in production, use Redis)
|
||||
self.request_counts = defaultdict(lambda: deque())
|
||||
self.blocked_clients = {}
|
||||
|
||||
# Rate limit configurations
|
||||
self.rate_limits = {
|
||||
"anonymous": {
|
||||
"requests": self.settings.rate_limit_requests,
|
||||
"window": self.settings.rate_limit_window,
|
||||
"burst": 10 # Allow burst of 10 requests
|
||||
},
|
||||
"authenticated": {
|
||||
"requests": self.settings.rate_limit_authenticated_requests,
|
||||
"window": self.settings.rate_limit_window,
|
||||
"burst": 50
|
||||
},
|
||||
"admin": {
|
||||
"requests": 10000, # Very high limit for admins
|
||||
"window": self.settings.rate_limit_window,
|
||||
"burst": 100
|
||||
}
|
||||
}
|
||||
|
||||
# Path-specific rate limits
|
||||
self.path_limits = {
|
||||
"/api/v1/pose/current": {"requests": 60, "window": 60}, # 1 per second
|
||||
"/api/v1/pose/analyze": {"requests": 10, "window": 60}, # 10 per minute
|
||||
"/api/v1/pose/calibrate": {"requests": 1, "window": 300}, # 1 per 5 minutes
|
||||
"/api/v1/stream/start": {"requests": 5, "window": 60}, # 5 per minute
|
||||
"/api/v1/stream/stop": {"requests": 5, "window": 60}, # 5 per minute
|
||||
}
|
||||
|
||||
# Exempt paths from rate limiting
|
||||
self.exempt_paths = {
|
||||
"/health",
|
||||
"/ready",
|
||||
"/live",
|
||||
"/version",
|
||||
"/metrics"
|
||||
}
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""Process request through rate limiting middleware."""
|
||||
|
||||
# Skip rate limiting for exempt paths
|
||||
if self._is_exempt_path(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Get client identifier
|
||||
client_id = self._get_client_id(request)
|
||||
|
||||
# Check if client is temporarily blocked
|
||||
if self._is_client_blocked(client_id):
|
||||
return self._create_rate_limit_response(
|
||||
"Client temporarily blocked due to excessive requests"
|
||||
)
|
||||
|
||||
# Get user type for rate limiting
|
||||
user_type = self._get_user_type(request)
|
||||
|
||||
# Check rate limits
|
||||
rate_limit_result = self._check_rate_limits(
|
||||
client_id,
|
||||
request.url.path,
|
||||
user_type
|
||||
)
|
||||
|
||||
if not rate_limit_result["allowed"]:
|
||||
# Log rate limit violation
|
||||
self._log_rate_limit_violation(request, client_id, rate_limit_result)
|
||||
|
||||
# Check if client should be temporarily blocked
|
||||
if rate_limit_result.get("violations", 0) > 5:
|
||||
self._block_client(client_id, duration=300) # 5 minutes
|
||||
|
||||
return self._create_rate_limit_response(
|
||||
rate_limit_result["message"],
|
||||
retry_after=rate_limit_result.get("retry_after", 60)
|
||||
)
|
||||
|
||||
# Record the request
|
||||
self._record_request(client_id, request.url.path)
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add rate limit headers
|
||||
self._add_rate_limit_headers(response, client_id, user_type)
|
||||
|
||||
return response
|
||||
|
||||
def _is_exempt_path(self, path: str) -> bool:
|
||||
"""Check if path is exempt from rate limiting."""
|
||||
return path in self.exempt_paths
|
||||
|
||||
def _get_client_id(self, request: Request) -> str:
|
||||
"""Get unique client identifier for rate limiting."""
|
||||
# Try to get user ID from request state (set by auth middleware)
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
return f"user:{request.state.user['id']}"
|
||||
|
||||
# Fall back to IP address
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
|
||||
# Include user agent for better identification
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
user_agent_hash = str(hash(user_agent))[:8]
|
||||
|
||||
return f"ip:{client_ip}:{user_agent_hash}"
|
||||
|
||||
def _get_user_type(self, request: Request) -> str:
|
||||
"""Determine user type for rate limiting."""
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
if request.state.user.get("is_admin", False):
|
||||
return "admin"
|
||||
return "authenticated"
|
||||
return "anonymous"
|
||||
|
||||
def _check_rate_limits(self, client_id: str, path: str, user_type: str) -> Dict:
|
||||
"""Check if request is within rate limits."""
|
||||
now = time.time()
|
||||
|
||||
# Get applicable rate limits
|
||||
general_limit = self.rate_limits[user_type]
|
||||
path_limit = self.path_limits.get(path)
|
||||
|
||||
# Check general rate limit
|
||||
general_result = self._check_limit(
|
||||
client_id,
|
||||
"general",
|
||||
general_limit["requests"],
|
||||
general_limit["window"],
|
||||
now
|
||||
)
|
||||
|
||||
if not general_result["allowed"]:
|
||||
return general_result
|
||||
|
||||
# Check path-specific rate limit if exists
|
||||
if path_limit:
|
||||
path_result = self._check_limit(
|
||||
client_id,
|
||||
f"path:{path}",
|
||||
path_limit["requests"],
|
||||
path_limit["window"],
|
||||
now
|
||||
)
|
||||
|
||||
if not path_result["allowed"]:
|
||||
return path_result
|
||||
|
||||
return {"allowed": True}
|
||||
|
||||
def _check_limit(self, client_id: str, limit_type: str, max_requests: int, window: int, now: float) -> Dict:
|
||||
"""Check specific rate limit using sliding window."""
|
||||
key = f"{client_id}:{limit_type}"
|
||||
requests = self.request_counts[key]
|
||||
|
||||
# Remove old requests outside the window
|
||||
cutoff = now - window
|
||||
while requests and requests[0] <= cutoff:
|
||||
requests.popleft()
|
||||
|
||||
# Check if limit exceeded
|
||||
if len(requests) >= max_requests:
|
||||
# Calculate retry after time
|
||||
oldest_request = requests[0] if requests else now
|
||||
retry_after = int(oldest_request + window - now) + 1
|
||||
|
||||
return {
|
||||
"allowed": False,
|
||||
"message": f"Rate limit exceeded: {max_requests} requests per {window} seconds",
|
||||
"retry_after": retry_after,
|
||||
"current_count": len(requests),
|
||||
"limit": max_requests,
|
||||
"window": window
|
||||
}
|
||||
|
||||
return {
|
||||
"allowed": True,
|
||||
"current_count": len(requests),
|
||||
"limit": max_requests,
|
||||
"window": window
|
||||
}
|
||||
|
||||
def _record_request(self, client_id: str, path: str):
|
||||
"""Record a request for rate limiting."""
|
||||
now = time.time()
|
||||
|
||||
# Record general request
|
||||
general_key = f"{client_id}:general"
|
||||
self.request_counts[general_key].append(now)
|
||||
|
||||
# Record path-specific request if path has specific limits
|
||||
if path in self.path_limits:
|
||||
path_key = f"{client_id}:path:{path}"
|
||||
self.request_counts[path_key].append(now)
|
||||
|
||||
def _is_client_blocked(self, client_id: str) -> bool:
|
||||
"""Check if client is temporarily blocked."""
|
||||
if client_id in self.blocked_clients:
|
||||
block_until = self.blocked_clients[client_id]
|
||||
if time.time() < block_until:
|
||||
return True
|
||||
else:
|
||||
# Block expired, remove it
|
||||
del self.blocked_clients[client_id]
|
||||
return False
|
||||
|
||||
def _block_client(self, client_id: str, duration: int):
|
||||
"""Temporarily block a client."""
|
||||
self.blocked_clients[client_id] = time.time() + duration
|
||||
logger.warning(f"Client {client_id} blocked for {duration} seconds due to rate limit violations")
|
||||
|
||||
def _create_rate_limit_response(self, message: str, retry_after: int = 60) -> JSONResponse:
|
||||
"""Create rate limit exceeded response."""
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": {
|
||||
"code": 429,
|
||||
"message": message,
|
||||
"type": "rate_limit_exceeded"
|
||||
}
|
||||
},
|
||||
headers={
|
||||
"Retry-After": str(retry_after),
|
||||
"X-RateLimit-Limit": "Exceeded",
|
||||
"X-RateLimit-Remaining": "0"
|
||||
}
|
||||
)
|
||||
|
||||
def _add_rate_limit_headers(self, response: Response, client_id: str, user_type: str):
|
||||
"""Add rate limit headers to response."""
|
||||
try:
|
||||
general_limit = self.rate_limits[user_type]
|
||||
general_key = f"{client_id}:general"
|
||||
current_requests = len(self.request_counts[general_key])
|
||||
|
||||
remaining = max(0, general_limit["requests"] - current_requests)
|
||||
|
||||
response.headers["X-RateLimit-Limit"] = str(general_limit["requests"])
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
response.headers["X-RateLimit-Window"] = str(general_limit["window"])
|
||||
|
||||
# Add reset time
|
||||
if self.request_counts[general_key]:
|
||||
oldest_request = self.request_counts[general_key][0]
|
||||
reset_time = int(oldest_request + general_limit["window"])
|
||||
response.headers["X-RateLimit-Reset"] = str(reset_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding rate limit headers: {e}")
|
||||
|
||||
def _log_rate_limit_violation(self, request: Request, client_id: str, result: Dict):
|
||||
"""Log rate limit violations for monitoring."""
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
user_agent = request.headers.get("user-agent", "unknown")
|
||||
|
||||
log_data = {
|
||||
"event_type": "rate_limit_violation",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"client_id": client_id,
|
||||
"client_ip": client_ip,
|
||||
"user_agent": user_agent,
|
||||
"path": request.url.path,
|
||||
"method": request.method,
|
||||
"current_count": result.get("current_count"),
|
||||
"limit": result.get("limit"),
|
||||
"window": result.get("window")
|
||||
}
|
||||
|
||||
logger.warning(f"Rate limit violation: {log_data}")
|
||||
|
||||
def cleanup_old_data(self):
|
||||
"""Clean up old rate limiting data (call periodically)."""
|
||||
now = time.time()
|
||||
cutoff = now - 3600 # Keep data for 1 hour
|
||||
|
||||
# Clean up request counts
|
||||
for key in list(self.request_counts.keys()):
|
||||
requests = self.request_counts[key]
|
||||
while requests and requests[0] <= cutoff:
|
||||
requests.popleft()
|
||||
|
||||
# Remove empty deques
|
||||
if not requests:
|
||||
del self.request_counts[key]
|
||||
|
||||
# Clean up expired blocks
|
||||
expired_blocks = [
|
||||
client_id for client_id, block_until in self.blocked_clients.items()
|
||||
if now >= block_until
|
||||
]
|
||||
|
||||
for client_id in expired_blocks:
|
||||
del self.blocked_clients[client_id]
|
||||
|
||||
|
||||
class AdaptiveRateLimit:
|
||||
"""Adaptive rate limiting based on system load."""
|
||||
|
||||
def __init__(self):
|
||||
self.base_limits = {}
|
||||
self.current_multiplier = 1.0
|
||||
self.load_history = deque(maxlen=60) # Keep 1 minute of load data
|
||||
|
||||
def update_system_load(self, cpu_percent: float, memory_percent: float):
|
||||
"""Update system load metrics."""
|
||||
load_score = (cpu_percent + memory_percent) / 2
|
||||
self.load_history.append(load_score)
|
||||
|
||||
# Calculate adaptive multiplier
|
||||
if len(self.load_history) >= 10:
|
||||
avg_load = sum(self.load_history) / len(self.load_history)
|
||||
|
||||
if avg_load > 80:
|
||||
self.current_multiplier = 0.5 # Reduce limits by 50%
|
||||
elif avg_load > 60:
|
||||
self.current_multiplier = 0.7 # Reduce limits by 30%
|
||||
elif avg_load < 30:
|
||||
self.current_multiplier = 1.2 # Increase limits by 20%
|
||||
else:
|
||||
self.current_multiplier = 1.0 # Normal limits
|
||||
|
||||
def get_adjusted_limit(self, base_limit: int) -> int:
|
||||
"""Get adjusted rate limit based on system load."""
|
||||
return max(1, int(base_limit * self.current_multiplier))
|
||||
|
||||
|
||||
class RateLimitStorage:
|
||||
"""Abstract interface for rate limit storage (Redis implementation)."""
|
||||
|
||||
async def get_count(self, key: str, window: int) -> int:
|
||||
"""Get current request count for key within window."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def increment(self, key: str, window: int) -> int:
|
||||
"""Increment request count and return new count."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def is_blocked(self, client_id: str) -> bool:
|
||||
"""Check if client is blocked."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def block_client(self, client_id: str, duration: int):
|
||||
"""Block client for duration seconds."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RedisRateLimitStorage(RateLimitStorage):
|
||||
"""Redis-based rate limit storage for production use."""
|
||||
|
||||
def __init__(self, redis_client):
|
||||
self.redis = redis_client
|
||||
|
||||
async def get_count(self, key: str, window: int) -> int:
|
||||
"""Get current request count using Redis sliding window."""
|
||||
now = time.time()
|
||||
pipeline = self.redis.pipeline()
|
||||
|
||||
# Remove old entries
|
||||
pipeline.zremrangebyscore(key, 0, now - window)
|
||||
|
||||
# Count current entries
|
||||
pipeline.zcard(key)
|
||||
|
||||
results = await pipeline.execute()
|
||||
return results[1]
|
||||
|
||||
async def increment(self, key: str, window: int) -> int:
|
||||
"""Increment request count using Redis."""
|
||||
now = time.time()
|
||||
pipeline = self.redis.pipeline()
|
||||
|
||||
# Add current request
|
||||
pipeline.zadd(key, {str(now): now})
|
||||
|
||||
# Remove old entries
|
||||
pipeline.zremrangebyscore(key, 0, now - window)
|
||||
|
||||
# Set expiration
|
||||
pipeline.expire(key, window + 1)
|
||||
|
||||
# Get count
|
||||
pipeline.zcard(key)
|
||||
|
||||
results = await pipeline.execute()
|
||||
return results[3]
|
||||
|
||||
async def is_blocked(self, client_id: str) -> bool:
|
||||
"""Check if client is blocked."""
|
||||
block_key = f"blocked:{client_id}"
|
||||
return await self.redis.exists(block_key)
|
||||
|
||||
async def block_client(self, client_id: str, duration: int):
|
||||
"""Block client for duration seconds."""
|
||||
block_key = f"blocked:{client_id}"
|
||||
await self.redis.setex(block_key, duration, "1")
|
||||
|
||||
|
||||
# Global adaptive rate limiter instance
|
||||
adaptive_rate_limit = AdaptiveRateLimit()
|
||||
Reference in New Issue
Block a user