feat: Complete Rust port of WiFi-DensePose with modular crates

Major changes:
- Organized Python v1 implementation into v1/ subdirectory
- Created Rust workspace with 9 modular crates:
  - wifi-densepose-core: Core types, traits, errors
  - wifi-densepose-signal: CSI processing, phase sanitization, FFT
  - wifi-densepose-nn: Neural network inference (ONNX/Candle/tch)
  - wifi-densepose-api: Axum-based REST/WebSocket API
  - wifi-densepose-db: SQLx database layer
  - wifi-densepose-config: Configuration management
  - wifi-densepose-hardware: Hardware abstraction
  - wifi-densepose-wasm: WebAssembly bindings
  - wifi-densepose-cli: Command-line interface

Documentation:
- ADR-001: Workspace structure
- ADR-002: Signal processing library selection
- ADR-003: Neural network inference strategy
- DDD domain model with bounded contexts

Testing:
- 69 tests passing across all crates
- Signal processing: 45 tests
- Neural networks: 21 tests
- Core: 3 doc tests

Performance targets:
- 10x faster CSI processing (~0.5ms vs ~5ms)
- 5x lower memory usage (~100MB vs ~500MB)
- WASM support for browser deployment
This commit is contained in:
Claude
2026-01-13 03:11:16 +00:00
parent 5101504b72
commit 6ed69a3d48
427 changed files with 90993 additions and 0 deletions

View File

@@ -0,0 +1,8 @@
"""
FastAPI middleware package
"""
from .auth import AuthMiddleware
from .rate_limit import RateLimitMiddleware
__all__ = ["AuthMiddleware", "RateLimitMiddleware"]

View File

