Files
wifi-densepose/v1/test_auth_rate_limit.py
Claude 6ed69a3d48 feat: Complete Rust port of WiFi-DensePose with modular crates
Major changes:
- Organized Python v1 implementation into v1/ subdirectory
- Created Rust workspace with 9 modular crates:
  - wifi-densepose-core: Core types, traits, errors
  - wifi-densepose-signal: CSI processing, phase sanitization, FFT
  - wifi-densepose-nn: Neural network inference (ONNX/Candle/tch)
  - wifi-densepose-api: Axum-based REST/WebSocket API
  - wifi-densepose-db: SQLx database layer
  - wifi-densepose-config: Configuration management
  - wifi-densepose-hardware: Hardware abstraction
  - wifi-densepose-wasm: WebAssembly bindings
  - wifi-densepose-cli: Command-line interface

Documentation:
- ADR-001: Workspace structure
- ADR-002: Signal processing library selection
- ADR-003: Neural network inference strategy
- DDD domain model with bounded contexts

Testing:
- 69 tests passing across all crates
- Signal processing: 45 tests
- Neural networks: 21 tests
- Core: 3 doc tests

Performance targets:
- 10x faster CSI processing (~0.5ms vs ~5ms)
- 5x lower memory usage (~100MB vs ~500MB)
- WASM support for browser deployment
2026-01-13 03:11:16 +00:00

