Major changes: - Organized Python v1 implementation into v1/ subdirectory - Created Rust workspace with 9 modular crates: - wifi-densepose-core: Core types, traits, errors - wifi-densepose-signal: CSI processing, phase sanitization, FFT - wifi-densepose-nn: Neural network inference (ONNX/Candle/tch) - wifi-densepose-api: Axum-based REST/WebSocket API - wifi-densepose-db: SQLx database layer - wifi-densepose-config: Configuration management - wifi-densepose-hardware: Hardware abstraction - wifi-densepose-wasm: WebAssembly bindings - wifi-densepose-cli: Command-line interface Documentation: - ADR-001: Workspace structure - ADR-002: Signal processing library selection - ADR-003: Neural network inference strategy - DDD domain model with bounded contexts Testing: - 69 tests passing across all crates - Signal processing: 45 tests - Neural networks: 21 tests - Core: 3 doc tests Performance targets: - 10x faster CSI processing (~0.5ms vs ~5ms) - 5x lower memory usage (~100MB vs ~500MB) - WASM support for browser deployment
429 lines
15 KiB
Python
429 lines
15 KiB
Python
"""
|
|
Rate limiting middleware for WiFi-DensePose API
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
from typing import Dict, Optional, Tuple
|
|
from datetime import datetime, timedelta
|
|
from collections import defaultdict, deque
|
|
|
|
from fastapi import Request, Response
|
|
from fastapi.responses import JSONResponse
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
from src.config.settings import get_settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
"""Rate limiting middleware with sliding window algorithm."""
|
|
|
|
def __init__(self, app):
|
|
super().__init__(app)
|
|
self.settings = get_settings()
|
|
|
|
# Rate limit storage (in production, use Redis)
|
|
self.request_counts = defaultdict(lambda: deque())
|
|
self.blocked_clients = {}
|
|
|
|
# Rate limit configurations
|
|
self.rate_limits = {
|
|
"anonymous": {
|
|
"requests": self.settings.rate_limit_requests,
|
|
"window": self.settings.rate_limit_window,
|
|
"burst": 10 # Allow burst of 10 requests
|
|
},
|
|
"authenticated": {
|
|
"requests": self.settings.rate_limit_authenticated_requests,
|
|
"window": self.settings.rate_limit_window,
|
|
"burst": 50
|
|
},
|
|
"admin": {
|
|
"requests": 10000, # Very high limit for admins
|
|
"window": self.settings.rate_limit_window,
|
|
"burst": 100
|
|
}
|
|
}
|
|
|
|
# Path-specific rate limits
|
|
self.path_limits = {
|
|
"/api/v1/pose/current": {"requests": 60, "window": 60}, # 1 per second
|
|
"/api/v1/pose/analyze": {"requests": 10, "window": 60}, # 10 per minute
|
|
"/api/v1/pose/calibrate": {"requests": 1, "window": 300}, # 1 per 5 minutes
|
|
"/api/v1/stream/start": {"requests": 5, "window": 60}, # 5 per minute
|
|
"/api/v1/stream/stop": {"requests": 5, "window": 60}, # 5 per minute
|
|
}
|
|
|
|
# Exempt paths from rate limiting
|
|
self.exempt_paths = {
|
|
"/health",
|
|
"/ready",
|
|
"/live",
|
|
"/version",
|
|
"/metrics"
|
|
}
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
"""Process request through rate limiting middleware."""
|
|
|
|
# Skip rate limiting for exempt paths
|
|
if self._is_exempt_path(request.url.path):
|
|
return await call_next(request)
|
|
|
|
# Get client identifier
|
|
client_id = self._get_client_id(request)
|
|
|
|
# Check if client is temporarily blocked
|
|
if self._is_client_blocked(client_id):
|
|
return self._create_rate_limit_response(
|
|
"Client temporarily blocked due to excessive requests"
|
|
)
|
|
|
|
# Get user type for rate limiting
|
|
user_type = self._get_user_type(request)
|
|
|
|
# Check rate limits
|
|
rate_limit_result = self._check_rate_limits(
|
|
client_id,
|
|
request.url.path,
|
|
user_type
|
|
)
|
|
|
|
if not rate_limit_result["allowed"]:
|
|
# Log rate limit violation
|
|
self._log_rate_limit_violation(request, client_id, rate_limit_result)
|
|
|
|
# Check if client should be temporarily blocked
|
|
if rate_limit_result.get("violations", 0) > 5:
|
|
self._block_client(client_id, duration=300) # 5 minutes
|
|
|
|
return self._create_rate_limit_response(
|
|
rate_limit_result["message"],
|
|
retry_after=rate_limit_result.get("retry_after", 60)
|
|
)
|
|
|
|
# Record the request
|
|
self._record_request(client_id, request.url.path)
|
|
|
|
# Process request
|
|
response = await call_next(request)
|
|
|
|
# Add rate limit headers
|
|
self._add_rate_limit_headers(response, client_id, user_type)
|
|
|
|
return response
|
|
|
|
def _is_exempt_path(self, path: str) -> bool:
|
|
"""Check if path is exempt from rate limiting."""
|
|
return path in self.exempt_paths
|
|
|
|
def _get_client_id(self, request: Request) -> str:
|
|
"""Get unique client identifier for rate limiting."""
|
|
# Try to get user ID from request state (set by auth middleware)
|
|
if hasattr(request.state, 'user') and request.state.user:
|
|
return f"user:{request.state.user['id']}"
|
|
|
|
# Fall back to IP address
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
|
|
# Include user agent for better identification
|
|
user_agent = request.headers.get("user-agent", "")
|
|
user_agent_hash = str(hash(user_agent))[:8]
|
|
|
|
return f"ip:{client_ip}:{user_agent_hash}"
|
|
|
|
def _get_user_type(self, request: Request) -> str:
|
|
"""Determine user type for rate limiting."""
|
|
if hasattr(request.state, 'user') and request.state.user:
|
|
if request.state.user.get("is_admin", False):
|
|
return "admin"
|
|
return "authenticated"
|
|
return "anonymous"
|
|
|
|
def _check_rate_limits(self, client_id: str, path: str, user_type: str) -> Dict:
|
|
"""Check if request is within rate limits."""
|
|
now = time.time()
|
|
|
|
# Get applicable rate limits
|
|
general_limit = self.rate_limits[user_type]
|
|
path_limit = self.path_limits.get(path)
|
|
|
|
# Check general rate limit
|
|
general_result = self._check_limit(
|
|
client_id,
|
|
"general",
|
|
general_limit["requests"],
|
|
general_limit["window"],
|
|
now
|
|
)
|
|
|
|
if not general_result["allowed"]:
|
|
return general_result
|
|
|
|
# Check path-specific rate limit if exists
|
|
if path_limit:
|
|
path_result = self._check_limit(
|
|
client_id,
|
|
f"path:{path}",
|
|
path_limit["requests"],
|
|
path_limit["window"],
|
|
now
|
|
)
|
|
|
|
if not path_result["allowed"]:
|
|
return path_result
|
|
|
|
return {"allowed": True}
|
|
|
|
def _check_limit(self, client_id: str, limit_type: str, max_requests: int, window: int, now: float) -> Dict:
|
|
"""Check specific rate limit using sliding window."""
|
|
key = f"{client_id}:{limit_type}"
|
|
requests = self.request_counts[key]
|
|
|
|
# Remove old requests outside the window
|
|
cutoff = now - window
|
|
while requests and requests[0] <= cutoff:
|
|
requests.popleft()
|
|
|
|
# Check if limit exceeded
|
|
if len(requests) >= max_requests:
|
|
# Calculate retry after time
|
|
oldest_request = requests[0] if requests else now
|
|
retry_after = int(oldest_request + window - now) + 1
|
|
|
|
return {
|
|
"allowed": False,
|
|
"message": f"Rate limit exceeded: {max_requests} requests per {window} seconds",
|
|
"retry_after": retry_after,
|
|
"current_count": len(requests),
|
|
"limit": max_requests,
|
|
"window": window
|
|
}
|
|
|
|
return {
|
|
"allowed": True,
|
|
"current_count": len(requests),
|
|
"limit": max_requests,
|
|
"window": window
|
|
}
|
|
|
|
def _record_request(self, client_id: str, path: str):
|
|
"""Record a request for rate limiting."""
|
|
now = time.time()
|
|
|
|
# Record general request
|
|
general_key = f"{client_id}:general"
|
|
self.request_counts[general_key].append(now)
|
|
|
|
# Record path-specific request if path has specific limits
|
|
if path in self.path_limits:
|
|
path_key = f"{client_id}:path:{path}"
|
|
self.request_counts[path_key].append(now)
|
|
|
|
def _is_client_blocked(self, client_id: str) -> bool:
|
|
"""Check if client is temporarily blocked."""
|
|
if client_id in self.blocked_clients:
|
|
block_until = self.blocked_clients[client_id]
|
|
if time.time() < block_until:
|
|
return True
|
|
else:
|
|
# Block expired, remove it
|
|
del self.blocked_clients[client_id]
|
|
return False
|
|
|
|
def _block_client(self, client_id: str, duration: int):
|
|
"""Temporarily block a client."""
|
|
self.blocked_clients[client_id] = time.time() + duration
|
|
logger.warning(f"Client {client_id} blocked for {duration} seconds due to rate limit violations")
|
|
|
|
def _create_rate_limit_response(self, message: str, retry_after: int = 60) -> JSONResponse:
|
|
"""Create rate limit exceeded response."""
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={
|
|
"error": {
|
|
"code": 429,
|
|
"message": message,
|
|
"type": "rate_limit_exceeded"
|
|
}
|
|
},
|
|
headers={
|
|
"Retry-After": str(retry_after),
|
|
"X-RateLimit-Limit": "Exceeded",
|
|
"X-RateLimit-Remaining": "0"
|
|
}
|
|
)
|
|
|
|
def _add_rate_limit_headers(self, response: Response, client_id: str, user_type: str):
|
|
"""Add rate limit headers to response."""
|
|
try:
|
|
general_limit = self.rate_limits[user_type]
|
|
general_key = f"{client_id}:general"
|
|
current_requests = len(self.request_counts[general_key])
|
|
|
|
remaining = max(0, general_limit["requests"] - current_requests)
|
|
|
|
response.headers["X-RateLimit-Limit"] = str(general_limit["requests"])
|
|
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
|
response.headers["X-RateLimit-Window"] = str(general_limit["window"])
|
|
|
|
# Add reset time
|
|
if self.request_counts[general_key]:
|
|
oldest_request = self.request_counts[general_key][0]
|
|
reset_time = int(oldest_request + general_limit["window"])
|
|
response.headers["X-RateLimit-Reset"] = str(reset_time)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error adding rate limit headers: {e}")
|
|
|
|
def _log_rate_limit_violation(self, request: Request, client_id: str, result: Dict):
|
|
"""Log rate limit violations for monitoring."""
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
user_agent = request.headers.get("user-agent", "unknown")
|
|
|
|
log_data = {
|
|
"event_type": "rate_limit_violation",
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"client_id": client_id,
|
|
"client_ip": client_ip,
|
|
"user_agent": user_agent,
|
|
"path": request.url.path,
|
|
"method": request.method,
|
|
"current_count": result.get("current_count"),
|
|
"limit": result.get("limit"),
|
|
"window": result.get("window")
|
|
}
|
|
|
|
logger.warning(f"Rate limit violation: {log_data}")
|
|
|
|
def cleanup_old_data(self):
|
|
"""Clean up old rate limiting data (call periodically)."""
|
|
now = time.time()
|
|
cutoff = now - 3600 # Keep data for 1 hour
|
|
|
|
# Clean up request counts
|
|
for key in list(self.request_counts.keys()):
|
|
requests = self.request_counts[key]
|
|
while requests and requests[0] <= cutoff:
|
|
requests.popleft()
|
|
|
|
# Remove empty deques
|
|
if not requests:
|
|
del self.request_counts[key]
|
|
|
|
# Clean up expired blocks
|
|
expired_blocks = [
|
|
client_id for client_id, block_until in self.blocked_clients.items()
|
|
if now >= block_until
|
|
]
|
|
|
|
for client_id in expired_blocks:
|
|
del self.blocked_clients[client_id]
|
|
|
|
|
|
class AdaptiveRateLimit:
|
|
"""Adaptive rate limiting based on system load."""
|
|
|
|
def __init__(self):
|
|
self.base_limits = {}
|
|
self.current_multiplier = 1.0
|
|
self.load_history = deque(maxlen=60) # Keep 1 minute of load data
|
|
|
|
def update_system_load(self, cpu_percent: float, memory_percent: float):
|
|
"""Update system load metrics."""
|
|
load_score = (cpu_percent + memory_percent) / 2
|
|
self.load_history.append(load_score)
|
|
|
|
# Calculate adaptive multiplier
|
|
if len(self.load_history) >= 10:
|
|
avg_load = sum(self.load_history) / len(self.load_history)
|
|
|
|
if avg_load > 80:
|
|
self.current_multiplier = 0.5 # Reduce limits by 50%
|
|
elif avg_load > 60:
|
|
self.current_multiplier = 0.7 # Reduce limits by 30%
|
|
elif avg_load < 30:
|
|
self.current_multiplier = 1.2 # Increase limits by 20%
|
|
else:
|
|
self.current_multiplier = 1.0 # Normal limits
|
|
|
|
def get_adjusted_limit(self, base_limit: int) -> int:
|
|
"""Get adjusted rate limit based on system load."""
|
|
return max(1, int(base_limit * self.current_multiplier))
|
|
|
|
|
|
class RateLimitStorage:
|
|
"""Abstract interface for rate limit storage (Redis implementation)."""
|
|
|
|
async def get_count(self, key: str, window: int) -> int:
|
|
"""Get current request count for key within window."""
|
|
raise NotImplementedError
|
|
|
|
async def increment(self, key: str, window: int) -> int:
|
|
"""Increment request count and return new count."""
|
|
raise NotImplementedError
|
|
|
|
async def is_blocked(self, client_id: str) -> bool:
|
|
"""Check if client is blocked."""
|
|
raise NotImplementedError
|
|
|
|
async def block_client(self, client_id: str, duration: int):
|
|
"""Block client for duration seconds."""
|
|
raise NotImplementedError
|
|
|
|
|
|
class RedisRateLimitStorage(RateLimitStorage):
|
|
"""Redis-based rate limit storage for production use."""
|
|
|
|
def __init__(self, redis_client):
|
|
self.redis = redis_client
|
|
|
|
async def get_count(self, key: str, window: int) -> int:
|
|
"""Get current request count using Redis sliding window."""
|
|
now = time.time()
|
|
pipeline = self.redis.pipeline()
|
|
|
|
# Remove old entries
|
|
pipeline.zremrangebyscore(key, 0, now - window)
|
|
|
|
# Count current entries
|
|
pipeline.zcard(key)
|
|
|
|
results = await pipeline.execute()
|
|
return results[1]
|
|
|
|
async def increment(self, key: str, window: int) -> int:
|
|
"""Increment request count using Redis."""
|
|
now = time.time()
|
|
pipeline = self.redis.pipeline()
|
|
|
|
# Add current request
|
|
pipeline.zadd(key, {str(now): now})
|
|
|
|
# Remove old entries
|
|
pipeline.zremrangebyscore(key, 0, now - window)
|
|
|
|
# Set expiration
|
|
pipeline.expire(key, window + 1)
|
|
|
|
# Get count
|
|
pipeline.zcard(key)
|
|
|
|
results = await pipeline.execute()
|
|
return results[3]
|
|
|
|
async def is_blocked(self, client_id: str) -> bool:
|
|
"""Check if client is blocked."""
|
|
block_key = f"blocked:{client_id}"
|
|
return await self.redis.exists(block_key)
|
|
|
|
async def block_client(self, client_id: str, duration: int):
|
|
"""Block client for duration seconds."""
|
|
block_key = f"blocked:{client_id}"
|
|
await self.redis.setex(block_key, duration, "1")
|
|
|
|
|
|
# Global adaptive rate limiter instance
|
|
adaptive_rate_limit = AdaptiveRateLimit() |