Files
wifi-densepose/v1/tests/integration/test_rate_limiting.py
Claude 6ed69a3d48 feat: Complete Rust port of WiFi-DensePose with modular crates
Major changes:
- Organized Python v1 implementation into v1/ subdirectory
- Created Rust workspace with 9 modular crates:
  - wifi-densepose-core: Core types, traits, errors
  - wifi-densepose-signal: CSI processing, phase sanitization, FFT
  - wifi-densepose-nn: Neural network inference (ONNX/Candle/tch)
  - wifi-densepose-api: Axum-based REST/WebSocket API
  - wifi-densepose-db: SQLx database layer
  - wifi-densepose-config: Configuration management
  - wifi-densepose-hardware: Hardware abstraction
  - wifi-densepose-wasm: WebAssembly bindings
  - wifi-densepose-cli: Command-line interface

Documentation:
- ADR-001: Workspace structure
- ADR-002: Signal processing library selection
- ADR-003: Neural network inference strategy
- DDD domain model with bounded contexts

Testing:
- 69 tests passing across all crates
- Signal processing: 45 tests
- Neural networks: 21 tests
- Core: 3 doc tests

Performance targets:
- 10x faster CSI processing (~0.5ms vs ~5ms)
- 5x lower memory usage (~100MB vs ~500MB)
- WASM support for browser deployment
2026-01-13 03:11:16 +00:00

565 lines
22 KiB
Python

