465 lines
16 KiB
Python
465 lines
16 KiB
Python
"""
|
|
Rate limiting middleware for WiFi-DensePose API
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from typing import Dict, Any, Optional, Callable, Tuple
|
|
from datetime import datetime, timedelta
|
|
from collections import defaultdict, deque
|
|
from dataclasses import dataclass
|
|
|
|
from fastapi import Request, Response, HTTPException, status
|
|
from starlette.types import ASGIApp
|
|
|
|
from src.config.settings import Settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class RateLimitInfo:
|
|
"""Rate limit information."""
|
|
requests: int
|
|
window_start: float
|
|
window_size: int
|
|
limit: int
|
|
|
|
@property
|
|
def remaining(self) -> int:
|
|
"""Get remaining requests in current window."""
|
|
return max(0, self.limit - self.requests)
|
|
|
|
@property
|
|
def reset_time(self) -> float:
|
|
"""Get time when window resets."""
|
|
return self.window_start + self.window_size
|
|
|
|
@property
|
|
def is_exceeded(self) -> bool:
|
|
"""Check if rate limit is exceeded."""
|
|
return self.requests >= self.limit
|
|
|
|
|
|
class TokenBucket:
|
|
"""Token bucket algorithm for rate limiting."""
|
|
|
|
def __init__(self, capacity: int, refill_rate: float):
|
|
self.capacity = capacity
|
|
self.tokens = capacity
|
|
self.refill_rate = refill_rate
|
|
self.last_refill = time.time()
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def consume(self, tokens: int = 1) -> bool:
|
|
"""Try to consume tokens from bucket."""
|
|
async with self._lock:
|
|
now = time.time()
|
|
|
|
# Refill tokens based on time elapsed
|
|
time_passed = now - self.last_refill
|
|
tokens_to_add = time_passed * self.refill_rate
|
|
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
|
|
self.last_refill = now
|
|
|
|
# Check if we have enough tokens
|
|
if self.tokens >= tokens:
|
|
self.tokens -= tokens
|
|
return True
|
|
|
|
return False
|
|
|
|
def get_info(self) -> Dict[str, Any]:
|
|
"""Get bucket information."""
|
|
return {
|
|
"capacity": self.capacity,
|
|
"tokens": self.tokens,
|
|
"refill_rate": self.refill_rate,
|
|
"last_refill": self.last_refill
|
|
}
|
|
|
|
|
|
class SlidingWindowCounter:
|
|
"""Sliding window counter for rate limiting."""
|
|
|
|
def __init__(self, window_size: int, limit: int):
|
|
self.window_size = window_size
|
|
self.limit = limit
|
|
self.requests = deque()
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def is_allowed(self) -> Tuple[bool, RateLimitInfo]:
|
|
"""Check if request is allowed."""
|
|
async with self._lock:
|
|
now = time.time()
|
|
window_start = now - self.window_size
|
|
|
|
# Remove old requests outside the window
|
|
while self.requests and self.requests[0] < window_start:
|
|
self.requests.popleft()
|
|
|
|
# Check if limit is exceeded
|
|
current_requests = len(self.requests)
|
|
allowed = current_requests < self.limit
|
|
|
|
if allowed:
|
|
self.requests.append(now)
|
|
|
|
rate_limit_info = RateLimitInfo(
|
|
requests=current_requests + (1 if allowed else 0),
|
|
window_start=window_start,
|
|
window_size=self.window_size,
|
|
limit=self.limit
|
|
)
|
|
|
|
return allowed, rate_limit_info
|
|
|
|
|
|
class RateLimiter:
|
|
"""Rate limiter with multiple algorithms."""
|
|
|
|
def __init__(self, settings: Settings):
|
|
self.settings = settings
|
|
self.enabled = settings.enable_rate_limiting
|
|
|
|
# Rate limit configurations
|
|
self.default_limit = settings.rate_limit_requests
|
|
self.authenticated_limit = settings.rate_limit_authenticated_requests
|
|
self.window_size = settings.rate_limit_window
|
|
|
|
# Storage for rate limit data
|
|
self._sliding_windows: Dict[str, SlidingWindowCounter] = {}
|
|
self._token_buckets: Dict[str, TokenBucket] = {}
|
|
|
|
# Cleanup task
|
|
self._cleanup_task: Optional[asyncio.Task] = None
|
|
self._cleanup_interval = 300 # 5 minutes
|
|
|
|
async def start(self):
|
|
"""Start rate limiter background tasks."""
|
|
if self.enabled:
|
|
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
|
logger.info("Rate limiter started")
|
|
|
|
async def stop(self):
|
|
"""Stop rate limiter background tasks."""
|
|
if self._cleanup_task:
|
|
self._cleanup_task.cancel()
|
|
try:
|
|
await self._cleanup_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
logger.info("Rate limiter stopped")
|
|
|
|
async def _cleanup_loop(self):
|
|
"""Background task to cleanup old rate limit data."""
|
|
while True:
|
|
try:
|
|
await asyncio.sleep(self._cleanup_interval)
|
|
await self._cleanup_old_data()
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error in rate limiter cleanup: {e}")
|
|
|
|
async def _cleanup_old_data(self):
|
|
"""Remove old rate limit data."""
|
|
now = time.time()
|
|
cutoff = now - (self.window_size * 2) # Keep data for 2 windows
|
|
|
|
# Cleanup sliding windows
|
|
keys_to_remove = []
|
|
for key, window in self._sliding_windows.items():
|
|
# Remove old requests
|
|
while window.requests and window.requests[0] < cutoff:
|
|
window.requests.popleft()
|
|
|
|
# Remove empty windows
|
|
if not window.requests:
|
|
keys_to_remove.append(key)
|
|
|
|
for key in keys_to_remove:
|
|
del self._sliding_windows[key]
|
|
|
|
logger.debug(f"Cleaned up {len(keys_to_remove)} old rate limit windows")
|
|
|
|
def _get_client_identifier(self, request: Request) -> str:
|
|
"""Get client identifier for rate limiting."""
|
|
# Try to get user ID from authenticated request
|
|
user = getattr(request.state, "user", None)
|
|
if user:
|
|
return f"user:{user.get('username', 'unknown')}"
|
|
|
|
# Fall back to IP address
|
|
client_ip = self._get_client_ip(request)
|
|
return f"ip:{client_ip}"
|
|
|
|
def _get_client_ip(self, request: Request) -> str:
|
|
"""Get client IP address."""
|
|
# Check for forwarded headers
|
|
forwarded_for = request.headers.get("X-Forwarded-For")
|
|
if forwarded_for:
|
|
return forwarded_for.split(",")[0].strip()
|
|
|
|
real_ip = request.headers.get("X-Real-IP")
|
|
if real_ip:
|
|
return real_ip
|
|
|
|
# Fall back to direct connection
|
|
return request.client.host if request.client else "unknown"
|
|
|
|
def _get_rate_limit(self, request: Request) -> int:
|
|
"""Get rate limit for request."""
|
|
# Check if user is authenticated
|
|
user = getattr(request.state, "user", None)
|
|
if user:
|
|
return self.authenticated_limit
|
|
|
|
return self.default_limit
|
|
|
|
def _get_rate_limit_key(self, request: Request) -> str:
|
|
"""Get rate limit key for request."""
|
|
client_id = self._get_client_identifier(request)
|
|
endpoint = f"{request.method}:{request.url.path}"
|
|
return f"{client_id}:{endpoint}"
|
|
|
|
async def check_rate_limit(self, request: Request) -> Tuple[bool, RateLimitInfo]:
|
|
"""Check if request is within rate limits."""
|
|
if not self.enabled:
|
|
# Return dummy info when rate limiting is disabled
|
|
return True, RateLimitInfo(
|
|
requests=0,
|
|
window_start=time.time(),
|
|
window_size=self.window_size,
|
|
limit=float('inf')
|
|
)
|
|
|
|
key = self._get_rate_limit_key(request)
|
|
limit = self._get_rate_limit(request)
|
|
|
|
# Get or create sliding window counter
|
|
if key not in self._sliding_windows:
|
|
self._sliding_windows[key] = SlidingWindowCounter(self.window_size, limit)
|
|
|
|
window = self._sliding_windows[key]
|
|
|
|
# Update limit if it changed (e.g., user authenticated)
|
|
window.limit = limit
|
|
|
|
return await window.is_allowed()
|
|
|
|
async def check_token_bucket(self, request: Request, tokens: int = 1) -> bool:
|
|
"""Check rate limit using token bucket algorithm."""
|
|
if not self.enabled:
|
|
return True
|
|
|
|
key = self._get_client_identifier(request)
|
|
limit = self._get_rate_limit(request)
|
|
|
|
# Get or create token bucket
|
|
if key not in self._token_buckets:
|
|
# Refill rate: limit per window size
|
|
refill_rate = limit / self.window_size
|
|
self._token_buckets[key] = TokenBucket(limit, refill_rate)
|
|
|
|
bucket = self._token_buckets[key]
|
|
return await bucket.consume(tokens)
|
|
|
|
def get_rate_limit_headers(self, rate_limit_info: RateLimitInfo) -> Dict[str, str]:
|
|
"""Get rate limit headers for response."""
|
|
return {
|
|
"X-RateLimit-Limit": str(rate_limit_info.limit),
|
|
"X-RateLimit-Remaining": str(rate_limit_info.remaining),
|
|
"X-RateLimit-Reset": str(int(rate_limit_info.reset_time)),
|
|
"X-RateLimit-Window": str(rate_limit_info.window_size),
|
|
}
|
|
|
|
async def get_stats(self) -> Dict[str, Any]:
|
|
"""Get rate limiter statistics."""
|
|
return {
|
|
"enabled": self.enabled,
|
|
"default_limit": self.default_limit,
|
|
"authenticated_limit": self.authenticated_limit,
|
|
"window_size": self.window_size,
|
|
"active_windows": len(self._sliding_windows),
|
|
"active_buckets": len(self._token_buckets),
|
|
}
|
|
|
|
|
|
class RateLimitMiddleware:
|
|
"""Rate limiting middleware for FastAPI."""
|
|
|
|
def __init__(self, settings: Settings):
|
|
self.settings = settings
|
|
self.rate_limiter = RateLimiter(settings)
|
|
self.enabled = settings.enable_rate_limiting
|
|
|
|
async def __call__(self, request: Request, call_next: Callable) -> Response:
|
|
"""Process request through rate limiting middleware."""
|
|
if not self.enabled:
|
|
return await call_next(request)
|
|
|
|
# Skip rate limiting for certain paths
|
|
if self._should_skip_rate_limit(request):
|
|
return await call_next(request)
|
|
|
|
try:
|
|
# Check rate limit
|
|
allowed, rate_limit_info = await self.rate_limiter.check_rate_limit(request)
|
|
|
|
if not allowed:
|
|
# Rate limit exceeded
|
|
logger.warning(
|
|
f"Rate limit exceeded for {self.rate_limiter._get_client_identifier(request)} "
|
|
f"on {request.method} {request.url.path}"
|
|
)
|
|
|
|
headers = self.rate_limiter.get_rate_limit_headers(rate_limit_info)
|
|
headers["Retry-After"] = str(int(rate_limit_info.reset_time - time.time()))
|
|
|
|
raise HTTPException(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
detail="Rate limit exceeded",
|
|
headers=headers
|
|
)
|
|
|
|
# Process request
|
|
response = await call_next(request)
|
|
|
|
# Add rate limit headers to response
|
|
headers = self.rate_limiter.get_rate_limit_headers(rate_limit_info)
|
|
for key, value in headers.items():
|
|
response.headers[key] = value
|
|
|
|
return response
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Rate limiting middleware error: {e}")
|
|
# Continue without rate limiting on error
|
|
return await call_next(request)
|
|
|
|
def _should_skip_rate_limit(self, request: Request) -> bool:
|
|
"""Check if rate limiting should be skipped for this request."""
|
|
path = request.url.path
|
|
|
|
# Skip rate limiting for these paths
|
|
skip_paths = [
|
|
"/health",
|
|
"/metrics",
|
|
"/docs",
|
|
"/redoc",
|
|
"/openapi.json",
|
|
"/static",
|
|
]
|
|
|
|
return any(path.startswith(skip_path) for skip_path in skip_paths)
|
|
|
|
async def start(self):
|
|
"""Start rate limiting middleware."""
|
|
await self.rate_limiter.start()
|
|
|
|
async def stop(self):
|
|
"""Stop rate limiting middleware."""
|
|
await self.rate_limiter.stop()
|
|
|
|
|
|
# Global rate limit middleware instance
|
|
_rate_limit_middleware: Optional[RateLimitMiddleware] = None
|
|
|
|
|
|
def get_rate_limit_middleware(settings: Settings) -> RateLimitMiddleware:
|
|
"""Get rate limit middleware instance."""
|
|
global _rate_limit_middleware
|
|
if _rate_limit_middleware is None:
|
|
_rate_limit_middleware = RateLimitMiddleware(settings)
|
|
return _rate_limit_middleware
|
|
|
|
|
|
def setup_rate_limiting(app: ASGIApp, settings: Settings) -> ASGIApp:
|
|
"""Setup rate limiting middleware for the application."""
|
|
if settings.enable_rate_limiting:
|
|
logger.info("Setting up rate limiting middleware")
|
|
|
|
middleware = get_rate_limit_middleware(settings)
|
|
|
|
# Add middleware to app
|
|
@app.middleware("http")
|
|
async def rate_limit_middleware(request: Request, call_next):
|
|
return await middleware(request, call_next)
|
|
|
|
logger.info(
|
|
f"Rate limiting enabled - Default: {settings.rate_limit_requests}/"
|
|
f"{settings.rate_limit_window}s, Authenticated: "
|
|
f"{settings.rate_limit_authenticated_requests}/{settings.rate_limit_window}s"
|
|
)
|
|
else:
|
|
logger.info("Rate limiting disabled")
|
|
|
|
return app
|
|
|
|
|
|
class RateLimitConfig:
|
|
"""Rate limiting configuration helper."""
|
|
|
|
@staticmethod
|
|
def development_config() -> dict:
|
|
"""Get rate limiting configuration for development."""
|
|
return {
|
|
"enable_rate_limiting": False, # Disabled in development
|
|
"rate_limit_requests": 1000,
|
|
"rate_limit_authenticated_requests": 5000,
|
|
"rate_limit_window": 3600, # 1 hour
|
|
}
|
|
|
|
@staticmethod
|
|
def production_config() -> dict:
|
|
"""Get rate limiting configuration for production."""
|
|
return {
|
|
"enable_rate_limiting": True,
|
|
"rate_limit_requests": 100, # 100 requests per hour for unauthenticated
|
|
"rate_limit_authenticated_requests": 1000, # 1000 requests per hour for authenticated
|
|
"rate_limit_window": 3600, # 1 hour
|
|
}
|
|
|
|
@staticmethod
|
|
def api_config() -> dict:
|
|
"""Get rate limiting configuration for API access."""
|
|
return {
|
|
"enable_rate_limiting": True,
|
|
"rate_limit_requests": 60, # 60 requests per minute
|
|
"rate_limit_authenticated_requests": 300, # 300 requests per minute
|
|
"rate_limit_window": 60, # 1 minute
|
|
}
|
|
|
|
@staticmethod
|
|
def strict_config() -> dict:
|
|
"""Get strict rate limiting configuration."""
|
|
return {
|
|
"enable_rate_limiting": True,
|
|
"rate_limit_requests": 10, # 10 requests per minute
|
|
"rate_limit_authenticated_requests": 100, # 100 requests per minute
|
|
"rate_limit_window": 60, # 1 minute
|
|
}
|
|
|
|
|
|
def validate_rate_limit_config(settings: Settings) -> list:
|
|
"""Validate rate limiting configuration."""
|
|
issues = []
|
|
|
|
if settings.enable_rate_limiting:
|
|
if settings.rate_limit_requests <= 0:
|
|
issues.append("Rate limit requests must be positive")
|
|
|
|
if settings.rate_limit_authenticated_requests <= 0:
|
|
issues.append("Authenticated rate limit requests must be positive")
|
|
|
|
if settings.rate_limit_window <= 0:
|
|
issues.append("Rate limit window must be positive")
|
|
|
|
if settings.rate_limit_authenticated_requests < settings.rate_limit_requests:
|
|
issues.append("Authenticated rate limit should be higher than default rate limit")
|
|
|
|
return issues |