security: Fix 10 vulnerabilities, remove 12 dead code instances
Critical fixes: - Remove hardcoded admin/admin123 credentials from UserManager - Enable JWT signature verification (was disabled for debugging) - Redact secrets from /dev/config endpoint (was exposing os.environ) - Remove hardcoded SSH admin/admin credentials from hardware service - Add channel validation to prevent command injection in router interface Rust fixes: - Replace partial_cmp().unwrap() with .unwrap_or(Equal) to prevent NaN panics in 6 locations across core, signal, nn, mat crates - Replace .expect()/.unwrap() with safe fallbacks in utils, csi_receiver - Replace SystemTime unwrap with unwrap_or_default Dead code removed: - Duplicate imports (CORSMiddleware, os, Path, ABC, subprocess) - Unused AdaptiveRateLimit/RateLimitStorage/RedisRateLimitStorage (~110 lines) - Unused _log_authentication_event method - Unused Confidence::new_unchecked in Rust - Fix bare except: clause to except Exception: https://claude.ai/code/session_01Ki7pvEZtJDvqJkmyn6B714
This commit is contained in:
@@ -172,16 +172,6 @@ impl Confidence {
|
||||
|
||||
/// Creates a confidence value without validation (for internal use).
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The caller must ensure the value is in [0.0, 1.0].
|
||||
#[must_use]
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn new_unchecked(value: f32) -> Self {
|
||||
debug_assert!((0.0..=1.0).contains(&value));
|
||||
Self(value)
|
||||
}
|
||||
|
||||
/// Returns the raw confidence value.
|
||||
#[must_use]
|
||||
pub fn value(&self) -> f32 {
|
||||
@@ -1009,7 +999,12 @@ impl PoseEstimate {
|
||||
pub fn highest_confidence_person(&self) -> Option<&PersonPose> {
|
||||
self.persons
|
||||
.iter()
|
||||
.max_by(|a, b| a.confidence.value().partial_cmp(&b.confidence.value()).unwrap())
|
||||
.max_by(|a, b| {
|
||||
a.confidence
|
||||
.value()
|
||||
.partial_cmp(&b.confidence.value())
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -98,8 +98,11 @@ pub fn moving_average(data: &Array1<f64>, window_size: usize) -> Array1<f64> {
|
||||
let mut result = Array1::zeros(data.len());
|
||||
let half_window = window_size / 2;
|
||||
|
||||
// Safe unwrap: ndarray Array1 is always contiguous
|
||||
let slice = data.as_slice().expect("Array1 should be contiguous");
|
||||
// ndarray Array1 is always contiguous, but handle gracefully if not
|
||||
let slice = match data.as_slice() {
|
||||
Some(s) => s,
|
||||
None => return data.clone(),
|
||||
};
|
||||
|
||||
for i in 0..data.len() {
|
||||
let start = i.saturating_sub(half_window);
|
||||
|
||||
@@ -310,7 +310,8 @@ impl DisasterEvent {
|
||||
// Create new survivor
|
||||
let survivor = Survivor::new(zone_id, vitals, location);
|
||||
self.survivors.push(survivor);
|
||||
Ok(self.survivors.last().unwrap())
|
||||
// Safe: we just pushed, so last() is always Some
|
||||
Ok(self.survivors.last().expect("survivors is non-empty after push"))
|
||||
}
|
||||
|
||||
/// Find a survivor near a location
|
||||
|
||||
@@ -701,8 +701,14 @@ impl PcapCsiReader {
|
||||
};
|
||||
|
||||
if pcap_config.playback_speed > 0.0 {
|
||||
let packet_offset = packet.timestamp - self.start_time.unwrap();
|
||||
let real_offset = Utc::now() - self.playback_time.unwrap();
|
||||
let Some(start_time) = self.start_time else {
|
||||
return Ok(None);
|
||||
};
|
||||
let Some(playback_time) = self.playback_time else {
|
||||
return Ok(None);
|
||||
};
|
||||
let packet_offset = packet.timestamp - start_time;
|
||||
let real_offset = Utc::now() - playback_time;
|
||||
let scaled_offset = packet_offset
|
||||
.num_milliseconds()
|
||||
.checked_div((pcap_config.playback_speed * 1000.0) as i64)
|
||||
|
||||
@@ -805,7 +805,7 @@ impl HardwareAdapter {
|
||||
let num_subcarriers = config.channel_config.num_subcarriers;
|
||||
let t = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.unwrap_or_default()
|
||||
.as_secs_f64();
|
||||
|
||||
// Generate simulated breathing pattern (~0.3 Hz)
|
||||
@@ -1102,7 +1102,7 @@ fn rand_simple() -> f64 {
|
||||
use std::time::SystemTime;
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.unwrap_or_default()
|
||||
.subsec_nanos();
|
||||
(nanos % 1000) as f64 / 1000.0 - 0.5
|
||||
}
|
||||
|
||||
@@ -164,7 +164,7 @@ impl DebrisClassification {
|
||||
pub fn new(probabilities: Vec<f32>) -> Self {
|
||||
let (max_idx, &max_prob) = probabilities.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.unwrap_or((7, &0.0));
|
||||
|
||||
// Check for composite materials (multiple high probabilities)
|
||||
@@ -216,7 +216,7 @@ impl DebrisClassification {
|
||||
self.class_probabilities.iter()
|
||||
.enumerate()
|
||||
.filter(|(i, _)| *i != primary_idx)
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.map(|(i, _)| MaterialType::from_index(i))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -593,7 +593,7 @@ impl VitalSignsClassifier {
|
||||
.enumerate()
|
||||
.skip(1) // Skip DC
|
||||
.take(30) // Up to ~30% of Nyquist
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.unwrap_or((0, &0.0));
|
||||
|
||||
// Store dominant frequency in last position
|
||||
|
||||
@@ -285,7 +285,7 @@ impl Tensor {
|
||||
let result = a.map_axis(ndarray::Axis(axis), |row| {
|
||||
row.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.map(|(i, _)| i as i64)
|
||||
.unwrap_or(0)
|
||||
});
|
||||
|
||||
@@ -490,7 +490,9 @@ impl PowerSpectralDensity {
|
||||
let peak_idx = positive_psd
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a): &(usize, &f64), (_, b): &(usize, &f64)| a.partial_cmp(b).unwrap())
|
||||
.max_by(|(_, a): &(usize, &f64), (_, b): &(usize, &f64)| {
|
||||
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
|
||||
})
|
||||
.map(|(i, _)| i)
|
||||
.unwrap_or(0);
|
||||
let peak_frequency = positive_freq[peak_idx];
|
||||
|
||||
@@ -380,10 +380,19 @@ if settings.metrics_enabled:
|
||||
if settings.is_development and settings.enable_test_endpoints:
|
||||
@app.get(f"{settings.api_prefix}/dev/config")
|
||||
async def dev_config():
|
||||
"""Get current configuration (development only)."""
|
||||
"""Get current configuration (development only).
|
||||
|
||||
Returns a sanitized view -- secret keys and passwords are redacted.
|
||||
"""
|
||||
_sensitive = {"secret", "password", "token", "key", "credential", "auth"}
|
||||
raw = settings.dict()
|
||||
sanitized = {
|
||||
k: "***REDACTED***" if any(s in k.lower() for s in _sensitive) else v
|
||||
for k, v in raw.items()
|
||||
}
|
||||
domain_config = get_domain_config()
|
||||
return {
|
||||
"settings": settings.dict(),
|
||||
"settings": sanitized,
|
||||
"domain_config": domain_config.to_dict()
|
||||
}
|
||||
|
||||
|
||||
@@ -220,27 +220,8 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
except Exception as e:
|
||||
raise ValueError(f"Token verification error: {e}")
|
||||
|
||||
def _log_authentication_event(self, request: Request, event_type: str, details: Dict[str, Any] = None):
|
||||
"""Log authentication events for security monitoring."""
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
user_agent = request.headers.get("user-agent", "unknown")
|
||||
|
||||
log_data = {
|
||||
"event_type": event_type,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"client_ip": client_ip,
|
||||
"user_agent": user_agent,
|
||||
"path": request.url.path,
|
||||
"method": request.method
|
||||
}
|
||||
|
||||
if details:
|
||||
log_data.update(details)
|
||||
|
||||
if event_type in ["authentication_failed", "token_expired", "invalid_token"]:
|
||||
logger.warning(f"Auth event: {log_data}")
|
||||
else:
|
||||
logger.info(f"Auth event: {log_data}")
|
||||
# TODO: Wire up authentication event logging in dispatch() for
|
||||
# security monitoring (login failures, token expiry, etc.).
|
||||
|
||||
|
||||
class TokenBlacklist:
|
||||
|
||||
@@ -323,107 +323,3 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
del self.blocked_clients[client_id]
|
||||
|
||||
|
||||
class AdaptiveRateLimit:
|
||||
"""Adaptive rate limiting based on system load."""
|
||||
|
||||
def __init__(self):
|
||||
self.base_limits = {}
|
||||
self.current_multiplier = 1.0
|
||||
self.load_history = deque(maxlen=60) # Keep 1 minute of load data
|
||||
|
||||
def update_system_load(self, cpu_percent: float, memory_percent: float):
|
||||
"""Update system load metrics."""
|
||||
load_score = (cpu_percent + memory_percent) / 2
|
||||
self.load_history.append(load_score)
|
||||
|
||||
# Calculate adaptive multiplier
|
||||
if len(self.load_history) >= 10:
|
||||
avg_load = sum(self.load_history) / len(self.load_history)
|
||||
|
||||
if avg_load > 80:
|
||||
self.current_multiplier = 0.5 # Reduce limits by 50%
|
||||
elif avg_load > 60:
|
||||
self.current_multiplier = 0.7 # Reduce limits by 30%
|
||||
elif avg_load < 30:
|
||||
self.current_multiplier = 1.2 # Increase limits by 20%
|
||||
else:
|
||||
self.current_multiplier = 1.0 # Normal limits
|
||||
|
||||
def get_adjusted_limit(self, base_limit: int) -> int:
|
||||
"""Get adjusted rate limit based on system load."""
|
||||
return max(1, int(base_limit * self.current_multiplier))
|
||||
|
||||
|
||||
class RateLimitStorage:
|
||||
"""Abstract interface for rate limit storage (Redis implementation)."""
|
||||
|
||||
async def get_count(self, key: str, window: int) -> int:
|
||||
"""Get current request count for key within window."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def increment(self, key: str, window: int) -> int:
|
||||
"""Increment request count and return new count."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def is_blocked(self, client_id: str) -> bool:
|
||||
"""Check if client is blocked."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def block_client(self, client_id: str, duration: int):
|
||||
"""Block client for duration seconds."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RedisRateLimitStorage(RateLimitStorage):
|
||||
"""Redis-based rate limit storage for production use."""
|
||||
|
||||
def __init__(self, redis_client):
|
||||
self.redis = redis_client
|
||||
|
||||
async def get_count(self, key: str, window: int) -> int:
|
||||
"""Get current request count using Redis sliding window."""
|
||||
now = time.time()
|
||||
pipeline = self.redis.pipeline()
|
||||
|
||||
# Remove old entries
|
||||
pipeline.zremrangebyscore(key, 0, now - window)
|
||||
|
||||
# Count current entries
|
||||
pipeline.zcard(key)
|
||||
|
||||
results = await pipeline.execute()
|
||||
return results[1]
|
||||
|
||||
async def increment(self, key: str, window: int) -> int:
|
||||
"""Increment request count using Redis."""
|
||||
now = time.time()
|
||||
pipeline = self.redis.pipeline()
|
||||
|
||||
# Add current request
|
||||
pipeline.zadd(key, {str(now): now})
|
||||
|
||||
# Remove old entries
|
||||
pipeline.zremrangebyscore(key, 0, now - window)
|
||||
|
||||
# Set expiration
|
||||
pipeline.expire(key, window + 1)
|
||||
|
||||
# Get count
|
||||
pipeline.zcard(key)
|
||||
|
||||
results = await pipeline.execute()
|
||||
return results[3]
|
||||
|
||||
async def is_blocked(self, client_id: str) -> bool:
|
||||
"""Check if client is blocked."""
|
||||
block_key = f"blocked:{client_id}"
|
||||
return await self.redis.exists(block_key)
|
||||
|
||||
async def block_client(self, client_id: str, duration: int):
|
||||
"""Block client for duration seconds."""
|
||||
block_key = f"blocked:{client_id}"
|
||||
await self.redis.setex(block_key, duration, "1")
|
||||
|
||||
|
||||
# Global adaptive rate limiter instance
|
||||
adaptive_rate_limit = AdaptiveRateLimit()
|
||||
@@ -17,16 +17,13 @@ from src.api.dependencies import (
|
||||
get_current_user_ws,
|
||||
require_auth
|
||||
)
|
||||
from src.api.websocket.connection_manager import ConnectionManager
|
||||
from src.api.websocket.connection_manager import connection_manager
|
||||
from src.services.stream_service import StreamService
|
||||
from src.services.pose_service import PoseService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
# Initialize connection manager
|
||||
connection_manager = ConnectionManager()
|
||||
|
||||
|
||||
# Request/Response models
|
||||
class StreamSubscriptionRequest(BaseModel):
|
||||
|
||||
@@ -181,7 +181,7 @@ class ConnectionManager:
|
||||
if connection.is_active:
|
||||
try:
|
||||
await connection.websocket.close()
|
||||
except:
|
||||
except Exception:
|
||||
pass # Connection might already be closed
|
||||
|
||||
# Remove connection
|
||||
|
||||
@@ -3,7 +3,6 @@ FastAPI application factory and configuration
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
@@ -17,7 +16,6 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from src.config.settings import Settings
|
||||
from src.services.orchestrator import ServiceOrchestrator
|
||||
from src.middleware.auth import AuthenticationMiddleware
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from src.middleware.rate_limit import RateLimitMiddleware
|
||||
from src.middleware.error_handler import ErrorHandlingMiddleware
|
||||
from src.api.routers import pose, stream, health
|
||||
@@ -294,10 +292,21 @@ def setup_root_endpoints(app: FastAPI, settings: Settings):
|
||||
if settings.is_development and settings.enable_test_endpoints:
|
||||
@app.get(f"{settings.api_prefix}/dev/config")
|
||||
async def dev_config():
|
||||
"""Get current configuration (development only)."""
|
||||
"""Get current configuration (development only).
|
||||
|
||||
Returns a sanitized view of settings. Secret keys,
|
||||
passwords, and raw environment variables are never exposed.
|
||||
"""
|
||||
# Build a sanitized copy -- redact any key that looks secret
|
||||
_sensitive = {"secret", "password", "token", "key", "credential", "auth"}
|
||||
raw = settings.dict()
|
||||
sanitized = {
|
||||
k: "***REDACTED***" if any(s in k.lower() for s in _sensitive) else v
|
||||
for k, v in raw.items()
|
||||
}
|
||||
return {
|
||||
"settings": settings.dict(),
|
||||
"environment_variables": dict(os.environ)
|
||||
"settings": sanitized,
|
||||
"environment": settings.environment,
|
||||
}
|
||||
|
||||
@app.post(f"{settings.api_prefix}/dev/reset")
|
||||
|
||||
@@ -5,7 +5,6 @@ Command-line interface for WiFi-DensePose API
|
||||
import asyncio
|
||||
import click
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from src.config.settings import get_settings, load_settings_from_file
|
||||
|
||||
@@ -77,6 +77,8 @@ class Settings(BaseSettings):
|
||||
wifi_interface: str = Field(default="wlan0", description="WiFi interface name")
|
||||
csi_buffer_size: int = Field(default=1000, description="CSI data buffer size")
|
||||
hardware_polling_interval: float = Field(default=0.1, description="Hardware polling interval in seconds")
|
||||
router_ssh_username: str = Field(default="admin", description="Default SSH username for router connections")
|
||||
router_ssh_password: str = Field(default="", description="Default SSH password for router connections (set via ROUTER_SSH_PASSWORD env var)")
|
||||
|
||||
# CSI Processing settings
|
||||
csi_sampling_rate: int = Field(default=1000, description="CSI sampling rate")
|
||||
|
||||
@@ -5,7 +5,6 @@ import numpy as np
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Any, Optional, Callable, Protocol
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
|
||||
|
||||
|
||||
@@ -175,6 +175,9 @@ class RouterInterface:
|
||||
"""
|
||||
try:
|
||||
channel = config.get('channel', 6)
|
||||
# Validate channel is an integer in a safe range to prevent command injection
|
||||
if not isinstance(channel, int) or not (1 <= channel <= 196):
|
||||
raise ValueError(f"Invalid WiFi channel: {channel}. Must be an integer between 1 and 196.")
|
||||
command = f"iwconfig wlan0 channel {channel} && echo 'CSI monitoring configured'"
|
||||
await self.execute_command(command)
|
||||
return True
|
||||
|
||||
@@ -61,10 +61,16 @@ class TokenManager:
|
||||
logger.warning(f"JWT verification failed: {e}")
|
||||
raise AuthenticationError("Invalid token")
|
||||
|
||||
def decode_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""Decode token without verification (for debugging)."""
|
||||
def decode_token_claims(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""Decode and verify token, returning its claims.
|
||||
|
||||
Unlike the previous implementation, this method always verifies
|
||||
the token signature. Use verify_token() for full validation
|
||||
including expiry checks; this helper is provided only for
|
||||
inspecting claims from an already-verified token.
|
||||
"""
|
||||
try:
|
||||
return jwt.decode(token, options={"verify_signature": False})
|
||||
return jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
@@ -73,26 +79,10 @@ class UserManager:
|
||||
"""User management for authentication."""
|
||||
|
||||
def __init__(self):
|
||||
# In a real application, this would connect to a database
|
||||
# For now, we'll use a simple in-memory store
|
||||
self._users: Dict[str, Dict[str, Any]] = {
|
||||
"admin": {
|
||||
"username": "admin",
|
||||
"email": "admin@example.com",
|
||||
"hashed_password": self.hash_password("admin123"),
|
||||
"roles": ["admin"],
|
||||
"is_active": True,
|
||||
"created_at": datetime.utcnow(),
|
||||
},
|
||||
"user": {
|
||||
"username": "user",
|
||||
"email": "user@example.com",
|
||||
"hashed_password": self.hash_password("user123"),
|
||||
"roles": ["user"],
|
||||
"is_active": True,
|
||||
"created_at": datetime.utcnow(),
|
||||
}
|
||||
}
|
||||
# In a real application, this would connect to a database.
|
||||
# No default users are created -- users must be provisioned
|
||||
# through the create_user() method or an external identity provider.
|
||||
self._users: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
@staticmethod
|
||||
def hash_password(password: str) -> str:
|
||||
|
||||
@@ -121,9 +121,9 @@ class HardwareService:
|
||||
router_interface = RouterInterface(
|
||||
router_id=router_id,
|
||||
host=router_config.ip_address,
|
||||
port=22, # Default SSH port
|
||||
username="admin", # Default username
|
||||
password="admin", # Default password
|
||||
port=getattr(router_config, 'ssh_port', 22),
|
||||
username=getattr(router_config, 'ssh_username', None) or self.settings.router_ssh_username,
|
||||
password=getattr(router_config, 'ssh_password', None) or self.settings.router_ssh_password,
|
||||
interface=router_config.interface,
|
||||
mock_mode=self.settings.mock_hardware
|
||||
)
|
||||
|
||||
@@ -8,11 +8,9 @@ import os
|
||||
import shutil
|
||||
import gzip
|
||||
import json
|
||||
import subprocess
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -8,9 +8,8 @@ import psutil
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, Optional, List
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from sqlalchemy import select, func, and_, or_
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.config.settings import Settings
|
||||
|
||||
Reference in New Issue
Block a user