security: Fix 10 vulnerabilities, remove 12 dead code instances
Critical fixes: - Remove hardcoded admin/admin123 credentials from UserManager - Enable JWT signature verification (was disabled for debugging) - Redact secrets from /dev/config endpoint (was exposing os.environ) - Remove hardcoded SSH admin/admin credentials from hardware service - Add channel validation to prevent command injection in router interface Rust fixes: - Replace partial_cmp().unwrap() with .unwrap_or(Equal) to prevent NaN panics in 6 locations across core, signal, nn, mat crates - Replace .expect()/.unwrap() with safe fallbacks in utils, csi_receiver - Replace SystemTime unwrap with unwrap_or_default Dead code removed: - Duplicate imports (CORSMiddleware, os, Path, ABC, subprocess) - Unused AdaptiveRateLimit/RateLimitStorage/RedisRateLimitStorage (~110 lines) - Unused _log_authentication_event method - Unused Confidence::new_unchecked in Rust - Fix bare except: clause to except Exception: https://claude.ai/code/session_01Ki7pvEZtJDvqJkmyn6B714
This commit is contained in:
@@ -380,10 +380,19 @@ if settings.metrics_enabled:
|
||||
if settings.is_development and settings.enable_test_endpoints:
|
||||
@app.get(f"{settings.api_prefix}/dev/config")
|
||||
async def dev_config():
|
||||
"""Get current configuration (development only)."""
|
||||
"""Get current configuration (development only).
|
||||
|
||||
Returns a sanitized view -- secret keys and passwords are redacted.
|
||||
"""
|
||||
_sensitive = {"secret", "password", "token", "key", "credential", "auth"}
|
||||
raw = settings.dict()
|
||||
sanitized = {
|
||||
k: "***REDACTED***" if any(s in k.lower() for s in _sensitive) else v
|
||||
for k, v in raw.items()
|
||||
}
|
||||
domain_config = get_domain_config()
|
||||
return {
|
||||
"settings": settings.dict(),
|
||||
"settings": sanitized,
|
||||
"domain_config": domain_config.to_dict()
|
||||
}
|
||||
|
||||
|
||||
@@ -220,27 +220,8 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
except Exception as e:
|
||||
raise ValueError(f"Token verification error: {e}")
|
||||
|
||||
def _log_authentication_event(self, request: Request, event_type: str, details: Dict[str, Any] = None):
|
||||
"""Log authentication events for security monitoring."""
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
user_agent = request.headers.get("user-agent", "unknown")
|
||||
|
||||
log_data = {
|
||||
"event_type": event_type,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"client_ip": client_ip,
|
||||
"user_agent": user_agent,
|
||||
"path": request.url.path,
|
||||
"method": request.method
|
||||
}
|
||||
|
||||
if details:
|
||||
log_data.update(details)
|
||||
|
||||
if event_type in ["authentication_failed", "token_expired", "invalid_token"]:
|
||||
logger.warning(f"Auth event: {log_data}")
|
||||
else:
|
||||
logger.info(f"Auth event: {log_data}")
|
||||
# TODO: Wire up authentication event logging in dispatch() for
|
||||
# security monitoring (login failures, token expiry, etc.).
|
||||
|
||||
|
||||
class TokenBlacklist:
|
||||
|
||||
@@ -323,107 +323,3 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
del self.blocked_clients[client_id]
|
||||
|
||||
|
||||
class AdaptiveRateLimit:
|
||||
"""Adaptive rate limiting based on system load."""
|
||||
|
||||
def __init__(self):
|
||||
self.base_limits = {}
|
||||
self.current_multiplier = 1.0
|
||||
self.load_history = deque(maxlen=60) # Keep 1 minute of load data
|
||||
|
||||
def update_system_load(self, cpu_percent: float, memory_percent: float):
|
||||
"""Update system load metrics."""
|
||||
load_score = (cpu_percent + memory_percent) / 2
|
||||
self.load_history.append(load_score)
|
||||
|
||||
# Calculate adaptive multiplier
|
||||
if len(self.load_history) >= 10:
|
||||
avg_load = sum(self.load_history) / len(self.load_history)
|
||||
|
||||
if avg_load > 80:
|
||||
self.current_multiplier = 0.5 # Reduce limits by 50%
|
||||
elif avg_load > 60:
|
||||
self.current_multiplier = 0.7 # Reduce limits by 30%
|
||||
elif avg_load < 30:
|
||||
self.current_multiplier = 1.2 # Increase limits by 20%
|
||||
else:
|
||||
self.current_multiplier = 1.0 # Normal limits
|
||||
|
||||
def get_adjusted_limit(self, base_limit: int) -> int:
|
||||
"""Get adjusted rate limit based on system load."""
|
||||
return max(1, int(base_limit * self.current_multiplier))
|
||||
|
||||
|
||||
class RateLimitStorage:
|
||||
"""Abstract interface for rate limit storage (Redis implementation)."""
|
||||
|
||||
async def get_count(self, key: str, window: int) -> int:
|
||||
"""Get current request count for key within window."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def increment(self, key: str, window: int) -> int:
|
||||
"""Increment request count and return new count."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def is_blocked(self, client_id: str) -> bool:
|
||||
"""Check if client is blocked."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def block_client(self, client_id: str, duration: int):
|
||||
"""Block client for duration seconds."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RedisRateLimitStorage(RateLimitStorage):
|
||||
"""Redis-based rate limit storage for production use."""
|
||||
|
||||
def __init__(self, redis_client):
|
||||
self.redis = redis_client
|
||||
|
||||
async def get_count(self, key: str, window: int) -> int:
|
||||
"""Get current request count using Redis sliding window."""
|
||||
now = time.time()
|
||||
pipeline = self.redis.pipeline()
|
||||
|
||||
# Remove old entries
|
||||
pipeline.zremrangebyscore(key, 0, now - window)
|
||||
|
||||
# Count current entries
|
||||
pipeline.zcard(key)
|
||||
|
||||
results = await pipeline.execute()
|
||||
return results[1]
|
||||
|
||||
async def increment(self, key: str, window: int) -> int:
|
||||
"""Increment request count using Redis."""
|
||||
now = time.time()
|
||||
pipeline = self.redis.pipeline()
|
||||
|
||||
# Add current request
|
||||
pipeline.zadd(key, {str(now): now})
|
||||
|
||||
# Remove old entries
|
||||
pipeline.zremrangebyscore(key, 0, now - window)
|
||||
|
||||
# Set expiration
|
||||
pipeline.expire(key, window + 1)
|
||||
|
||||
# Get count
|
||||
pipeline.zcard(key)
|
||||
|
||||
results = await pipeline.execute()
|
||||
return results[3]
|
||||
|
||||
async def is_blocked(self, client_id: str) -> bool:
|
||||
"""Check if client is blocked."""
|
||||
block_key = f"blocked:{client_id}"
|
||||
return await self.redis.exists(block_key)
|
||||
|
||||
async def block_client(self, client_id: str, duration: int):
|
||||
"""Block client for duration seconds."""
|
||||
block_key = f"blocked:{client_id}"
|
||||
await self.redis.setex(block_key, duration, "1")
|
||||
|
||||
|
||||
# Global adaptive rate limiter instance
|
||||
adaptive_rate_limit = AdaptiveRateLimit()
|
||||
@@ -17,16 +17,13 @@ from src.api.dependencies import (
|
||||
get_current_user_ws,
|
||||
require_auth
|
||||
)
|
||||
from src.api.websocket.connection_manager import ConnectionManager
|
||||
from src.api.websocket.connection_manager import connection_manager
|
||||
from src.services.stream_service import StreamService
|
||||
from src.services.pose_service import PoseService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
# Initialize connection manager
|
||||
connection_manager = ConnectionManager()
|
||||
|
||||
|
||||
# Request/Response models
|
||||
class StreamSubscriptionRequest(BaseModel):
|
||||
|
||||
@@ -181,7 +181,7 @@ class ConnectionManager:
|
||||
if connection.is_active:
|
||||
try:
|
||||
await connection.websocket.close()
|
||||
except:
|
||||
except Exception:
|
||||
pass # Connection might already be closed
|
||||
|
||||
# Remove connection
|
||||
|
||||
@@ -3,7 +3,6 @@ FastAPI application factory and configuration
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
@@ -17,7 +16,6 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from src.config.settings import Settings
|
||||
from src.services.orchestrator import ServiceOrchestrator
|
||||
from src.middleware.auth import AuthenticationMiddleware
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from src.middleware.rate_limit import RateLimitMiddleware
|
||||
from src.middleware.error_handler import ErrorHandlingMiddleware
|
||||
from src.api.routers import pose, stream, health
|
||||
@@ -294,10 +292,21 @@ def setup_root_endpoints(app: FastAPI, settings: Settings):
|
||||
if settings.is_development and settings.enable_test_endpoints:
|
||||
@app.get(f"{settings.api_prefix}/dev/config")
|
||||
async def dev_config():
|
||||
"""Get current configuration (development only)."""
|
||||
"""Get current configuration (development only).
|
||||
|
||||
Returns a sanitized view of settings. Secret keys,
|
||||
passwords, and raw environment variables are never exposed.
|
||||
"""
|
||||
# Build a sanitized copy -- redact any key that looks secret
|
||||
_sensitive = {"secret", "password", "token", "key", "credential", "auth"}
|
||||
raw = settings.dict()
|
||||
sanitized = {
|
||||
k: "***REDACTED***" if any(s in k.lower() for s in _sensitive) else v
|
||||
for k, v in raw.items()
|
||||
}
|
||||
return {
|
||||
"settings": settings.dict(),
|
||||
"environment_variables": dict(os.environ)
|
||||
"settings": sanitized,
|
||||
"environment": settings.environment,
|
||||
}
|
||||
|
||||
@app.post(f"{settings.api_prefix}/dev/reset")
|
||||
|
||||
@@ -5,7 +5,6 @@ Command-line interface for WiFi-DensePose API
|
||||
import asyncio
|
||||
import click
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from src.config.settings import get_settings, load_settings_from_file
|
||||
|
||||
@@ -77,6 +77,8 @@ class Settings(BaseSettings):
|
||||
wifi_interface: str = Field(default="wlan0", description="WiFi interface name")
|
||||
csi_buffer_size: int = Field(default=1000, description="CSI data buffer size")
|
||||
hardware_polling_interval: float = Field(default=0.1, description="Hardware polling interval in seconds")
|
||||
router_ssh_username: str = Field(default="admin", description="Default SSH username for router connections")
|
||||
router_ssh_password: str = Field(default="", description="Default SSH password for router connections (set via ROUTER_SSH_PASSWORD env var)")
|
||||
|
||||
# CSI Processing settings
|
||||
csi_sampling_rate: int = Field(default=1000, description="CSI sampling rate")
|
||||
|
||||
@@ -5,7 +5,6 @@ import numpy as np
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Any, Optional, Callable, Protocol
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
|
||||
|
||||
|
||||
@@ -175,6 +175,9 @@ class RouterInterface:
|
||||
"""
|
||||
try:
|
||||
channel = config.get('channel', 6)
|
||||
# Validate channel is an integer in a safe range to prevent command injection
|
||||
if not isinstance(channel, int) or not (1 <= channel <= 196):
|
||||
raise ValueError(f"Invalid WiFi channel: {channel}. Must be an integer between 1 and 196.")
|
||||
command = f"iwconfig wlan0 channel {channel} && echo 'CSI monitoring configured'"
|
||||
await self.execute_command(command)
|
||||
return True
|
||||
|
||||
@@ -61,10 +61,16 @@ class TokenManager:
|
||||
logger.warning(f"JWT verification failed: {e}")
|
||||
raise AuthenticationError("Invalid token")
|
||||
|
||||
def decode_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""Decode token without verification (for debugging)."""
|
||||
def decode_token_claims(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""Decode and verify token, returning its claims.
|
||||
|
||||
Unlike the previous implementation, this method always verifies
|
||||
the token signature. Use verify_token() for full validation
|
||||
including expiry checks; this helper is provided only for
|
||||
inspecting claims from an already-verified token.
|
||||
"""
|
||||
try:
|
||||
return jwt.decode(token, options={"verify_signature": False})
|
||||
return jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
@@ -73,26 +79,10 @@ class UserManager:
|
||||
"""User management for authentication."""
|
||||
|
||||
def __init__(self):
|
||||
# In a real application, this would connect to a database
|
||||
# For now, we'll use a simple in-memory store
|
||||
self._users: Dict[str, Dict[str, Any]] = {
|
||||
"admin": {
|
||||
"username": "admin",
|
||||
"email": "admin@example.com",
|
||||
"hashed_password": self.hash_password("admin123"),
|
||||
"roles": ["admin"],
|
||||
"is_active": True,
|
||||
"created_at": datetime.utcnow(),
|
||||
},
|
||||
"user": {
|
||||
"username": "user",
|
||||
"email": "user@example.com",
|
||||
"hashed_password": self.hash_password("user123"),
|
||||
"roles": ["user"],
|
||||
"is_active": True,
|
||||
"created_at": datetime.utcnow(),
|
||||
}
|
||||
}
|
||||
# In a real application, this would connect to a database.
|
||||
# No default users are created -- users must be provisioned
|
||||
# through the create_user() method or an external identity provider.
|
||||
self._users: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
@staticmethod
|
||||
def hash_password(password: str) -> str:
|
||||
|
||||
@@ -121,9 +121,9 @@ class HardwareService:
|
||||
router_interface = RouterInterface(
|
||||
router_id=router_id,
|
||||
host=router_config.ip_address,
|
||||
port=22, # Default SSH port
|
||||
username="admin", # Default username
|
||||
password="admin", # Default password
|
||||
port=getattr(router_config, 'ssh_port', 22),
|
||||
username=getattr(router_config, 'ssh_username', None) or self.settings.router_ssh_username,
|
||||
password=getattr(router_config, 'ssh_password', None) or self.settings.router_ssh_password,
|
||||
interface=router_config.interface,
|
||||
mock_mode=self.settings.mock_hardware
|
||||
)
|
||||
|
||||
@@ -8,11 +8,9 @@ import os
|
||||
import shutil
|
||||
import gzip
|
||||
import json
|
||||
import subprocess
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -8,9 +8,8 @@ import psutil
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, Optional, List
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from sqlalchemy import select, func, and_, or_
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.config.settings import Settings
|
||||
|
||||
Reference in New Issue
Block a user