updates
This commit is contained in:
465
src/middleware/rate_limit.py
Normal file
465
src/middleware/rate_limit.py
Normal file
@@ -0,0 +1,465 @@
|
||||
"""
|
||||
Rate limiting middleware for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Any, Optional, Callable, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fastapi import Request, Response, HTTPException, status
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from src.config.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitInfo:
|
||||
"""Rate limit information."""
|
||||
requests: int
|
||||
window_start: float
|
||||
window_size: int
|
||||
limit: int
|
||||
|
||||
@property
|
||||
def remaining(self) -> int:
|
||||
"""Get remaining requests in current window."""
|
||||
return max(0, self.limit - self.requests)
|
||||
|
||||
@property
|
||||
def reset_time(self) -> float:
|
||||
"""Get time when window resets."""
|
||||
return self.window_start + self.window_size
|
||||
|
||||
@property
|
||||
def is_exceeded(self) -> bool:
|
||||
"""Check if rate limit is exceeded."""
|
||||
return self.requests >= self.limit
|
||||
|
||||
|
||||
class TokenBucket:
|
||||
"""Token bucket algorithm for rate limiting."""
|
||||
|
||||
def __init__(self, capacity: int, refill_rate: float):
|
||||
self.capacity = capacity
|
||||
self.tokens = capacity
|
||||
self.refill_rate = refill_rate
|
||||
self.last_refill = time.time()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def consume(self, tokens: int = 1) -> bool:
|
||||
"""Try to consume tokens from bucket."""
|
||||
async with self._lock:
|
||||
now = time.time()
|
||||
|
||||
# Refill tokens based on time elapsed
|
||||
time_passed = now - self.last_refill
|
||||
tokens_to_add = time_passed * self.refill_rate
|
||||
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
|
||||
self.last_refill = now
|
||||
|
||||
# Check if we have enough tokens
|
||||
if self.tokens >= tokens:
|
||||
self.tokens -= tokens
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_info(self) -> Dict[str, Any]:
|
||||
"""Get bucket information."""
|
||||
return {
|
||||
"capacity": self.capacity,
|
||||
"tokens": self.tokens,
|
||||
"refill_rate": self.refill_rate,
|
||||
"last_refill": self.last_refill
|
||||
}
|
||||
|
||||
|
||||
class SlidingWindowCounter:
|
||||
"""Sliding window counter for rate limiting."""
|
||||
|
||||
def __init__(self, window_size: int, limit: int):
|
||||
self.window_size = window_size
|
||||
self.limit = limit
|
||||
self.requests = deque()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def is_allowed(self) -> Tuple[bool, RateLimitInfo]:
|
||||
"""Check if request is allowed."""
|
||||
async with self._lock:
|
||||
now = time.time()
|
||||
window_start = now - self.window_size
|
||||
|
||||
# Remove old requests outside the window
|
||||
while self.requests and self.requests[0] < window_start:
|
||||
self.requests.popleft()
|
||||
|
||||
# Check if limit is exceeded
|
||||
current_requests = len(self.requests)
|
||||
allowed = current_requests < self.limit
|
||||
|
||||
if allowed:
|
||||
self.requests.append(now)
|
||||
|
||||
rate_limit_info = RateLimitInfo(
|
||||
requests=current_requests + (1 if allowed else 0),
|
||||
window_start=window_start,
|
||||
window_size=self.window_size,
|
||||
limit=self.limit
|
||||
)
|
||||
|
||||
return allowed, rate_limit_info
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Rate limiter with multiple algorithms."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self.enabled = settings.enable_rate_limiting
|
||||
|
||||
# Rate limit configurations
|
||||
self.default_limit = settings.rate_limit_requests
|
||||
self.authenticated_limit = settings.rate_limit_authenticated_requests
|
||||
self.window_size = settings.rate_limit_window
|
||||
|
||||
# Storage for rate limit data
|
||||
self._sliding_windows: Dict[str, SlidingWindowCounter] = {}
|
||||
self._token_buckets: Dict[str, TokenBucket] = {}
|
||||
|
||||
# Cleanup task
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
self._cleanup_interval = 300 # 5 minutes
|
||||
|
||||
async def start(self):
|
||||
"""Start rate limiter background tasks."""
|
||||
if self.enabled:
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
logger.info("Rate limiter started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop rate limiter background tasks."""
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Rate limiter stopped")
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""Background task to cleanup old rate limit data."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self._cleanup_interval)
|
||||
await self._cleanup_old_data()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in rate limiter cleanup: {e}")
|
||||
|
||||
async def _cleanup_old_data(self):
|
||||
"""Remove old rate limit data."""
|
||||
now = time.time()
|
||||
cutoff = now - (self.window_size * 2) # Keep data for 2 windows
|
||||
|
||||
# Cleanup sliding windows
|
||||
keys_to_remove = []
|
||||
for key, window in self._sliding_windows.items():
|
||||
# Remove old requests
|
||||
while window.requests and window.requests[0] < cutoff:
|
||||
window.requests.popleft()
|
||||
|
||||
# Remove empty windows
|
||||
if not window.requests:
|
||||
keys_to_remove.append(key)
|
||||
|
||||
for key in keys_to_remove:
|
||||
del self._sliding_windows[key]
|
||||
|
||||
logger.debug(f"Cleaned up {len(keys_to_remove)} old rate limit windows")
|
||||
|
||||
def _get_client_identifier(self, request: Request) -> str:
|
||||
"""Get client identifier for rate limiting."""
|
||||
# Try to get user ID from authenticated request
|
||||
user = getattr(request.state, "user", None)
|
||||
if user:
|
||||
return f"user:{user.get('username', 'unknown')}"
|
||||
|
||||
# Fall back to IP address
|
||||
client_ip = self._get_client_ip(request)
|
||||
return f"ip:{client_ip}"
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Get client IP address."""
|
||||
# Check for forwarded headers
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
# Fall back to direct connection
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
def _get_rate_limit(self, request: Request) -> int:
|
||||
"""Get rate limit for request."""
|
||||
# Check if user is authenticated
|
||||
user = getattr(request.state, "user", None)
|
||||
if user:
|
||||
return self.authenticated_limit
|
||||
|
||||
return self.default_limit
|
||||
|
||||
def _get_rate_limit_key(self, request: Request) -> str:
|
||||
"""Get rate limit key for request."""
|
||||
client_id = self._get_client_identifier(request)
|
||||
endpoint = f"{request.method}:{request.url.path}"
|
||||
return f"{client_id}:{endpoint}"
|
||||
|
||||
async def check_rate_limit(self, request: Request) -> Tuple[bool, RateLimitInfo]:
|
||||
"""Check if request is within rate limits."""
|
||||
if not self.enabled:
|
||||
# Return dummy info when rate limiting is disabled
|
||||
return True, RateLimitInfo(
|
||||
requests=0,
|
||||
window_start=time.time(),
|
||||
window_size=self.window_size,
|
||||
limit=float('inf')
|
||||
)
|
||||
|
||||
key = self._get_rate_limit_key(request)
|
||||
limit = self._get_rate_limit(request)
|
||||
|
||||
# Get or create sliding window counter
|
||||
if key not in self._sliding_windows:
|
||||
self._sliding_windows[key] = SlidingWindowCounter(self.window_size, limit)
|
||||
|
||||
window = self._sliding_windows[key]
|
||||
|
||||
# Update limit if it changed (e.g., user authenticated)
|
||||
window.limit = limit
|
||||
|
||||
return await window.is_allowed()
|
||||
|
||||
async def check_token_bucket(self, request: Request, tokens: int = 1) -> bool:
|
||||
"""Check rate limit using token bucket algorithm."""
|
||||
if not self.enabled:
|
||||
return True
|
||||
|
||||
key = self._get_client_identifier(request)
|
||||
limit = self._get_rate_limit(request)
|
||||
|
||||
# Get or create token bucket
|
||||
if key not in self._token_buckets:
|
||||
# Refill rate: limit per window size
|
||||
refill_rate = limit / self.window_size
|
||||
self._token_buckets[key] = TokenBucket(limit, refill_rate)
|
||||
|
||||
bucket = self._token_buckets[key]
|
||||
return await bucket.consume(tokens)
|
||||
|
||||
def get_rate_limit_headers(self, rate_limit_info: RateLimitInfo) -> Dict[str, str]:
|
||||
"""Get rate limit headers for response."""
|
||||
return {
|
||||
"X-RateLimit-Limit": str(rate_limit_info.limit),
|
||||
"X-RateLimit-Remaining": str(rate_limit_info.remaining),
|
||||
"X-RateLimit-Reset": str(int(rate_limit_info.reset_time)),
|
||||
"X-RateLimit-Window": str(rate_limit_info.window_size),
|
||||
}
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get rate limiter statistics."""
|
||||
return {
|
||||
"enabled": self.enabled,
|
||||
"default_limit": self.default_limit,
|
||||
"authenticated_limit": self.authenticated_limit,
|
||||
"window_size": self.window_size,
|
||||
"active_windows": len(self._sliding_windows),
|
||||
"active_buckets": len(self._token_buckets),
|
||||
}
|
||||
|
||||
|
||||
class RateLimitMiddleware:
|
||||
"""Rate limiting middleware for FastAPI."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self.rate_limiter = RateLimiter(settings)
|
||||
self.enabled = settings.enable_rate_limiting
|
||||
|
||||
async def __call__(self, request: Request, call_next: Callable) -> Response:
|
||||
"""Process request through rate limiting middleware."""
|
||||
if not self.enabled:
|
||||
return await call_next(request)
|
||||
|
||||
# Skip rate limiting for certain paths
|
||||
if self._should_skip_rate_limit(request):
|
||||
return await call_next(request)
|
||||
|
||||
try:
|
||||
# Check rate limit
|
||||
allowed, rate_limit_info = await self.rate_limiter.check_rate_limit(request)
|
||||
|
||||
if not allowed:
|
||||
# Rate limit exceeded
|
||||
logger.warning(
|
||||
f"Rate limit exceeded for {self.rate_limiter._get_client_identifier(request)} "
|
||||
f"on {request.method} {request.url.path}"
|
||||
)
|
||||
|
||||
headers = self.rate_limiter.get_rate_limit_headers(rate_limit_info)
|
||||
headers["Retry-After"] = str(int(rate_limit_info.reset_time - time.time()))
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Rate limit exceeded",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add rate limit headers to response
|
||||
headers = self.rate_limiter.get_rate_limit_headers(rate_limit_info)
|
||||
for key, value in headers.items():
|
||||
response.headers[key] = value
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Rate limiting middleware error: {e}")
|
||||
# Continue without rate limiting on error
|
||||
return await call_next(request)
|
||||
|
||||
def _should_skip_rate_limit(self, request: Request) -> bool:
|
||||
"""Check if rate limiting should be skipped for this request."""
|
||||
path = request.url.path
|
||||
|
||||
# Skip rate limiting for these paths
|
||||
skip_paths = [
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/static",
|
||||
]
|
||||
|
||||
return any(path.startswith(skip_path) for skip_path in skip_paths)
|
||||
|
||||
async def start(self):
|
||||
"""Start rate limiting middleware."""
|
||||
await self.rate_limiter.start()
|
||||
|
||||
async def stop(self):
|
||||
"""Stop rate limiting middleware."""
|
||||
await self.rate_limiter.stop()
|
||||
|
||||
|
||||
# Global rate limit middleware instance
|
||||
_rate_limit_middleware: Optional[RateLimitMiddleware] = None
|
||||
|
||||
|
||||
def get_rate_limit_middleware(settings: Settings) -> RateLimitMiddleware:
|
||||
"""Get rate limit middleware instance."""
|
||||
global _rate_limit_middleware
|
||||
if _rate_limit_middleware is None:
|
||||
_rate_limit_middleware = RateLimitMiddleware(settings)
|
||||
return _rate_limit_middleware
|
||||
|
||||
|
||||
def setup_rate_limiting(app: ASGIApp, settings: Settings) -> ASGIApp:
|
||||
"""Setup rate limiting middleware for the application."""
|
||||
if settings.enable_rate_limiting:
|
||||
logger.info("Setting up rate limiting middleware")
|
||||
|
||||
middleware = get_rate_limit_middleware(settings)
|
||||
|
||||
# Add middleware to app
|
||||
@app.middleware("http")
|
||||
async def rate_limit_middleware(request: Request, call_next):
|
||||
return await middleware(request, call_next)
|
||||
|
||||
logger.info(
|
||||
f"Rate limiting enabled - Default: {settings.rate_limit_requests}/"
|
||||
f"{settings.rate_limit_window}s, Authenticated: "
|
||||
f"{settings.rate_limit_authenticated_requests}/{settings.rate_limit_window}s"
|
||||
)
|
||||
else:
|
||||
logger.info("Rate limiting disabled")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class RateLimitConfig:
|
||||
"""Rate limiting configuration helper."""
|
||||
|
||||
@staticmethod
|
||||
def development_config() -> dict:
|
||||
"""Get rate limiting configuration for development."""
|
||||
return {
|
||||
"enable_rate_limiting": False, # Disabled in development
|
||||
"rate_limit_requests": 1000,
|
||||
"rate_limit_authenticated_requests": 5000,
|
||||
"rate_limit_window": 3600, # 1 hour
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def production_config() -> dict:
|
||||
"""Get rate limiting configuration for production."""
|
||||
return {
|
||||
"enable_rate_limiting": True,
|
||||
"rate_limit_requests": 100, # 100 requests per hour for unauthenticated
|
||||
"rate_limit_authenticated_requests": 1000, # 1000 requests per hour for authenticated
|
||||
"rate_limit_window": 3600, # 1 hour
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def api_config() -> dict:
|
||||
"""Get rate limiting configuration for API access."""
|
||||
return {
|
||||
"enable_rate_limiting": True,
|
||||
"rate_limit_requests": 60, # 60 requests per minute
|
||||
"rate_limit_authenticated_requests": 300, # 300 requests per minute
|
||||
"rate_limit_window": 60, # 1 minute
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def strict_config() -> dict:
|
||||
"""Get strict rate limiting configuration."""
|
||||
return {
|
||||
"enable_rate_limiting": True,
|
||||
"rate_limit_requests": 10, # 10 requests per minute
|
||||
"rate_limit_authenticated_requests": 100, # 100 requests per minute
|
||||
"rate_limit_window": 60, # 1 minute
|
||||
}
|
||||
|
||||
|
||||
def validate_rate_limit_config(settings: Settings) -> list:
|
||||
"""Validate rate limiting configuration."""
|
||||
issues = []
|
||||
|
||||
if settings.enable_rate_limiting:
|
||||
if settings.rate_limit_requests <= 0:
|
||||
issues.append("Rate limit requests must be positive")
|
||||
|
||||
if settings.rate_limit_authenticated_requests <= 0:
|
||||
issues.append("Authenticated rate limit requests must be positive")
|
||||
|
||||
if settings.rate_limit_window <= 0:
|
||||
issues.append("Rate limit window must be positive")
|
||||
|
||||
if settings.rate_limit_authenticated_requests < settings.rate_limit_requests:
|
||||
issues.append("Authenticated rate limit should be higher than default rate limit")
|
||||
|
||||
return issues
|
||||
Reference in New Issue
Block a user