483 lines
17 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Test script for authentication and rate limiting functionality
"""
import asyncio
import time
import json
import sys
from typing import Dict, List, Any
import httpx
import jwt
from datetime import datetime, timedelta
# Configuration
BASE_URL = "http://localhost:8000"
API_PREFIX = "/api/v1"
# Test credentials
TEST_USERS = {
"admin": {"username": "admin", "password": "admin123"},
"user": {"username": "user", "password": "user123"}
}
# JWT settings for testing
SECRET_KEY = "your-secret-key-here" # This should match your settings
JWT_ALGORITHM = "HS256"
class AuthRateLimitTester:
def __init__(self, base_url: str = BASE_URL):
self.base_url = base_url
self.client = httpx.Client(timeout=30.0)
self.async_client = httpx.AsyncClient(timeout=30.0)
self.results = []
def log_result(self, test_name: str, success: bool, message: str, details: Dict = None):
"""Log test result"""
result = {
"test": test_name,
"success": success,
"message": message,
"timestamp": datetime.now().isoformat(),
"details": details or {}
}
self.results.append(result)
# Print result
status = "" if success else ""
print(f"{status} {test_name}: {message}")
if details and not success:
print(f" Details: {json.dumps(details, indent=2)}")
def generate_test_token(self, username: str, expired: bool = False) -> str:
"""Generate a test JWT token"""
payload = {
"sub": username,
"username": username,
"email": f"{username}@example.com",
"roles": ["admin"] if username == "admin" else ["user"],
"iat": datetime.utcnow(),
"exp": datetime.utcnow() + (timedelta(hours=-1) if expired else timedelta(hours=24))
}
return jwt.encode(payload, SECRET_KEY, algorithm=JWT_ALGORITHM)
def test_public_endpoints(self):
"""Test access to public endpoints without authentication"""
print("\n=== Testing Public Endpoints ===")
public_endpoints = [
"/",
"/health",
f"{API_PREFIX}/info",
f"{API_PREFIX}/status",
f"{API_PREFIX}/pose/current"
]
for endpoint in public_endpoints:
try:
response = self.client.get(f"{self.base_url}{endpoint}")
self.log_result(
f"Public endpoint {endpoint}",
response.status_code in [200, 204],
f"Status: {response.status_code}",
{"response": response.json() if response.content else None}
)
except Exception as e:
self.log_result(
f"Public endpoint {endpoint}",
False,
str(e)
)
def test_protected_endpoints(self):
"""Test protected endpoints without authentication"""
print("\n=== Testing Protected Endpoints (No Auth) ===")
protected_endpoints = [
(f"{API_PREFIX}/pose/analyze", "POST"),
(f"{API_PREFIX}/pose/calibrate", "POST"),
(f"{API_PREFIX}/stream/start", "POST"),
(f"{API_PREFIX}/stream/stop", "POST")
]
for endpoint, method in protected_endpoints:
try:
if method == "GET":
response = self.client.get(f"{self.base_url}{endpoint}")
else:
response = self.client.post(f"{self.base_url}{endpoint}", json={})
# Should return 401 Unauthorized
expected_status = 401
self.log_result(
f"Protected endpoint {endpoint} without auth",
response.status_code == expected_status,
f"Status: {response.status_code} (expected {expected_status})",
{"response": response.json() if response.content else None}
)
except Exception as e:
self.log_result(
f"Protected endpoint {endpoint}",
False,
str(e)
)
def test_authentication_headers(self):
"""Test different authentication header formats"""
print("\n=== Testing Authentication Headers ===")
endpoint = f"{self.base_url}{API_PREFIX}/pose/analyze"
test_cases = [
("No header", {}),
("Invalid format", {"Authorization": "InvalidFormat"}),
("Wrong scheme", {"Authorization": "Basic dGVzdDp0ZXN0"}),
("Invalid token", {"Authorization": "Bearer invalid.token.here"}),
("Expired token", {"Authorization": f"Bearer {self.generate_test_token('user', expired=True)}"}),
("Valid token", {"Authorization": f"Bearer {self.generate_test_token('user')}"})
]
for test_name, headers in test_cases:
try:
response = self.client.post(endpoint, headers=headers, json={})
# Only valid token should succeed (or get validation error)
if test_name == "Valid token":
expected = response.status_code in [200, 422] # 422 for validation errors
else:
expected = response.status_code == 401
self.log_result(
f"Auth header test: {test_name}",
expected,
f"Status: {response.status_code}",
{"headers": headers}
)
except Exception as e:
self.log_result(
f"Auth header test: {test_name}",
False,
str(e)
)
async def test_rate_limiting(self):
"""Test rate limiting functionality"""
print("\n=== Testing Rate Limiting ===")
# Test endpoints with different rate limits
test_configs = [
{
"endpoint": f"{API_PREFIX}/pose/current",
"method": "GET",
"requests": 70, # More than 60/min limit
"window": 60,
"description": "Current pose endpoint (60/min)"
},
{
"endpoint": f"{API_PREFIX}/pose/analyze",
"method": "POST",
"requests": 15, # More than 10/min limit
"window": 60,
"description": "Analyze endpoint (10/min)",
"auth": True
}
]
for config in test_configs:
print(f"\nTesting: {config['description']}")
# Prepare headers
headers = {}
if config.get("auth"):
headers["Authorization"] = f"Bearer {self.generate_test_token('user')}"
# Send requests
responses = []
start_time = time.time()
for i in range(config["requests"]):
try:
if config["method"] == "GET":
response = await self.async_client.get(
f"{self.base_url}{config['endpoint']}",
headers=headers
)
else:
response = await self.async_client.post(
f"{self.base_url}{config['endpoint']}",
headers=headers,
json={}
)
responses.append({
"request": i + 1,
"status": response.status_code,
"headers": dict(response.headers)
})
# Check rate limit headers
if "X-RateLimit-Limit" in response.headers:
remaining = response.headers.get("X-RateLimit-Remaining", "N/A")
if i % 10 == 0: # Print every 10th request
print(f" Request {i+1}: Status {response.status_code}, Remaining: {remaining}")
# Small delay to avoid overwhelming
await asyncio.sleep(0.1)
except Exception as e:
responses.append({
"request": i + 1,
"error": str(e)
})
elapsed = time.time() - start_time
# Analyze results
rate_limited = sum(1 for r in responses if r.get("status") == 429)
successful = sum(1 for r in responses if r.get("status") in [200, 204])
self.log_result(
f"Rate limit test: {config['description']}",
rate_limited > 0, # Should have some rate limited requests
f"Sent {config['requests']} requests in {elapsed:.1f}s. "
f"Successful: {successful}, Rate limited: {rate_limited}",
{
"total_requests": config["requests"],
"successful": successful,
"rate_limited": rate_limited,
"elapsed_time": f"{elapsed:.1f}s"
}
)
def test_rate_limit_headers(self):
"""Test rate limit response headers"""
print("\n=== Testing Rate Limit Headers ===")
endpoint = f"{self.base_url}{API_PREFIX}/pose/current"
try:
response = self.client.get(endpoint)
# Check for rate limit headers
expected_headers = [
"X-RateLimit-Limit",
"X-RateLimit-Remaining",
"X-RateLimit-Window"
]
found_headers = {h: response.headers.get(h) for h in expected_headers if h in response.headers}
self.log_result(
"Rate limit headers",
len(found_headers) > 0,
f"Found {len(found_headers)} rate limit headers",
{"headers": found_headers}
)
# Test 429 response
if len(found_headers) > 0:
# Send many requests to trigger rate limit
for _ in range(100):
r = self.client.get(endpoint)
if r.status_code == 429:
retry_after = r.headers.get("Retry-After")
self.log_result(
"Rate limit 429 response",
retry_after is not None,
f"Got 429 with Retry-After: {retry_after}",
{"headers": dict(r.headers)}
)
break
except Exception as e:
self.log_result(
"Rate limit headers",
False,
str(e)
)
def test_cors_headers(self):
"""Test CORS headers"""
print("\n=== Testing CORS Headers ===")
test_origins = [
"http://localhost:3000",
"https://example.com",
"http://malicious.site"
]
endpoint = f"{self.base_url}/health"
for origin in test_origins:
try:
# Regular request with Origin header
response = self.client.get(
endpoint,
headers={"Origin": origin}
)
cors_headers = {
k: v for k, v in response.headers.items()
if k.lower().startswith("access-control-")
}
self.log_result(
f"CORS headers for origin {origin}",
len(cors_headers) > 0,
f"Found {len(cors_headers)} CORS headers",
{"headers": cors_headers}
)
# Preflight request
preflight_response = self.client.options(
endpoint,
headers={
"Origin": origin,
"Access-Control-Request-Method": "POST",
"Access-Control-Request-Headers": "Content-Type,Authorization"
}
)
self.log_result(
f"CORS preflight for origin {origin}",
preflight_response.status_code in [200, 204],
f"Status: {preflight_response.status_code}",
{"headers": dict(preflight_response.headers)}
)
except Exception as e:
self.log_result(
f"CORS test for origin {origin}",
False,
str(e)
)
def test_security_headers(self):
"""Test security headers"""
print("\n=== Testing Security Headers ===")
endpoint = f"{self.base_url}/health"
try:
response = self.client.get(endpoint)
security_headers = [
"X-Content-Type-Options",
"X-Frame-Options",
"X-XSS-Protection",
"Referrer-Policy",
"Content-Security-Policy"
]
found_headers = {h: response.headers.get(h) for h in security_headers if h in response.headers}
self.log_result(
"Security headers",
len(found_headers) >= 3, # At least 3 security headers
f"Found {len(found_headers)}/{len(security_headers)} security headers",
{"headers": found_headers}
)
except Exception as e:
self.log_result(
"Security headers",
False,
str(e)
)
def test_authentication_states(self):
"""Test authentication enable/disable states"""
print("\n=== Testing Authentication States ===")
# Check if authentication is enabled
try:
info_response = self.client.get(f"{self.base_url}{API_PREFIX}/info")
if info_response.status_code == 200:
info = info_response.json()
auth_enabled = info.get("features", {}).get("authentication", False)
rate_limit_enabled = info.get("features", {}).get("rate_limiting", False)
self.log_result(
"Feature flags",
True,
f"Authentication: {auth_enabled}, Rate Limiting: {rate_limit_enabled}",
{
"authentication": auth_enabled,
"rate_limiting": rate_limit_enabled
}
)
except Exception as e:
self.log_result(
"Feature flags",
False,
str(e)
)
async def run_all_tests(self):
"""Run all tests"""
print("=" * 60)
print("WiFi-DensePose Authentication & Rate Limiting Test Suite")
print("=" * 60)
# Run synchronous tests
self.test_public_endpoints()
self.test_protected_endpoints()
self.test_authentication_headers()
self.test_rate_limit_headers()
self.test_cors_headers()
self.test_security_headers()
self.test_authentication_states()
# Run async tests
await self.test_rate_limiting()
# Summary
print("\n" + "=" * 60)
print("Test Summary")
print("=" * 60)
total = len(self.results)
passed = sum(1 for r in self.results if r["success"])
failed = total - passed
print(f"Total tests: {total}")
print(f"Passed: {passed}")
print(f"Failed: {failed}")
print(f"Success rate: {(passed/total*100):.1f}%" if total > 0 else "N/A")
# Save results
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"auth_rate_limit_test_results_{timestamp}.json"
with open(filename, "w") as f:
json.dump({
"test_run": {
"timestamp": datetime.now().isoformat(),
"base_url": self.base_url,
"total_tests": total,
"passed": passed,
"failed": failed
},
"results": self.results
}, f, indent=2)
print(f"\nResults saved to: {filename}")
# Cleanup
await self.async_client.aclose()
self.client.close()
return passed == total
async def main():
"""Main function"""
tester = AuthRateLimitTester()
success = await tester.run_all_tests()
sys.exit(0 if success else 1)
if __name__ == "__main__":
asyncio.run(main())