"""
Integration tests for rate limiting functionality.
Tests rate limit behavior, throttling, and quota management.
"""
import pytest
import asyncio
from datetime import datetime, timedelta
from typing import Dict, Any, List
from unittest.mock import AsyncMock, MagicMock, patch
import time
from fastapi import HTTPException, status, Request, Response
class MockRateLimiter:
"""Mock rate limiter for testing."""
def __init__(self, requests_per_minute: int = 60, requests_per_hour: int = 1000):
self.requests_per_minute = requests_per_minute
self.requests_per_hour = requests_per_hour
self.request_history = {}
self.blocked_clients = set()
def _get_client_key(self, client_id: str, endpoint: str = None) -> str:
"""Get client key for rate limiting."""
return f"{client_id}:{endpoint}" if endpoint else client_id
def _cleanup_old_requests(self, client_key: str):
"""Clean up old request records."""
if client_key not in self.request_history:
return
now = datetime.utcnow()
minute_ago = now - timedelta(minutes=1)
hour_ago = now - timedelta(hours=1)
# Keep only requests from the last hour
self.request_history[client_key] = [
req_time for req_time in self.request_history[client_key]
if req_time > hour_ago
]
def check_rate_limit(self, client_id: str, endpoint: str = None) -> Dict[str, Any]:
"""Check if client is within rate limits."""
client_key = self._get_client_key(client_id, endpoint)
if client_id in self.blocked_clients:
return {
"allowed": False,
"reason": "Client blocked",
"retry_after": 3600 # 1 hour
}
self._cleanup_old_requests(client_key)
if client_key not in self.request_history:
self.request_history[client_key] = []
now = datetime.utcnow()
minute_ago = now - timedelta(minutes=1)
# Count requests in the last minute
recent_requests = [
req_time for req_time in self.request_history[client_key]
if req_time > minute_ago
]
# Count requests in the last hour
hour_requests = len(self.request_history[client_key])
if len(recent_requests) >= self.requests_per_minute:
return {
"allowed": False,
"reason": "Rate limit exceeded (per minute)",
"retry_after": 60,
"current_requests": len(recent_requests),
"limit": self.requests_per_minute
}
if hour_requests >= self.requests_per_hour:
return {
"allowed": False,
"reason": "Rate limit exceeded (per hour)",
"retry_after": 3600,
"current_requests": hour_requests,
"limit": self.requests_per_hour
}
# Record this request
self.request_history[client_key].append(now)
return {
"allowed": True,
"remaining_minute": self.requests_per_minute - len(recent_requests) - 1,
"remaining_hour": self.requests_per_hour - hour_requests - 1,
"reset_time": minute_ago + timedelta(minutes=1)
}
def block_client(self, client_id: str):
"""Block a client."""
self.blocked_clients.add(client_id)
def unblock_client(self, client_id: str):
"""Unblock a client."""
self.blocked_clients.discard(client_id)
class TestRateLimitingBasic:
"""Test basic rate limiting functionality."""
@pytest.fixture
def rate_limiter(self):
"""Create rate limiter for testing."""
return MockRateLimiter(requests_per_minute=5, requests_per_hour=100)
def test_rate_limit_within_bounds_should_fail_initially(self, rate_limiter):
"""Test rate limiting within bounds - should fail initially."""
client_id = "test-client-001"
# Make requests within limit
for i in range(3):
result = rate_limiter.check_rate_limit(client_id)
# This will fail initially
assert result["allowed"] is True
assert "remaining_minute" in result
assert "remaining_hour" in result
def test_rate_limit_per_minute_exceeded_should_fail_initially(self, rate_limiter):
"""Test per-minute rate limit exceeded - should fail initially."""
client_id = "test-client-002"
# Make requests up to the limit
for i in range(5):
result = rate_limiter.check_rate_limit(client_id)
assert result["allowed"] is True
# Next request should be blocked
result = rate_limiter.check_rate_limit(client_id)
# This will fail initially
assert result["allowed"] is False
assert "per minute" in result["reason"]
assert result["retry_after"] == 60
assert result["current_requests"] == 5
assert result["limit"] == 5
def test_rate_limit_per_hour_exceeded_should_fail_initially(self, rate_limiter):
"""Test per-hour rate limit exceeded - should fail initially."""
# Create rate limiter with very low hour limit for testing
limiter = MockRateLimiter(requests_per_minute=10, requests_per_hour=3)
client_id = "test-client-003"
# Make requests up to hour limit
for i in range(3):
result = limiter.check_rate_limit(client_id)
assert result["allowed"] is True
# Next request should be blocked
result = limiter.check_rate_limit(client_id)
# This will fail initially
assert result["allowed"] is False
assert "per hour" in result["reason"]
assert result["retry_after"] == 3600
def test_blocked_client_should_fail_initially(self, rate_limiter):
"""Test blocked client handling - should fail initially."""
client_id = "blocked-client"
# Block the client
rate_limiter.block_client(client_id)
# Request should be blocked
result = rate_limiter.check_rate_limit(client_id)
# This will fail initially
assert result["allowed"] is False
assert result["reason"] == "Client blocked"
assert result["retry_after"] == 3600
# Unblock and test
rate_limiter.unblock_client(client_id)
result = rate_limiter.check_rate_limit(client_id)
assert result["allowed"] is True
def test_endpoint_specific_rate_limiting_should_fail_initially(self, rate_limiter):
"""Test endpoint-specific rate limiting - should fail initially."""
client_id = "test-client-004"
# Make requests to different endpoints
result1 = rate_limiter.check_rate_limit(client_id, "/api/pose/current")
result2 = rate_limiter.check_rate_limit(client_id, "/api/stream/status")
# This will fail initially
assert result1["allowed"] is True
assert result2["allowed"] is True
# Each endpoint should have separate rate limiting
for i in range(4):
rate_limiter.check_rate_limit(client_id, "/api/pose/current")
# Pose endpoint should be at limit, but stream should still work
pose_result = rate_limiter.check_rate_limit(client_id, "/api/pose/current")
stream_result = rate_limiter.check_rate_limit(client_id, "/api/stream/status")
assert pose_result["allowed"] is False
assert stream_result["allowed"] is True
class TestRateLimitMiddleware:
"""Test rate limiting middleware functionality."""
@pytest.fixture
def mock_request(self):
"""Mock FastAPI request."""
class MockRequest:
def __init__(self, client_ip="127.0.0.1", path="/api/test", method="GET"):
self.client = MagicMock()
self.client.host = client_ip
self.url = MagicMock()
self.url.path = path
self.method = method
self.headers = {}
self.state = MagicMock()
return MockRequest
@pytest.fixture
def mock_response(self):
"""Mock FastAPI response."""
class MockResponse:
def __init__(self):
self.status_code = 200
self.headers = {}
return MockResponse()
@pytest.fixture
def rate_limit_middleware(self, rate_limiter):
"""Create rate limiting middleware."""
class RateLimitMiddleware:
def __init__(self, rate_limiter):
self.rate_limiter = rate_limiter
async def __call__(self, request, call_next):
# Get client identifier
client_id = self._get_client_id(request)
endpoint = request.url.path
# Check rate limit
limit_result = self.rate_limiter.check_rate_limit(client_id, endpoint)
if not limit_result["allowed"]:
# Return rate limit exceeded response
response = Response(
content=f"Rate limit exceeded: {limit_result['reason']}",
status_code=status.HTTP_429_TOO_MANY_REQUESTS
)
response.headers["Retry-After"] = str(limit_result["retry_after"])
response.headers["X-RateLimit-Limit"] = str(limit_result.get("limit", "unknown"))
response.headers["X-RateLimit-Remaining"] = "0"
return response
# Process request
response = await call_next(request)
# Add rate limit headers
response.headers["X-RateLimit-Limit"] = str(self.rate_limiter.requests_per_minute)
response.headers["X-RateLimit-Remaining"] = str(limit_result.get("remaining_minute", 0))
response.headers["X-RateLimit-Reset"] = str(int(limit_result.get("reset_time", datetime.utcnow()).timestamp()))
return response
def _get_client_id(self, request):
"""Get client identifier from request."""
# Check for API key in headers
api_key = request.headers.get("X-API-Key")
if api_key:
return f"api:{api_key}"
# Check for user ID in request state (from auth)
if hasattr(request.state, "user") and request.state.user:
return f"user:{request.state.user.get('id', 'unknown')}"
# Fall back to IP address
return f"ip:{request.client.host}"
return RateLimitMiddleware(rate_limiter)
@pytest.mark.asyncio
async def test_middleware_allows_normal_requests_should_fail_initially(
self, rate_limit_middleware, mock_request, mock_response
):
"""Test middleware allows normal requests - should fail initially."""
request = mock_request()
async def mock_call_next(req):
return mock_response
response = await rate_limit_middleware(request, mock_call_next)
# This will fail initially
assert response.status_code == 200
assert "X-RateLimit-Limit" in response.headers
assert "X-RateLimit-Remaining" in response.headers
assert "X-RateLimit-Reset" in response.headers
@pytest.mark.asyncio
async def test_middleware_blocks_excessive_requests_should_fail_initially(
self, rate_limit_middleware, mock_request
):
"""Test middleware blocks excessive requests - should fail initially."""
request = mock_request()
async def mock_call_next(req):
response = Response(content="OK", status_code=200)
return response
# Make requests up to the limit
for i in range(5):
response = await rate_limit_middleware(request, mock_call_next)
assert response.status_code == 200
# Next request should be blocked
response = await rate_limit_middleware(request, mock_call_next)
# This will fail initially
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
assert "Retry-After" in response.headers
assert "X-RateLimit-Remaining" in response.headers
assert response.headers["X-RateLimit-Remaining"] == "0"
@pytest.mark.asyncio
async def test_middleware_client_identification_should_fail_initially(
self, rate_limit_middleware, mock_request
):
"""Test middleware client identification - should fail initially."""
# Test API key identification
request_with_api_key = mock_request()
request_with_api_key.headers["X-API-Key"] = "test-api-key-123"
# Test user identification
request_with_user = mock_request()
request_with_user.state.user = {"id": "user-123"}
# Test IP identification
request_with_ip = mock_request(client_ip="192.168.1.100")
async def mock_call_next(req):
return Response(content="OK", status_code=200)
# Each should be treated as different clients
response1 = await rate_limit_middleware(request_with_api_key, mock_call_next)
response2 = await rate_limit_middleware(request_with_user, mock_call_next)
response3 = await rate_limit_middleware(request_with_ip, mock_call_next)
# This will fail initially
assert response1.status_code == 200
assert response2.status_code == 200
assert response3.status_code == 200
class TestRateLimitingStrategies:
"""Test different rate limiting strategies."""
@pytest.fixture
def sliding_window_limiter(self):
"""Create sliding window rate limiter."""
class SlidingWindowLimiter:
def __init__(self, window_size_seconds: int = 60, max_requests: int = 10):
self.window_size = window_size_seconds
self.max_requests = max_requests
self.request_times = {}
def check_limit(self, client_id: str) -> Dict[str, Any]:
now = time.time()
if client_id not in self.request_times:
self.request_times[client_id] = []
# Remove old requests outside the window
cutoff_time = now - self.window_size
self.request_times[client_id] = [
req_time for req_time in self.request_times[client_id]
if req_time > cutoff_time
]
# Check if we're at the limit
if len(self.request_times[client_id]) >= self.max_requests:
oldest_request = min(self.request_times[client_id])
retry_after = int(oldest_request + self.window_size - now)
return {
"allowed": False,
"retry_after": max(retry_after, 1),
"current_requests": len(self.request_times[client_id]),
"limit": self.max_requests
}
# Record this request
self.request_times[client_id].append(now)
return {
"allowed": True,
"remaining": self.max_requests - len(self.request_times[client_id]),
"window_reset": int(now + self.window_size)
}
return SlidingWindowLimiter(window_size_seconds=10, max_requests=3)
@pytest.fixture
def token_bucket_limiter(self):
"""Create token bucket rate limiter."""
class TokenBucketLimiter:
def __init__(self, capacity: int = 10, refill_rate: float = 1.0):
self.capacity = capacity
self.refill_rate = refill_rate # tokens per second
self.buckets = {}
def check_limit(self, client_id: str) -> Dict[str, Any]:
now = time.time()
if client_id not in self.buckets:
self.buckets[client_id] = {
"tokens": self.capacity,
"last_refill": now
}
bucket = self.buckets[client_id]
# Refill tokens based on time elapsed
time_elapsed = now - bucket["last_refill"]
tokens_to_add = time_elapsed * self.refill_rate
bucket["tokens"] = min(self.capacity, bucket["tokens"] + tokens_to_add)
bucket["last_refill"] = now
# Check if we have tokens available
if bucket["tokens"] < 1:
return {
"allowed": False,
"retry_after": int((1 - bucket["tokens"]) / self.refill_rate),
"tokens_remaining": bucket["tokens"]
}
# Consume a token
bucket["tokens"] -= 1
return {
"allowed": True,
"tokens_remaining": bucket["tokens"]
}
return TokenBucketLimiter(capacity=5, refill_rate=0.5) # 0.5 tokens per second
def test_sliding_window_limiter_should_fail_initially(self, sliding_window_limiter):
"""Test sliding window rate limiter - should fail initially."""
client_id = "sliding-test-client"
# Make requests up to limit
for i in range(3):
result = sliding_window_limiter.check_limit(client_id)
# This will fail initially
assert result["allowed"] is True
assert "remaining" in result
# Next request should be blocked
result = sliding_window_limiter.check_limit(client_id)
assert result["allowed"] is False
assert result["current_requests"] == 3
assert result["limit"] == 3
def test_token_bucket_limiter_should_fail_initially(self, token_bucket_limiter):
"""Test token bucket rate limiter - should fail initially."""
client_id = "bucket-test-client"
# Make requests up to capacity
for i in range(5):
result = token_bucket_limiter.check_limit(client_id)
# This will fail initially
assert result["allowed"] is True
assert "tokens_remaining" in result
# Next request should be blocked (no tokens left)
result = token_bucket_limiter.check_limit(client_id)
assert result["allowed"] is False
assert result["tokens_remaining"] < 1
@pytest.mark.asyncio
async def test_token_bucket_refill_should_fail_initially(self, token_bucket_limiter):
"""Test token bucket refill mechanism - should fail initially."""
client_id = "refill-test-client"
# Exhaust all tokens
for i in range(5):
token_bucket_limiter.check_limit(client_id)
# Should be blocked
result = token_bucket_limiter.check_limit(client_id)
assert result["allowed"] is False
# Wait for refill (simulate time passing)
await asyncio.sleep(2.1) # Wait for 1 token to be refilled (0.5 tokens/sec * 2.1 sec > 1)
# Should now be allowed
result = token_bucket_limiter.check_limit(client_id)
# This will fail initially
assert result["allowed"] is True
class TestRateLimitingPerformance:
"""Test rate limiting performance characteristics."""
@pytest.mark.asyncio
async def test_concurrent_rate_limit_checks_should_fail_initially(self):
"""Test concurrent rate limit checks - should fail initially."""
rate_limiter = MockRateLimiter(requests_per_minute=100, requests_per_hour=1000)
async def make_request(client_id: str, request_id: int):
result = rate_limiter.check_rate_limit(f"{client_id}-{request_id}")
return result["allowed"]
# Create many concurrent requests
tasks = [
make_request("concurrent-client", i)
for i in range(50)
]
results = await asyncio.gather(*tasks)
# This will fail initially
assert len(results) == 50
assert all(results) # All should be allowed since they're different clients
@pytest.mark.asyncio
async def test_rate_limiter_memory_cleanup_should_fail_initially(self):
"""Test rate limiter memory cleanup - should fail initially."""
rate_limiter = MockRateLimiter(requests_per_minute=10, requests_per_hour=100)
# Make requests for many different clients
for i in range(100):
rate_limiter.check_rate_limit(f"client-{i}")
initial_memory_size = len(rate_limiter.request_history)
# Simulate time passing and cleanup
for client_key in list(rate_limiter.request_history.keys()):
rate_limiter._cleanup_old_requests(client_key)
# This will fail initially
assert initial_memory_size == 100
# After cleanup, old entries should be removed
# (In a real implementation, this would clean up old timestamps)
final_memory_size = len([
key for key, history in rate_limiter.request_history.items()
if history # Only count non-empty histories
])
assert final_memory_size <= initial_memory_size