Major changes: - Organized Python v1 implementation into v1/ subdirectory - Created Rust workspace with 9 modular crates: - wifi-densepose-core: Core types, traits, errors - wifi-densepose-signal: CSI processing, phase sanitization, FFT - wifi-densepose-nn: Neural network inference (ONNX/Candle/tch) - wifi-densepose-api: Axum-based REST/WebSocket API - wifi-densepose-db: SQLx database layer - wifi-densepose-config: Configuration management - wifi-densepose-hardware: Hardware abstraction - wifi-densepose-wasm: WebAssembly bindings - wifi-densepose-cli: Command-line interface Documentation: - ADR-001: Workspace structure - ADR-002: Signal processing library selection - ADR-003: Neural network inference strategy - DDD domain model with bounded contexts Testing: - 69 tests passing across all crates - Signal processing: 45 tests - Neural networks: 21 tests - Core: 3 doc tests Performance targets: - 10x faster CSI processing (~0.5ms vs ~5ms) - 5x lower memory usage (~100MB vs ~500MB) - WASM support for browser deployment
483 lines
17 KiB
Python
Executable File
483 lines
17 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Test script for authentication and rate limiting functionality
|
|
"""
|
|
|
|
import asyncio
|
|
import time
|
|
import json
|
|
import sys
|
|
from typing import Dict, List, Any
|
|
import httpx
|
|
import jwt
|
|
from datetime import datetime, timedelta
|
|
|
|
# Configuration
|
|
BASE_URL = "http://localhost:8000"
|
|
API_PREFIX = "/api/v1"
|
|
|
|
# Test credentials
|
|
TEST_USERS = {
|
|
"admin": {"username": "admin", "password": "admin123"},
|
|
"user": {"username": "user", "password": "user123"}
|
|
}
|
|
|
|
# JWT settings for testing
|
|
SECRET_KEY = "your-secret-key-here" # This should match your settings
|
|
JWT_ALGORITHM = "HS256"
|
|
|
|
|
|
class AuthRateLimitTester:
|
|
def __init__(self, base_url: str = BASE_URL):
|
|
self.base_url = base_url
|
|
self.client = httpx.Client(timeout=30.0)
|
|
self.async_client = httpx.AsyncClient(timeout=30.0)
|
|
self.results = []
|
|
|
|
def log_result(self, test_name: str, success: bool, message: str, details: Dict = None):
|
|
"""Log test result"""
|
|
result = {
|
|
"test": test_name,
|
|
"success": success,
|
|
"message": message,
|
|
"timestamp": datetime.now().isoformat(),
|
|
"details": details or {}
|
|
}
|
|
self.results.append(result)
|
|
|
|
# Print result
|
|
status = "✓" if success else "✗"
|
|
print(f"{status} {test_name}: {message}")
|
|
if details and not success:
|
|
print(f" Details: {json.dumps(details, indent=2)}")
|
|
|
|
def generate_test_token(self, username: str, expired: bool = False) -> str:
|
|
"""Generate a test JWT token"""
|
|
payload = {
|
|
"sub": username,
|
|
"username": username,
|
|
"email": f"{username}@example.com",
|
|
"roles": ["admin"] if username == "admin" else ["user"],
|
|
"iat": datetime.utcnow(),
|
|
"exp": datetime.utcnow() + (timedelta(hours=-1) if expired else timedelta(hours=24))
|
|
}
|
|
return jwt.encode(payload, SECRET_KEY, algorithm=JWT_ALGORITHM)
|
|
|
|
def test_public_endpoints(self):
|
|
"""Test access to public endpoints without authentication"""
|
|
print("\n=== Testing Public Endpoints ===")
|
|
|
|
public_endpoints = [
|
|
"/",
|
|
"/health",
|
|
f"{API_PREFIX}/info",
|
|
f"{API_PREFIX}/status",
|
|
f"{API_PREFIX}/pose/current"
|
|
]
|
|
|
|
for endpoint in public_endpoints:
|
|
try:
|
|
response = self.client.get(f"{self.base_url}{endpoint}")
|
|
self.log_result(
|
|
f"Public endpoint {endpoint}",
|
|
response.status_code in [200, 204],
|
|
f"Status: {response.status_code}",
|
|
{"response": response.json() if response.content else None}
|
|
)
|
|
except Exception as e:
|
|
self.log_result(
|
|
f"Public endpoint {endpoint}",
|
|
False,
|
|
str(e)
|
|
)
|
|
|
|
def test_protected_endpoints(self):
|
|
"""Test protected endpoints without authentication"""
|
|
print("\n=== Testing Protected Endpoints (No Auth) ===")
|
|
|
|
protected_endpoints = [
|
|
(f"{API_PREFIX}/pose/analyze", "POST"),
|
|
(f"{API_PREFIX}/pose/calibrate", "POST"),
|
|
(f"{API_PREFIX}/stream/start", "POST"),
|
|
(f"{API_PREFIX}/stream/stop", "POST")
|
|
]
|
|
|
|
for endpoint, method in protected_endpoints:
|
|
try:
|
|
if method == "GET":
|
|
response = self.client.get(f"{self.base_url}{endpoint}")
|
|
else:
|
|
response = self.client.post(f"{self.base_url}{endpoint}", json={})
|
|
|
|
# Should return 401 Unauthorized
|
|
expected_status = 401
|
|
self.log_result(
|
|
f"Protected endpoint {endpoint} without auth",
|
|
response.status_code == expected_status,
|
|
f"Status: {response.status_code} (expected {expected_status})",
|
|
{"response": response.json() if response.content else None}
|
|
)
|
|
except Exception as e:
|
|
self.log_result(
|
|
f"Protected endpoint {endpoint}",
|
|
False,
|
|
str(e)
|
|
)
|
|
|
|
def test_authentication_headers(self):
|
|
"""Test different authentication header formats"""
|
|
print("\n=== Testing Authentication Headers ===")
|
|
|
|
endpoint = f"{self.base_url}{API_PREFIX}/pose/analyze"
|
|
test_cases = [
|
|
("No header", {}),
|
|
("Invalid format", {"Authorization": "InvalidFormat"}),
|
|
("Wrong scheme", {"Authorization": "Basic dGVzdDp0ZXN0"}),
|
|
("Invalid token", {"Authorization": "Bearer invalid.token.here"}),
|
|
("Expired token", {"Authorization": f"Bearer {self.generate_test_token('user', expired=True)}"}),
|
|
("Valid token", {"Authorization": f"Bearer {self.generate_test_token('user')}"})
|
|
]
|
|
|
|
for test_name, headers in test_cases:
|
|
try:
|
|
response = self.client.post(endpoint, headers=headers, json={})
|
|
|
|
# Only valid token should succeed (or get validation error)
|
|
if test_name == "Valid token":
|
|
expected = response.status_code in [200, 422] # 422 for validation errors
|
|
else:
|
|
expected = response.status_code == 401
|
|
|
|
self.log_result(
|
|
f"Auth header test: {test_name}",
|
|
expected,
|
|
f"Status: {response.status_code}",
|
|
{"headers": headers}
|
|
)
|
|
except Exception as e:
|
|
self.log_result(
|
|
f"Auth header test: {test_name}",
|
|
False,
|
|
str(e)
|
|
)
|
|
|
|
async def test_rate_limiting(self):
|
|
"""Test rate limiting functionality"""
|
|
print("\n=== Testing Rate Limiting ===")
|
|
|
|
# Test endpoints with different rate limits
|
|
test_configs = [
|
|
{
|
|
"endpoint": f"{API_PREFIX}/pose/current",
|
|
"method": "GET",
|
|
"requests": 70, # More than 60/min limit
|
|
"window": 60,
|
|
"description": "Current pose endpoint (60/min)"
|
|
},
|
|
{
|
|
"endpoint": f"{API_PREFIX}/pose/analyze",
|
|
"method": "POST",
|
|
"requests": 15, # More than 10/min limit
|
|
"window": 60,
|
|
"description": "Analyze endpoint (10/min)",
|
|
"auth": True
|
|
}
|
|
]
|
|
|
|
for config in test_configs:
|
|
print(f"\nTesting: {config['description']}")
|
|
|
|
# Prepare headers
|
|
headers = {}
|
|
if config.get("auth"):
|
|
headers["Authorization"] = f"Bearer {self.generate_test_token('user')}"
|
|
|
|
# Send requests
|
|
responses = []
|
|
start_time = time.time()
|
|
|
|
for i in range(config["requests"]):
|
|
try:
|
|
if config["method"] == "GET":
|
|
response = await self.async_client.get(
|
|
f"{self.base_url}{config['endpoint']}",
|
|
headers=headers
|
|
)
|
|
else:
|
|
response = await self.async_client.post(
|
|
f"{self.base_url}{config['endpoint']}",
|
|
headers=headers,
|
|
json={}
|
|
)
|
|
|
|
responses.append({
|
|
"request": i + 1,
|
|
"status": response.status_code,
|
|
"headers": dict(response.headers)
|
|
})
|
|
|
|
# Check rate limit headers
|
|
if "X-RateLimit-Limit" in response.headers:
|
|
remaining = response.headers.get("X-RateLimit-Remaining", "N/A")
|
|
if i % 10 == 0: # Print every 10th request
|
|
print(f" Request {i+1}: Status {response.status_code}, Remaining: {remaining}")
|
|
|
|
# Small delay to avoid overwhelming
|
|
await asyncio.sleep(0.1)
|
|
|
|
except Exception as e:
|
|
responses.append({
|
|
"request": i + 1,
|
|
"error": str(e)
|
|
})
|
|
|
|
elapsed = time.time() - start_time
|
|
|
|
# Analyze results
|
|
rate_limited = sum(1 for r in responses if r.get("status") == 429)
|
|
successful = sum(1 for r in responses if r.get("status") in [200, 204])
|
|
|
|
self.log_result(
|
|
f"Rate limit test: {config['description']}",
|
|
rate_limited > 0, # Should have some rate limited requests
|
|
f"Sent {config['requests']} requests in {elapsed:.1f}s. "
|
|
f"Successful: {successful}, Rate limited: {rate_limited}",
|
|
{
|
|
"total_requests": config["requests"],
|
|
"successful": successful,
|
|
"rate_limited": rate_limited,
|
|
"elapsed_time": f"{elapsed:.1f}s"
|
|
}
|
|
)
|
|
|
|
def test_rate_limit_headers(self):
|
|
"""Test rate limit response headers"""
|
|
print("\n=== Testing Rate Limit Headers ===")
|
|
|
|
endpoint = f"{self.base_url}{API_PREFIX}/pose/current"
|
|
|
|
try:
|
|
response = self.client.get(endpoint)
|
|
|
|
# Check for rate limit headers
|
|
expected_headers = [
|
|
"X-RateLimit-Limit",
|
|
"X-RateLimit-Remaining",
|
|
"X-RateLimit-Window"
|
|
]
|
|
|
|
found_headers = {h: response.headers.get(h) for h in expected_headers if h in response.headers}
|
|
|
|
self.log_result(
|
|
"Rate limit headers",
|
|
len(found_headers) > 0,
|
|
f"Found {len(found_headers)} rate limit headers",
|
|
{"headers": found_headers}
|
|
)
|
|
|
|
# Test 429 response
|
|
if len(found_headers) > 0:
|
|
# Send many requests to trigger rate limit
|
|
for _ in range(100):
|
|
r = self.client.get(endpoint)
|
|
if r.status_code == 429:
|
|
retry_after = r.headers.get("Retry-After")
|
|
self.log_result(
|
|
"Rate limit 429 response",
|
|
retry_after is not None,
|
|
f"Got 429 with Retry-After: {retry_after}",
|
|
{"headers": dict(r.headers)}
|
|
)
|
|
break
|
|
|
|
except Exception as e:
|
|
self.log_result(
|
|
"Rate limit headers",
|
|
False,
|
|
str(e)
|
|
)
|
|
|
|
def test_cors_headers(self):
|
|
"""Test CORS headers"""
|
|
print("\n=== Testing CORS Headers ===")
|
|
|
|
test_origins = [
|
|
"http://localhost:3000",
|
|
"https://example.com",
|
|
"http://malicious.site"
|
|
]
|
|
|
|
endpoint = f"{self.base_url}/health"
|
|
|
|
for origin in test_origins:
|
|
try:
|
|
# Regular request with Origin header
|
|
response = self.client.get(
|
|
endpoint,
|
|
headers={"Origin": origin}
|
|
)
|
|
|
|
cors_headers = {
|
|
k: v for k, v in response.headers.items()
|
|
if k.lower().startswith("access-control-")
|
|
}
|
|
|
|
self.log_result(
|
|
f"CORS headers for origin {origin}",
|
|
len(cors_headers) > 0,
|
|
f"Found {len(cors_headers)} CORS headers",
|
|
{"headers": cors_headers}
|
|
)
|
|
|
|
# Preflight request
|
|
preflight_response = self.client.options(
|
|
endpoint,
|
|
headers={
|
|
"Origin": origin,
|
|
"Access-Control-Request-Method": "POST",
|
|
"Access-Control-Request-Headers": "Content-Type,Authorization"
|
|
}
|
|
)
|
|
|
|
self.log_result(
|
|
f"CORS preflight for origin {origin}",
|
|
preflight_response.status_code in [200, 204],
|
|
f"Status: {preflight_response.status_code}",
|
|
{"headers": dict(preflight_response.headers)}
|
|
)
|
|
|
|
except Exception as e:
|
|
self.log_result(
|
|
f"CORS test for origin {origin}",
|
|
False,
|
|
str(e)
|
|
)
|
|
|
|
def test_security_headers(self):
|
|
"""Test security headers"""
|
|
print("\n=== Testing Security Headers ===")
|
|
|
|
endpoint = f"{self.base_url}/health"
|
|
|
|
try:
|
|
response = self.client.get(endpoint)
|
|
|
|
security_headers = [
|
|
"X-Content-Type-Options",
|
|
"X-Frame-Options",
|
|
"X-XSS-Protection",
|
|
"Referrer-Policy",
|
|
"Content-Security-Policy"
|
|
]
|
|
|
|
found_headers = {h: response.headers.get(h) for h in security_headers if h in response.headers}
|
|
|
|
self.log_result(
|
|
"Security headers",
|
|
len(found_headers) >= 3, # At least 3 security headers
|
|
f"Found {len(found_headers)}/{len(security_headers)} security headers",
|
|
{"headers": found_headers}
|
|
)
|
|
|
|
except Exception as e:
|
|
self.log_result(
|
|
"Security headers",
|
|
False,
|
|
str(e)
|
|
)
|
|
|
|
def test_authentication_states(self):
|
|
"""Test authentication enable/disable states"""
|
|
print("\n=== Testing Authentication States ===")
|
|
|
|
# Check if authentication is enabled
|
|
try:
|
|
info_response = self.client.get(f"{self.base_url}{API_PREFIX}/info")
|
|
if info_response.status_code == 200:
|
|
info = info_response.json()
|
|
auth_enabled = info.get("features", {}).get("authentication", False)
|
|
rate_limit_enabled = info.get("features", {}).get("rate_limiting", False)
|
|
|
|
self.log_result(
|
|
"Feature flags",
|
|
True,
|
|
f"Authentication: {auth_enabled}, Rate Limiting: {rate_limit_enabled}",
|
|
{
|
|
"authentication": auth_enabled,
|
|
"rate_limiting": rate_limit_enabled
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
self.log_result(
|
|
"Feature flags",
|
|
False,
|
|
str(e)
|
|
)
|
|
|
|
async def run_all_tests(self):
|
|
"""Run all tests"""
|
|
print("=" * 60)
|
|
print("WiFi-DensePose Authentication & Rate Limiting Test Suite")
|
|
print("=" * 60)
|
|
|
|
# Run synchronous tests
|
|
self.test_public_endpoints()
|
|
self.test_protected_endpoints()
|
|
self.test_authentication_headers()
|
|
self.test_rate_limit_headers()
|
|
self.test_cors_headers()
|
|
self.test_security_headers()
|
|
self.test_authentication_states()
|
|
|
|
# Run async tests
|
|
await self.test_rate_limiting()
|
|
|
|
# Summary
|
|
print("\n" + "=" * 60)
|
|
print("Test Summary")
|
|
print("=" * 60)
|
|
|
|
total = len(self.results)
|
|
passed = sum(1 for r in self.results if r["success"])
|
|
failed = total - passed
|
|
|
|
print(f"Total tests: {total}")
|
|
print(f"Passed: {passed}")
|
|
print(f"Failed: {failed}")
|
|
print(f"Success rate: {(passed/total*100):.1f}%" if total > 0 else "N/A")
|
|
|
|
# Save results
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
filename = f"auth_rate_limit_test_results_{timestamp}.json"
|
|
|
|
with open(filename, "w") as f:
|
|
json.dump({
|
|
"test_run": {
|
|
"timestamp": datetime.now().isoformat(),
|
|
"base_url": self.base_url,
|
|
"total_tests": total,
|
|
"passed": passed,
|
|
"failed": failed
|
|
},
|
|
"results": self.results
|
|
}, f, indent=2)
|
|
|
|
print(f"\nResults saved to: {filename}")
|
|
|
|
# Cleanup
|
|
await self.async_client.aclose()
|
|
self.client.close()
|
|
|
|
return passed == total
|
|
|
|
|
|
async def main():
|
|
"""Main function"""
|
|
tester = AuthRateLimitTester()
|
|
success = await tester.run_all_tests()
|
|
sys.exit(0 if success else 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main()) |