@@ -0,0 +1,322 @@
"""
JWT Authentication middleware for WiFi-DensePose API
"""
import logging
from typing import Optional, Dict, Any
from datetime import datetime
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from jose import JWTError, jwt
from src.config.settings import get_settings
logger = logging.getLogger(__name__)
class AuthMiddleware(BaseHTTPMiddleware):
"""JWT Authentication middleware."""
def __init__(self, app):
super().__init__(app)
self.settings = get_settings()
# Paths that don't require authentication
self.public_paths = {
"/",
"/docs",
"/redoc",
"/openapi.json",
"/health",
"/ready",
"/live",
"/version",
"/metrics"
}
# Paths that require authentication
self.protected_paths = {
"/api/v1/pose/analyze",
"/api/v1/pose/calibrate",
"/api/v1/pose/historical",
"/api/v1/stream/start",
"/api/v1/stream/stop",
"/api/v1/stream/clients",
"/api/v1/stream/broadcast"
}
async def dispatch(self, request: Request, call_next):
"""Process request through authentication middleware."""
# Skip authentication for public paths
if self._is_public_path(request.url.path):
return await call_next(request)
# Extract and validate token
token = self._extract_token(request)
if token:
try:
# Verify token and add user info to request state
user_data = await self._verify_token(token)
request.state.user = user_data
request.state.authenticated = True
logger.debug(f"Authenticated user: {user_data.get('id')}")
except Exception as e:
logger.warning(f"Token validation failed: {e}")
# For protected paths, return 401
if self._is_protected_path(request.url.path):
return JSONResponse(
status_code=401,
content={
"error": {
"code": 401,
"message": "Invalid or expired token",
"type": "authentication_error"
}
}
)
# For other paths, continue without authentication
request.state.user = None
request.state.authenticated = False
else:
# No token provided
if self._is_protected_path(request.url.path):
return JSONResponse(
status_code=401,
content={
"error": {
"code": 401,
"message": "Authentication required",
"type": "authentication_error"
}
},
headers={"WWW-Authenticate": "Bearer"}
)
request.state.user = None
request.state.authenticated = False
# Continue with request processing
response = await call_next(request)
# Add authentication headers to response
if hasattr(request.state, 'user') and request.state.user:
response.headers["X-User-ID"] = request.state.user.get("id", "")
response.headers["X-Authenticated"] = "true"
else:
response.headers["X-Authenticated"] = "false"
return response
def _is_public_path(self, path: str) -> bool:
"""Check if path is public (doesn't require authentication)."""
# Exact match
if path in self.public_paths:
return True
# Pattern matching for public paths
public_patterns = [
"/health",
"/metrics",
"/api/v1/pose/current", # Allow anonymous access to current pose data
"/api/v1/pose/zones/", # Allow anonymous access to zone data
"/api/v1/pose/activities", # Allow anonymous access to activities
"/api/v1/pose/stats", # Allow anonymous access to stats
"/api/v1/stream/status" # Allow anonymous access to stream status
]
for pattern in public_patterns:
if path.startswith(pattern):
return True
return False
def _is_protected_path(self, path: str) -> bool:
"""Check if path requires authentication."""
# Exact match
if path in self.protected_paths:
return True
# Pattern matching for protected paths
protected_patterns = [
"/api/v1/pose/analyze",
"/api/v1/pose/calibrate",
"/api/v1/pose/historical",
"/api/v1/stream/start",
"/api/v1/stream/stop",
"/api/v1/stream/clients",
"/api/v1/stream/broadcast"
]
for pattern in protected_patterns:
if path.startswith(pattern):
return True
return False
def _extract_token(self, request: Request) -> Optional[str]:
"""Extract JWT token from request."""
# Check Authorization header
auth_header = request.headers.get("authorization")
if auth_header and auth_header.startswith("Bearer "):
return auth_header.split(" ")[1]
# Check query parameter (for WebSocket connections)
token = request.query_params.get("token")
if token:
return token
# Check cookie
token = request.cookies.get("access_token")
if token:
return token
return None
async def _verify_token(self, token: str) -> Dict[str, Any]:
"""Verify JWT token and return user data."""
try:
# Decode JWT token
payload = jwt.decode(
token,
self.settings.secret_key,
algorithms=[self.settings.jwt_algorithm]
)
# Extract user information
user_id = payload.get("sub")
if not user_id:
raise ValueError("Token missing user ID")
# Check token expiration
exp = payload.get("exp")
if exp and datetime.utcnow() > datetime.fromtimestamp(exp):
raise ValueError("Token expired")
# Build user object
user_data = {
"id": user_id,
"username": payload.get("username"),
"email": payload.get("email"),
"is_admin": payload.get("is_admin", False),
"permissions": payload.get("permissions", []),
"accessible_zones": payload.get("accessible_zones", []),
"token_issued_at": payload.get("iat"),
"token_expires_at": payload.get("exp"),
"session_id": payload.get("session_id")
}
return user_data
except JWTError as e:
raise ValueError(f"JWT validation failed: {e}")
except Exception as e:
raise ValueError(f"Token verification error: {e}")
def _log_authentication_event(self, request: Request, event_type: str, details: Dict[str, Any] = None):
"""Log authentication events for security monitoring."""
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "unknown")
log_data = {
"event_type": event_type,
"timestamp": datetime.utcnow().isoformat(),
"client_ip": client_ip,
"user_agent": user_agent,
"path": request.url.path,
"method": request.method
}
if details:
log_data.update(details)
if event_type in ["authentication_failed", "token_expired", "invalid_token"]:
logger.warning(f"Auth event: {log_data}")
else:
logger.info(f"Auth event: {log_data}")
class TokenBlacklist:
"""Simple in-memory token blacklist for logout functionality."""
def __init__(self):
self._blacklisted_tokens = set()
self._cleanup_interval = 3600 # 1 hour
self._last_cleanup = datetime.utcnow()
def add_token(self, token: str):
"""Add token to blacklist."""
self._blacklisted_tokens.add(token)
self._cleanup_if_needed()
def is_blacklisted(self, token: str) -> bool:
"""Check if token is blacklisted."""
self._cleanup_if_needed()
return token in self._blacklisted_tokens
def _cleanup_if_needed(self):
"""Clean up expired tokens from blacklist."""
now = datetime.utcnow()
if (now - self._last_cleanup).total_seconds() > self._cleanup_interval:
# In a real implementation, you would check token expiration
# For now, we'll just clear old tokens periodically
self._blacklisted_tokens.clear()
self._last_cleanup = now
# Global token blacklist instance
token_blacklist = TokenBlacklist()
class SecurityHeaders:
"""Security headers for API responses."""
@staticmethod
def add_security_headers(response: Response) -> Response:
"""Add security headers to response."""
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
response.headers["Content-Security-Policy"] = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline'; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data:; "
"connect-src 'self' ws: wss:;"
)
return response
class APIKeyAuth:
"""Alternative API key authentication for service-to-service communication."""
def __init__(self, api_keys: Dict[str, Dict[str, Any]] = None):
self.api_keys = api_keys or {}
def verify_api_key(self, api_key: str) -> Optional[Dict[str, Any]]:
"""Verify API key and return associated service info."""
if api_key in self.api_keys:
return self.api_keys[api_key]
return None
def add_api_key(self, api_key: str, service_info: Dict[str, Any]):
"""Add new API key."""
self.api_keys[api_key] = service_info
def revoke_api_key(self, api_key: str):
"""Revoke API key."""
if api_key in self.api_keys:
del self.api_keys[api_key]
# Global API key auth instance
api_key_auth = APIKeyAuth()

