Major changes: - Organized Python v1 implementation into v1/ subdirectory - Created Rust workspace with 9 modular crates: - wifi-densepose-core: Core types, traits, errors - wifi-densepose-signal: CSI processing, phase sanitization, FFT - wifi-densepose-nn: Neural network inference (ONNX/Candle/tch) - wifi-densepose-api: Axum-based REST/WebSocket API - wifi-densepose-db: SQLx database layer - wifi-densepose-config: Configuration management - wifi-densepose-hardware: Hardware abstraction - wifi-densepose-wasm: WebAssembly bindings - wifi-densepose-cli: Command-line interface Documentation: - ADR-001: Workspace structure - ADR-002: Signal processing library selection - ADR-003: Neural network inference strategy - DDD domain model with bounded contexts Testing: - 69 tests passing across all crates - Signal processing: 45 tests - Neural networks: 21 tests - Core: 3 doc tests Performance targets: - 10x faster CSI processing (~0.5ms vs ~5ms) - 5x lower memory usage (~100MB vs ~500MB) - WASM support for browser deployment
322 lines
11 KiB
Python
322 lines
11 KiB
Python
"""
|
|
JWT Authentication middleware for WiFi-DensePose API
|
|
"""
|
|
|
|
import logging
|
|
from typing import Optional, Dict, Any
|
|
from datetime import datetime
|
|
|
|
from fastapi import Request, Response
|
|
from fastapi.responses import JSONResponse
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from jose import JWTError, jwt
|
|
|
|
from src.config.settings import get_settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AuthMiddleware(BaseHTTPMiddleware):
|
|
"""JWT Authentication middleware."""
|
|
|
|
def __init__(self, app):
|
|
super().__init__(app)
|
|
self.settings = get_settings()
|
|
|
|
# Paths that don't require authentication
|
|
self.public_paths = {
|
|
"/",
|
|
"/docs",
|
|
"/redoc",
|
|
"/openapi.json",
|
|
"/health",
|
|
"/ready",
|
|
"/live",
|
|
"/version",
|
|
"/metrics"
|
|
}
|
|
|
|
# Paths that require authentication
|
|
self.protected_paths = {
|
|
"/api/v1/pose/analyze",
|
|
"/api/v1/pose/calibrate",
|
|
"/api/v1/pose/historical",
|
|
"/api/v1/stream/start",
|
|
"/api/v1/stream/stop",
|
|
"/api/v1/stream/clients",
|
|
"/api/v1/stream/broadcast"
|
|
}
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
"""Process request through authentication middleware."""
|
|
|
|
# Skip authentication for public paths
|
|
if self._is_public_path(request.url.path):
|
|
return await call_next(request)
|
|
|
|
# Extract and validate token
|
|
token = self._extract_token(request)
|
|
|
|
if token:
|
|
try:
|
|
# Verify token and add user info to request state
|
|
user_data = await self._verify_token(token)
|
|
request.state.user = user_data
|
|
request.state.authenticated = True
|
|
|
|
logger.debug(f"Authenticated user: {user_data.get('id')}")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Token validation failed: {e}")
|
|
|
|
# For protected paths, return 401
|
|
if self._is_protected_path(request.url.path):
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={
|
|
"error": {
|
|
"code": 401,
|
|
"message": "Invalid or expired token",
|
|
"type": "authentication_error"
|
|
}
|
|
}
|
|
)
|
|
|
|
# For other paths, continue without authentication
|
|
request.state.user = None
|
|
request.state.authenticated = False
|
|
else:
|
|
# No token provided
|
|
if self._is_protected_path(request.url.path):
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={
|
|
"error": {
|
|
"code": 401,
|
|
"message": "Authentication required",
|
|
"type": "authentication_error"
|
|
}
|
|
},
|
|
headers={"WWW-Authenticate": "Bearer"}
|
|
)
|
|
|
|
request.state.user = None
|
|
request.state.authenticated = False
|
|
|
|
# Continue with request processing
|
|
response = await call_next(request)
|
|
|
|
# Add authentication headers to response
|
|
if hasattr(request.state, 'user') and request.state.user:
|
|
response.headers["X-User-ID"] = request.state.user.get("id", "")
|
|
response.headers["X-Authenticated"] = "true"
|
|
else:
|
|
response.headers["X-Authenticated"] = "false"
|
|
|
|
return response
|
|
|
|
def _is_public_path(self, path: str) -> bool:
|
|
"""Check if path is public (doesn't require authentication)."""
|
|
# Exact match
|
|
if path in self.public_paths:
|
|
return True
|
|
|
|
# Pattern matching for public paths
|
|
public_patterns = [
|
|
"/health",
|
|
"/metrics",
|
|
"/api/v1/pose/current", # Allow anonymous access to current pose data
|
|
"/api/v1/pose/zones/", # Allow anonymous access to zone data
|
|
"/api/v1/pose/activities", # Allow anonymous access to activities
|
|
"/api/v1/pose/stats", # Allow anonymous access to stats
|
|
"/api/v1/stream/status" # Allow anonymous access to stream status
|
|
]
|
|
|
|
for pattern in public_patterns:
|
|
if path.startswith(pattern):
|
|
return True
|
|
|
|
return False
|
|
|
|
def _is_protected_path(self, path: str) -> bool:
|
|
"""Check if path requires authentication."""
|
|
# Exact match
|
|
if path in self.protected_paths:
|
|
return True
|
|
|
|
# Pattern matching for protected paths
|
|
protected_patterns = [
|
|
"/api/v1/pose/analyze",
|
|
"/api/v1/pose/calibrate",
|
|
"/api/v1/pose/historical",
|
|
"/api/v1/stream/start",
|
|
"/api/v1/stream/stop",
|
|
"/api/v1/stream/clients",
|
|
"/api/v1/stream/broadcast"
|
|
]
|
|
|
|
for pattern in protected_patterns:
|
|
if path.startswith(pattern):
|
|
return True
|
|
|
|
return False
|
|
|
|
def _extract_token(self, request: Request) -> Optional[str]:
|
|
"""Extract JWT token from request."""
|
|
# Check Authorization header
|
|
auth_header = request.headers.get("authorization")
|
|
if auth_header and auth_header.startswith("Bearer "):
|
|
return auth_header.split(" ")[1]
|
|
|
|
# Check query parameter (for WebSocket connections)
|
|
token = request.query_params.get("token")
|
|
if token:
|
|
return token
|
|
|
|
# Check cookie
|
|
token = request.cookies.get("access_token")
|
|
if token:
|
|
return token
|
|
|
|
return None
|
|
|
|
async def _verify_token(self, token: str) -> Dict[str, Any]:
|
|
"""Verify JWT token and return user data."""
|
|
try:
|
|
# Decode JWT token
|
|
payload = jwt.decode(
|
|
token,
|
|
self.settings.secret_key,
|
|
algorithms=[self.settings.jwt_algorithm]
|
|
)
|
|
|
|
# Extract user information
|
|
user_id = payload.get("sub")
|
|
if not user_id:
|
|
raise ValueError("Token missing user ID")
|
|
|
|
# Check token expiration
|
|
exp = payload.get("exp")
|
|
if exp and datetime.utcnow() > datetime.fromtimestamp(exp):
|
|
raise ValueError("Token expired")
|
|
|
|
# Build user object
|
|
user_data = {
|
|
"id": user_id,
|
|
"username": payload.get("username"),
|
|
"email": payload.get("email"),
|
|
"is_admin": payload.get("is_admin", False),
|
|
"permissions": payload.get("permissions", []),
|
|
"accessible_zones": payload.get("accessible_zones", []),
|
|
"token_issued_at": payload.get("iat"),
|
|
"token_expires_at": payload.get("exp"),
|
|
"session_id": payload.get("session_id")
|
|
}
|
|
|
|
return user_data
|
|
|
|
except JWTError as e:
|
|
raise ValueError(f"JWT validation failed: {e}")
|
|
except Exception as e:
|
|
raise ValueError(f"Token verification error: {e}")
|
|
|
|
def _log_authentication_event(self, request: Request, event_type: str, details: Dict[str, Any] = None):
|
|
"""Log authentication events for security monitoring."""
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
user_agent = request.headers.get("user-agent", "unknown")
|
|
|
|
log_data = {
|
|
"event_type": event_type,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"client_ip": client_ip,
|
|
"user_agent": user_agent,
|
|
"path": request.url.path,
|
|
"method": request.method
|
|
}
|
|
|
|
if details:
|
|
log_data.update(details)
|
|
|
|
if event_type in ["authentication_failed", "token_expired", "invalid_token"]:
|
|
logger.warning(f"Auth event: {log_data}")
|
|
else:
|
|
logger.info(f"Auth event: {log_data}")
|
|
|
|
|
|
class TokenBlacklist:
|
|
"""Simple in-memory token blacklist for logout functionality."""
|
|
|
|
def __init__(self):
|
|
self._blacklisted_tokens = set()
|
|
self._cleanup_interval = 3600 # 1 hour
|
|
self._last_cleanup = datetime.utcnow()
|
|
|
|
def add_token(self, token: str):
|
|
"""Add token to blacklist."""
|
|
self._blacklisted_tokens.add(token)
|
|
self._cleanup_if_needed()
|
|
|
|
def is_blacklisted(self, token: str) -> bool:
|
|
"""Check if token is blacklisted."""
|
|
self._cleanup_if_needed()
|
|
return token in self._blacklisted_tokens
|
|
|
|
def _cleanup_if_needed(self):
|
|
"""Clean up expired tokens from blacklist."""
|
|
now = datetime.utcnow()
|
|
if (now - self._last_cleanup).total_seconds() > self._cleanup_interval:
|
|
# In a real implementation, you would check token expiration
|
|
# For now, we'll just clear old tokens periodically
|
|
self._blacklisted_tokens.clear()
|
|
self._last_cleanup = now
|
|
|
|
|
|
# Global token blacklist instance
|
|
token_blacklist = TokenBlacklist()
|
|
|
|
|
|
class SecurityHeaders:
|
|
"""Security headers for API responses."""
|
|
|
|
@staticmethod
|
|
def add_security_headers(response: Response) -> Response:
|
|
"""Add security headers to response."""
|
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
|
response.headers["X-Frame-Options"] = "DENY"
|
|
response.headers["X-XSS-Protection"] = "1; mode=block"
|
|
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
|
response.headers["Content-Security-Policy"] = (
|
|
"default-src 'self'; "
|
|
"script-src 'self' 'unsafe-inline'; "
|
|
"style-src 'self' 'unsafe-inline'; "
|
|
"img-src 'self' data:; "
|
|
"connect-src 'self' ws: wss:;"
|
|
)
|
|
|
|
return response
|
|
|
|
|
|
class APIKeyAuth:
|
|
"""Alternative API key authentication for service-to-service communication."""
|
|
|
|
def __init__(self, api_keys: Dict[str, Dict[str, Any]] = None):
|
|
self.api_keys = api_keys or {}
|
|
|
|
def verify_api_key(self, api_key: str) -> Optional[Dict[str, Any]]:
|
|
"""Verify API key and return associated service info."""
|
|
if api_key in self.api_keys:
|
|
return self.api_keys[api_key]
|
|
return None
|
|
|
|
def add_api_key(self, api_key: str, service_info: Dict[str, Any]):
|
|
"""Add new API key."""
|
|
self.api_keys[api_key] = service_info
|
|
|
|
def revoke_api_key(self, api_key: str):
|
|
"""Revoke API key."""
|
|
if api_key in self.api_keys:
|
|
del self.api_keys[api_key]
|
|
|
|
|
|
# Global API key auth instance
|
|
api_key_auth = APIKeyAuth() |