View File

@@ -0,0 +1,429 @@
"""
Rate limiting middleware for WiFi-DensePose API
"""
import logging
import time
from typing import Dict, Optional, Tuple
from datetime import datetime, timedelta
from collections import defaultdict, deque
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from src.config.settings import get_settings
logger = logging.getLogger(__name__)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Rate limiting middleware with sliding window algorithm."""
def __init__(self, app):
super().__init__(app)
self.settings = get_settings()
# Rate limit storage (in production, use Redis)
self.request_counts = defaultdict(lambda: deque())
self.blocked_clients = {}
# Rate limit configurations
self.rate_limits = {
"anonymous": {
"requests": self.settings.rate_limit_requests,
"window": self.settings.rate_limit_window,
"burst": 10 # Allow burst of 10 requests
},
"authenticated": {
"requests": self.settings.rate_limit_authenticated_requests,
"window": self.settings.rate_limit_window,
"burst": 50
},
"admin": {
"requests": 10000, # Very high limit for admins
"window": self.settings.rate_limit_window,
"burst": 100
}
}
# Path-specific rate limits
self.path_limits = {
"/api/v1/pose/current": {"requests": 60, "window": 60}, # 1 per second
"/api/v1/pose/analyze": {"requests": 10, "window": 60}, # 10 per minute
"/api/v1/pose/calibrate": {"requests": 1, "window": 300}, # 1 per 5 minutes
"/api/v1/stream/start": {"requests": 5, "window": 60}, # 5 per minute
"/api/v1/stream/stop": {"requests": 5, "window": 60}, # 5 per minute
}
# Exempt paths from rate limiting
self.exempt_paths = {
"/health",
"/ready",
"/live",
"/version",
"/metrics"
}
async def dispatch(self, request: Request, call_next):
"""Process request through rate limiting middleware."""
# Skip rate limiting for exempt paths
if self._is_exempt_path(request.url.path):
return await call_next(request)
# Get client identifier
client_id = self._get_client_id(request)
# Check if client is temporarily blocked
if self._is_client_blocked(client_id):
return self._create_rate_limit_response(
"Client temporarily blocked due to excessive requests"
)
# Get user type for rate limiting
user_type = self._get_user_type(request)
# Check rate limits
rate_limit_result = self._check_rate_limits(
client_id,
request.url.path,
user_type
)
if not rate_limit_result["allowed"]:
# Log rate limit violation
self._log_rate_limit_violation(request, client_id, rate_limit_result)
# Check if client should be temporarily blocked
if rate_limit_result.get("violations", 0) > 5:
self._block_client(client_id, duration=300) # 5 minutes
return self._create_rate_limit_response(
rate_limit_result["message"],
retry_after=rate_limit_result.get("retry_after", 60)
)
# Record the request
self._record_request(client_id, request.url.path)
# Process request
response = await call_next(request)
# Add rate limit headers
self._add_rate_limit_headers(response, client_id, user_type)
return response
def _is_exempt_path(self, path: str) -> bool:
"""Check if path is exempt from rate limiting."""
return path in self.exempt_paths
def _get_client_id(self, request: Request) -> str:
"""Get unique client identifier for rate limiting."""
# Try to get user ID from request state (set by auth middleware)
if hasattr(request.state, 'user') and request.state.user:
return f"user:{request.state.user['id']}"
# Fall back to IP address
client_ip = request.client.host if request.client else "unknown"
# Include user agent for better identification
user_agent = request.headers.get("user-agent", "")
user_agent_hash = str(hash(user_agent))[:8]
return f"ip:{client_ip}:{user_agent_hash}"
def _get_user_type(self, request: Request) -> str:
"""Determine user type for rate limiting."""
if hasattr(request.state, 'user') and request.state.user:
if request.state.user.get("is_admin", False):
return "admin"
return "authenticated"
return "anonymous"
def _check_rate_limits(self, client_id: str, path: str, user_type: str) -> Dict:
"""Check if request is within rate limits."""
now = time.time()
# Get applicable rate limits
general_limit = self.rate_limits[user_type]
path_limit = self.path_limits.get(path)
# Check general rate limit
general_result = self._check_limit(
client_id,
"general",
general_limit["requests"],
general_limit["window"],
now
)
if not general_result["allowed"]:
return general_result
# Check path-specific rate limit if exists
if path_limit:
path_result = self._check_limit(
client_id,
f"path:{path}",
path_limit["requests"],
path_limit["window"],
now
)
if not path_result["allowed"]:
return path_result
return {"allowed": True}
def _check_limit(self, client_id: str, limit_type: str, max_requests: int, window: int, now: float) -> Dict:
"""Check specific rate limit using sliding window."""
key = f"{client_id}:{limit_type}"
requests = self.request_counts[key]
# Remove old requests outside the window
cutoff = now - window
while requests and requests[0] <= cutoff:
requests.popleft()
# Check if limit exceeded
if len(requests) >= max_requests:
# Calculate retry after time
oldest_request = requests[0] if requests else now
retry_after = int(oldest_request + window - now) + 1
return {
"allowed": False,
"message": f"Rate limit exceeded: {max_requests} requests per {window} seconds",
"retry_after": retry_after,
"current_count": len(requests),
"limit": max_requests,
"window": window
}
return {
"allowed": True,
"current_count": len(requests),
"limit": max_requests,
"window": window
}
def _record_request(self, client_id: str, path: str):
"""Record a request for rate limiting."""
now = time.time()
# Record general request
general_key = f"{client_id}:general"
self.request_counts[general_key].append(now)
# Record path-specific request if path has specific limits
if path in self.path_limits:
path_key = f"{client_id}:path:{path}"
self.request_counts[path_key].append(now)
def _is_client_blocked(self, client_id: str) -> bool:
"""Check if client is temporarily blocked."""
if client_id in self.blocked_clients:
block_until = self.blocked_clients[client_id]
if time.time() < block_until:
return True
else:
# Block expired, remove it
del self.blocked_clients[client_id]
return False
def _block_client(self, client_id: str, duration: int):
"""Temporarily block a client."""
self.blocked_clients[client_id] = time.time() + duration
logger.warning(f"Client {client_id} blocked for {duration} seconds due to rate limit violations")
def _create_rate_limit_response(self, message: str, retry_after: int = 60) -> JSONResponse:
"""Create rate limit exceeded response."""
return JSONResponse(
status_code=429,
content={
"error": {
"code": 429,
"message": message,
"type": "rate_limit_exceeded"
}
},
headers={
"Retry-After": str(retry_after),
"X-RateLimit-Limit": "Exceeded",
"X-RateLimit-Remaining": "0"
}
)
def _add_rate_limit_headers(self, response: Response, client_id: str, user_type: str):
"""Add rate limit headers to response."""
try:
general_limit = self.rate_limits[user_type]
general_key = f"{client_id}:general"
current_requests = len(self.request_counts[general_key])
remaining = max(0, general_limit["requests"] - current_requests)
response.headers["X-RateLimit-Limit"] = str(general_limit["requests"])
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Window"] = str(general_limit["window"])
# Add reset time
if self.request_counts[general_key]:
oldest_request = self.request_counts[general_key][0]
reset_time = int(oldest_request + general_limit["window"])
response.headers["X-RateLimit-Reset"] = str(reset_time)
except Exception as e:
logger.error(f"Error adding rate limit headers: {e}")
def _log_rate_limit_violation(self, request: Request, client_id: str, result: Dict):
"""Log rate limit violations for monitoring."""
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "unknown")
log_data = {
"event_type": "rate_limit_violation",
"timestamp": datetime.utcnow().isoformat(),
"client_id": client_id,
"client_ip": client_ip,
"user_agent": user_agent,
"path": request.url.path,
"method": request.method,
"current_count": result.get("current_count"),
"limit": result.get("limit"),
"window": result.get("window")
}
logger.warning(f"Rate limit violation: {log_data}")
def cleanup_old_data(self):
"""Clean up old rate limiting data (call periodically)."""
now = time.time()
cutoff = now - 3600 # Keep data for 1 hour
# Clean up request counts
for key in list(self.request_counts.keys()):
requests = self.request_counts[key]
while requests and requests[0] <= cutoff:
requests.popleft()
# Remove empty deques
if not requests:
del self.request_counts[key]
# Clean up expired blocks
expired_blocks = [
client_id for client_id, block_until in self.blocked_clients.items()
if now >= block_until
]
for client_id in expired_blocks:
del self.blocked_clients[client_id]
class AdaptiveRateLimit:
"""Adaptive rate limiting based on system load."""
def __init__(self):
self.base_limits = {}
self.current_multiplier = 1.0
self.load_history = deque(maxlen=60) # Keep 1 minute of load data
def update_system_load(self, cpu_percent: float, memory_percent: float):
"""Update system load metrics."""
load_score = (cpu_percent + memory_percent) / 2
self.load_history.append(load_score)
# Calculate adaptive multiplier
if len(self.load_history) >= 10:
avg_load = sum(self.load_history) / len(self.load_history)
if avg_load > 80:
self.current_multiplier = 0.5 # Reduce limits by 50%
elif avg_load > 60:
self.current_multiplier = 0.7 # Reduce limits by 30%
elif avg_load < 30:
self.current_multiplier = 1.2 # Increase limits by 20%
else:
self.current_multiplier = 1.0 # Normal limits
def get_adjusted_limit(self, base_limit: int) -> int:
"""Get adjusted rate limit based on system load."""
return max(1, int(base_limit * self.current_multiplier))
class RateLimitStorage:
"""Abstract interface for rate limit storage (Redis implementation)."""
async def get_count(self, key: str, window: int) -> int:
"""Get current request count for key within window."""
raise NotImplementedError
async def increment(self, key: str, window: int) -> int:
"""Increment request count and return new count."""
raise NotImplementedError
async def is_blocked(self, client_id: str) -> bool:
"""Check if client is blocked."""
raise NotImplementedError
async def block_client(self, client_id: str, duration: int):
"""Block client for duration seconds."""
raise NotImplementedError
class RedisRateLimitStorage(RateLimitStorage):
"""Redis-based rate limit storage for production use."""
def __init__(self, redis_client):
self.redis = redis_client
async def get_count(self, key: str, window: int) -> int:
"""Get current request count using Redis sliding window."""
now = time.time()
pipeline = self.redis.pipeline()
# Remove old entries
pipeline.zremrangebyscore(key, 0, now - window)
# Count current entries
pipeline.zcard(key)
results = await pipeline.execute()
return results[1]
async def increment(self, key: str, window: int) -> int:
"""Increment request count using Redis."""
now = time.time()
pipeline = self.redis.pipeline()
# Add current request
pipeline.zadd(key, {str(now): now})
# Remove old entries
pipeline.zremrangebyscore(key, 0, now - window)
# Set expiration
pipeline.expire(key, window + 1)
# Get count
pipeline.zcard(key)
results = await pipeline.execute()
return results[3]
async def is_blocked(self, client_id: str) -> bool:
"""Check if client is blocked."""
block_key = f"blocked:{client_id}"
return await self.redis.exists(block_key)
async def block_client(self, client_id: str, duration: int):
"""Block client for duration seconds."""
block_key = f"blocked:{client_id}"
await self.redis.setex(block_key, duration, "1")
# Global adaptive rate limiter instance
adaptive_rate_limit = AdaptiveRateLimit()