feat: Complete Rust port of WiFi-DensePose with modular crates
Major changes: - Organized Python v1 implementation into v1/ subdirectory - Created Rust workspace with 9 modular crates: - wifi-densepose-core: Core types, traits, errors - wifi-densepose-signal: CSI processing, phase sanitization, FFT - wifi-densepose-nn: Neural network inference (ONNX/Candle/tch) - wifi-densepose-api: Axum-based REST/WebSocket API - wifi-densepose-db: SQLx database layer - wifi-densepose-config: Configuration management - wifi-densepose-hardware: Hardware abstraction - wifi-densepose-wasm: WebAssembly bindings - wifi-densepose-cli: Command-line interface Documentation: - ADR-001: Workspace structure - ADR-002: Signal processing library selection - ADR-003: Neural network inference strategy - DDD domain model with bounded contexts Testing: - 69 tests passing across all crates - Signal processing: 45 tests - Neural networks: 21 tests - Core: 3 doc tests Performance targets: - 10x faster CSI processing (~0.5ms vs ~5ms) - 5x lower memory usage (~100MB vs ~500MB) - WASM support for browser deployment
This commit is contained in:
8
v1/src/config/__init__.py
Normal file
8
v1/src/config/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Configuration management package
|
||||
"""
|
||||
|
||||
from .settings import get_settings, Settings
|
||||
from .domains import DomainConfig, get_domain_config
|
||||
|
||||
__all__ = ["get_settings", "Settings", "DomainConfig", "get_domain_config"]
|
||||
481
v1/src/config/domains.py
Normal file
481
v1/src/config/domains.py
Normal file
@@ -0,0 +1,481 @@
|
||||
"""
|
||||
Domain-specific configuration for WiFi-DensePose
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
|
||||
class ZoneType(str, Enum):
|
||||
"""Zone types for pose detection."""
|
||||
ROOM = "room"
|
||||
HALLWAY = "hallway"
|
||||
ENTRANCE = "entrance"
|
||||
OUTDOOR = "outdoor"
|
||||
OFFICE = "office"
|
||||
MEETING_ROOM = "meeting_room"
|
||||
KITCHEN = "kitchen"
|
||||
BATHROOM = "bathroom"
|
||||
BEDROOM = "bedroom"
|
||||
LIVING_ROOM = "living_room"
|
||||
|
||||
|
||||
class ActivityType(str, Enum):
|
||||
"""Activity types for pose classification."""
|
||||
STANDING = "standing"
|
||||
SITTING = "sitting"
|
||||
WALKING = "walking"
|
||||
LYING = "lying"
|
||||
RUNNING = "running"
|
||||
JUMPING = "jumping"
|
||||
FALLING = "falling"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class HardwareType(str, Enum):
|
||||
"""Hardware types for WiFi devices."""
|
||||
ROUTER = "router"
|
||||
ACCESS_POINT = "access_point"
|
||||
REPEATER = "repeater"
|
||||
MESH_NODE = "mesh_node"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ZoneConfig:
|
||||
"""Configuration for a detection zone."""
|
||||
|
||||
zone_id: str
|
||||
name: str
|
||||
zone_type: ZoneType
|
||||
description: Optional[str] = None
|
||||
|
||||
# Physical boundaries (in meters)
|
||||
x_min: float = 0.0
|
||||
x_max: float = 10.0
|
||||
y_min: float = 0.0
|
||||
y_max: float = 10.0
|
||||
z_min: float = 0.0
|
||||
z_max: float = 3.0
|
||||
|
||||
# Detection settings
|
||||
enabled: bool = True
|
||||
confidence_threshold: float = 0.5
|
||||
max_persons: int = 5
|
||||
activity_detection: bool = True
|
||||
|
||||
# Hardware assignments
|
||||
primary_router: Optional[str] = None
|
||||
secondary_routers: List[str] = field(default_factory=list)
|
||||
|
||||
# Processing settings
|
||||
processing_interval: float = 0.1 # seconds
|
||||
data_retention_hours: int = 24
|
||||
|
||||
# Alert settings
|
||||
enable_alerts: bool = False
|
||||
alert_threshold: float = 0.8
|
||||
alert_activities: List[ActivityType] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RouterConfig:
|
||||
"""Configuration for a WiFi router/device."""
|
||||
|
||||
router_id: str
|
||||
name: str
|
||||
hardware_type: HardwareType
|
||||
|
||||
# Network settings
|
||||
ip_address: str
|
||||
mac_address: str
|
||||
interface: str = "wlan0"
|
||||
channel: int = 6
|
||||
frequency: float = 2.4 # GHz
|
||||
|
||||
# CSI settings
|
||||
csi_enabled: bool = True
|
||||
csi_rate: int = 100 # Hz
|
||||
csi_subcarriers: int = 56
|
||||
antenna_count: int = 3
|
||||
|
||||
# Position (in meters)
|
||||
x_position: float = 0.0
|
||||
y_position: float = 0.0
|
||||
z_position: float = 2.5 # typical ceiling mount
|
||||
|
||||
# Calibration
|
||||
calibrated: bool = False
|
||||
calibration_data: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Status
|
||||
enabled: bool = True
|
||||
last_seen: Optional[str] = None
|
||||
|
||||
# Performance settings
|
||||
max_connections: int = 50
|
||||
power_level: int = 20 # dBm
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"router_id": self.router_id,
|
||||
"name": self.name,
|
||||
"hardware_type": self.hardware_type.value,
|
||||
"ip_address": self.ip_address,
|
||||
"mac_address": self.mac_address,
|
||||
"interface": self.interface,
|
||||
"channel": self.channel,
|
||||
"frequency": self.frequency,
|
||||
"csi_enabled": self.csi_enabled,
|
||||
"csi_rate": self.csi_rate,
|
||||
"csi_subcarriers": self.csi_subcarriers,
|
||||
"antenna_count": self.antenna_count,
|
||||
"position": {
|
||||
"x": self.x_position,
|
||||
"y": self.y_position,
|
||||
"z": self.z_position
|
||||
},
|
||||
"calibrated": self.calibrated,
|
||||
"calibration_data": self.calibration_data,
|
||||
"enabled": self.enabled,
|
||||
"last_seen": self.last_seen,
|
||||
"max_connections": self.max_connections,
|
||||
"power_level": self.power_level
|
||||
}
|
||||
|
||||
|
||||
class PoseModelConfig(BaseModel):
|
||||
"""Configuration for pose estimation models."""
|
||||
|
||||
model_name: str = Field(..., description="Model name")
|
||||
model_path: str = Field(..., description="Path to model file")
|
||||
model_type: str = Field(default="densepose", description="Model type")
|
||||
|
||||
# Input settings
|
||||
input_width: int = Field(default=256, description="Input image width")
|
||||
input_height: int = Field(default=256, description="Input image height")
|
||||
input_channels: int = Field(default=3, description="Input channels")
|
||||
|
||||
# Processing settings
|
||||
batch_size: int = Field(default=1, description="Batch size for inference")
|
||||
confidence_threshold: float = Field(default=0.5, description="Confidence threshold")
|
||||
nms_threshold: float = Field(default=0.4, description="NMS threshold")
|
||||
|
||||
# Output settings
|
||||
max_detections: int = Field(default=10, description="Maximum detections per frame")
|
||||
keypoint_count: int = Field(default=17, description="Number of keypoints")
|
||||
|
||||
# Performance settings
|
||||
use_gpu: bool = Field(default=True, description="Use GPU acceleration")
|
||||
gpu_memory_fraction: float = Field(default=0.5, description="GPU memory fraction")
|
||||
num_threads: int = Field(default=4, description="Number of CPU threads")
|
||||
|
||||
@validator("confidence_threshold", "nms_threshold", "gpu_memory_fraction")
|
||||
def validate_thresholds(cls, v):
|
||||
"""Validate threshold values."""
|
||||
if not 0.0 <= v <= 1.0:
|
||||
raise ValueError("Threshold must be between 0.0 and 1.0")
|
||||
return v
|
||||
|
||||
|
||||
class StreamingConfig(BaseModel):
|
||||
"""Configuration for real-time streaming."""
|
||||
|
||||
# Stream settings
|
||||
fps: int = Field(default=30, description="Frames per second")
|
||||
resolution: str = Field(default="720p", description="Stream resolution")
|
||||
quality: str = Field(default="medium", description="Stream quality")
|
||||
|
||||
# Buffer settings
|
||||
buffer_size: int = Field(default=100, description="Buffer size")
|
||||
max_latency_ms: int = Field(default=100, description="Maximum latency in milliseconds")
|
||||
|
||||
# Compression settings
|
||||
compression_enabled: bool = Field(default=True, description="Enable compression")
|
||||
compression_level: int = Field(default=5, description="Compression level (1-9)")
|
||||
|
||||
# WebSocket settings
|
||||
ping_interval: int = Field(default=60, description="Ping interval in seconds")
|
||||
timeout: int = Field(default=300, description="Connection timeout in seconds")
|
||||
max_connections: int = Field(default=100, description="Maximum concurrent connections")
|
||||
|
||||
# Data filtering
|
||||
min_confidence: float = Field(default=0.5, description="Minimum confidence for streaming")
|
||||
include_metadata: bool = Field(default=True, description="Include metadata in stream")
|
||||
|
||||
@validator("fps")
|
||||
def validate_fps(cls, v):
|
||||
"""Validate FPS value."""
|
||||
if not 1 <= v <= 60:
|
||||
raise ValueError("FPS must be between 1 and 60")
|
||||
return v
|
||||
|
||||
@validator("compression_level")
|
||||
def validate_compression_level(cls, v):
|
||||
"""Validate compression level."""
|
||||
if not 1 <= v <= 9:
|
||||
raise ValueError("Compression level must be between 1 and 9")
|
||||
return v
|
||||
|
||||
|
||||
class AlertConfig(BaseModel):
|
||||
"""Configuration for alerts and notifications."""
|
||||
|
||||
# Alert types
|
||||
enable_pose_alerts: bool = Field(default=False, description="Enable pose-based alerts")
|
||||
enable_activity_alerts: bool = Field(default=False, description="Enable activity-based alerts")
|
||||
enable_zone_alerts: bool = Field(default=False, description="Enable zone-based alerts")
|
||||
enable_system_alerts: bool = Field(default=True, description="Enable system alerts")
|
||||
|
||||
# Thresholds
|
||||
confidence_threshold: float = Field(default=0.8, description="Alert confidence threshold")
|
||||
duration_threshold: int = Field(default=5, description="Alert duration threshold in seconds")
|
||||
|
||||
# Activities that trigger alerts
|
||||
alert_activities: List[ActivityType] = Field(
|
||||
default=[ActivityType.FALLING],
|
||||
description="Activities that trigger alerts"
|
||||
)
|
||||
|
||||
# Notification settings
|
||||
email_enabled: bool = Field(default=False, description="Enable email notifications")
|
||||
webhook_enabled: bool = Field(default=False, description="Enable webhook notifications")
|
||||
sms_enabled: bool = Field(default=False, description="Enable SMS notifications")
|
||||
|
||||
# Rate limiting
|
||||
max_alerts_per_hour: int = Field(default=10, description="Maximum alerts per hour")
|
||||
cooldown_minutes: int = Field(default=5, description="Cooldown between similar alerts")
|
||||
|
||||
|
||||
class DomainConfig:
|
||||
"""Main domain configuration container."""
|
||||
|
||||
def __init__(self):
|
||||
self.zones: Dict[str, ZoneConfig] = {}
|
||||
self.routers: Dict[str, RouterConfig] = {}
|
||||
self.pose_models: Dict[str, PoseModelConfig] = {}
|
||||
self.streaming = StreamingConfig()
|
||||
self.alerts = AlertConfig()
|
||||
|
||||
# Load default configurations
|
||||
self._load_defaults()
|
||||
|
||||
def _load_defaults(self):
|
||||
"""Load default configurations."""
|
||||
# Default pose model
|
||||
self.pose_models["default"] = PoseModelConfig(
|
||||
model_name="densepose_rcnn_R_50_FPN_s1x",
|
||||
model_path="./models/densepose_rcnn_R_50_FPN_s1x.pkl",
|
||||
model_type="densepose"
|
||||
)
|
||||
|
||||
# Example zone
|
||||
self.zones["living_room"] = ZoneConfig(
|
||||
zone_id="living_room",
|
||||
name="Living Room",
|
||||
zone_type=ZoneType.LIVING_ROOM,
|
||||
description="Main living area",
|
||||
x_max=5.0,
|
||||
y_max=4.0,
|
||||
z_max=3.0
|
||||
)
|
||||
|
||||
# Example router
|
||||
self.routers["main_router"] = RouterConfig(
|
||||
router_id="main_router",
|
||||
name="Main Router",
|
||||
hardware_type=HardwareType.ROUTER,
|
||||
ip_address="192.168.1.1",
|
||||
mac_address="00:11:22:33:44:55",
|
||||
x_position=2.5,
|
||||
y_position=2.0,
|
||||
z_position=2.5
|
||||
)
|
||||
|
||||
def add_zone(self, zone: ZoneConfig):
|
||||
"""Add a zone configuration."""
|
||||
self.zones[zone.zone_id] = zone
|
||||
|
||||
def add_router(self, router: RouterConfig):
|
||||
"""Add a router configuration."""
|
||||
self.routers[router.router_id] = router
|
||||
|
||||
def add_pose_model(self, model: PoseModelConfig):
|
||||
"""Add a pose model configuration."""
|
||||
self.pose_models[model.model_name] = model
|
||||
|
||||
def get_zone(self, zone_id: str) -> Optional[ZoneConfig]:
|
||||
"""Get zone configuration by ID."""
|
||||
return self.zones.get(zone_id)
|
||||
|
||||
def get_router(self, router_id: str) -> Optional[RouterConfig]:
|
||||
"""Get router configuration by ID."""
|
||||
return self.routers.get(router_id)
|
||||
|
||||
def get_pose_model(self, model_name: str) -> Optional[PoseModelConfig]:
|
||||
"""Get pose model configuration by name."""
|
||||
return self.pose_models.get(model_name)
|
||||
|
||||
def get_zones_for_router(self, router_id: str) -> List[ZoneConfig]:
|
||||
"""Get zones that use a specific router."""
|
||||
zones = []
|
||||
for zone in self.zones.values():
|
||||
if (zone.primary_router == router_id or
|
||||
router_id in zone.secondary_routers):
|
||||
zones.append(zone)
|
||||
return zones
|
||||
|
||||
def get_routers_for_zone(self, zone_id: str) -> List[RouterConfig]:
|
||||
"""Get routers assigned to a specific zone."""
|
||||
zone = self.get_zone(zone_id)
|
||||
if not zone:
|
||||
return []
|
||||
|
||||
routers = []
|
||||
|
||||
# Add primary router
|
||||
if zone.primary_router and zone.primary_router in self.routers:
|
||||
routers.append(self.routers[zone.primary_router])
|
||||
|
||||
# Add secondary routers
|
||||
for router_id in zone.secondary_routers:
|
||||
if router_id in self.routers:
|
||||
routers.append(self.routers[router_id])
|
||||
|
||||
return routers
|
||||
|
||||
def get_all_routers(self) -> List[RouterConfig]:
|
||||
"""Get all router configurations."""
|
||||
return list(self.routers.values())
|
||||
|
||||
def validate_configuration(self) -> List[str]:
|
||||
"""Validate the entire configuration."""
|
||||
issues = []
|
||||
|
||||
# Validate zones
|
||||
for zone_id, zone in self.zones.items():
|
||||
if zone.primary_router and zone.primary_router not in self.routers:
|
||||
issues.append(f"Zone {zone_id} references unknown primary router: {zone.primary_router}")
|
||||
|
||||
for router_id in zone.secondary_routers:
|
||||
if router_id not in self.routers:
|
||||
issues.append(f"Zone {zone_id} references unknown secondary router: {router_id}")
|
||||
|
||||
# Validate routers
|
||||
for router_id, router in self.routers.items():
|
||||
if not router.ip_address:
|
||||
issues.append(f"Router {router_id} missing IP address")
|
||||
|
||||
if not router.mac_address:
|
||||
issues.append(f"Router {router_id} missing MAC address")
|
||||
|
||||
# Validate pose models
|
||||
for model_name, model in self.pose_models.items():
|
||||
import os
|
||||
if not os.path.exists(model.model_path):
|
||||
issues.append(f"Pose model {model_name} file not found: {model.model_path}")
|
||||
|
||||
return issues
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert configuration to dictionary."""
|
||||
return {
|
||||
"zones": {
|
||||
zone_id: {
|
||||
"zone_id": zone.zone_id,
|
||||
"name": zone.name,
|
||||
"zone_type": zone.zone_type.value,
|
||||
"description": zone.description,
|
||||
"boundaries": {
|
||||
"x_min": zone.x_min,
|
||||
"x_max": zone.x_max,
|
||||
"y_min": zone.y_min,
|
||||
"y_max": zone.y_max,
|
||||
"z_min": zone.z_min,
|
||||
"z_max": zone.z_max
|
||||
},
|
||||
"settings": {
|
||||
"enabled": zone.enabled,
|
||||
"confidence_threshold": zone.confidence_threshold,
|
||||
"max_persons": zone.max_persons,
|
||||
"activity_detection": zone.activity_detection
|
||||
},
|
||||
"hardware": {
|
||||
"primary_router": zone.primary_router,
|
||||
"secondary_routers": zone.secondary_routers
|
||||
}
|
||||
}
|
||||
for zone_id, zone in self.zones.items()
|
||||
},
|
||||
"routers": {
|
||||
router_id: router.to_dict()
|
||||
for router_id, router in self.routers.items()
|
||||
},
|
||||
"pose_models": {
|
||||
model_name: model.dict()
|
||||
for model_name, model in self.pose_models.items()
|
||||
},
|
||||
"streaming": self.streaming.dict(),
|
||||
"alerts": self.alerts.dict()
|
||||
}
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_domain_config() -> DomainConfig:
|
||||
"""Get cached domain configuration instance."""
|
||||
return DomainConfig()
|
||||
|
||||
|
||||
def load_domain_config_from_file(file_path: str) -> DomainConfig:
|
||||
"""Load domain configuration from file."""
|
||||
import json
|
||||
|
||||
config = DomainConfig()
|
||||
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Load zones
|
||||
for zone_data in data.get("zones", []):
|
||||
zone = ZoneConfig(**zone_data)
|
||||
config.add_zone(zone)
|
||||
|
||||
# Load routers
|
||||
for router_data in data.get("routers", []):
|
||||
router = RouterConfig(**router_data)
|
||||
config.add_router(router)
|
||||
|
||||
# Load pose models
|
||||
for model_data in data.get("pose_models", []):
|
||||
model = PoseModelConfig(**model_data)
|
||||
config.add_pose_model(model)
|
||||
|
||||
# Load streaming config
|
||||
if "streaming" in data:
|
||||
config.streaming = StreamingConfig(**data["streaming"])
|
||||
|
||||
# Load alerts config
|
||||
if "alerts" in data:
|
||||
config.alerts = AlertConfig(**data["alerts"])
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load domain configuration: {e}")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def save_domain_config_to_file(config: DomainConfig, file_path: str):
|
||||
"""Save domain configuration to file."""
|
||||
import json
|
||||
|
||||
try:
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(config.to_dict(), f, indent=2)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to save domain configuration: {e}")
|
||||
435
v1/src/config/settings.py
Normal file
435
v1/src/config/settings.py
Normal file
@@ -0,0 +1,435 @@
|
||||
"""
|
||||
Pydantic settings for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional, Dict, Any
|
||||
from functools import lru_cache
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings with environment variable support."""
|
||||
|
||||
# Application settings
|
||||
app_name: str = Field(default="WiFi-DensePose API", description="Application name")
|
||||
version: str = Field(default="1.0.0", description="Application version")
|
||||
environment: str = Field(default="development", description="Environment (development, staging, production)")
|
||||
debug: bool = Field(default=False, description="Debug mode")
|
||||
|
||||
# Server settings
|
||||
host: str = Field(default="0.0.0.0", description="Server host")
|
||||
port: int = Field(default=8000, description="Server port")
|
||||
reload: bool = Field(default=False, description="Auto-reload on code changes")
|
||||
workers: int = Field(default=1, description="Number of worker processes")
|
||||
|
||||
# Security settings
|
||||
secret_key: str = Field(..., description="Secret key for JWT tokens")
|
||||
jwt_algorithm: str = Field(default="HS256", description="JWT algorithm")
|
||||
jwt_expire_hours: int = Field(default=24, description="JWT token expiration in hours")
|
||||
allowed_hosts: List[str] = Field(default=["*"], description="Allowed hosts")
|
||||
cors_origins: List[str] = Field(default=["*"], description="CORS allowed origins")
|
||||
|
||||
# Rate limiting settings
|
||||
rate_limit_requests: int = Field(default=100, description="Rate limit requests per window")
|
||||
rate_limit_authenticated_requests: int = Field(default=1000, description="Rate limit for authenticated users")
|
||||
rate_limit_window: int = Field(default=3600, description="Rate limit window in seconds")
|
||||
|
||||
# Database settings
|
||||
database_url: Optional[str] = Field(default=None, description="Database connection URL")
|
||||
database_pool_size: int = Field(default=10, description="Database connection pool size")
|
||||
database_max_overflow: int = Field(default=20, description="Database max overflow connections")
|
||||
|
||||
# Database connection pool settings (alternative naming for compatibility)
|
||||
db_pool_size: int = Field(default=10, description="Database connection pool size")
|
||||
db_max_overflow: int = Field(default=20, description="Database max overflow connections")
|
||||
db_pool_timeout: int = Field(default=30, description="Database pool timeout in seconds")
|
||||
db_pool_recycle: int = Field(default=3600, description="Database pool recycle time in seconds")
|
||||
|
||||
# Database connection settings
|
||||
db_host: Optional[str] = Field(default=None, description="Database host")
|
||||
db_port: int = Field(default=5432, description="Database port")
|
||||
db_name: Optional[str] = Field(default=None, description="Database name")
|
||||
db_user: Optional[str] = Field(default=None, description="Database user")
|
||||
db_password: Optional[str] = Field(default=None, description="Database password")
|
||||
db_echo: bool = Field(default=False, description="Enable database query logging")
|
||||
|
||||
# Redis settings (for caching and rate limiting)
|
||||
redis_url: Optional[str] = Field(default=None, description="Redis connection URL")
|
||||
redis_password: Optional[str] = Field(default=None, description="Redis password")
|
||||
redis_db: int = Field(default=0, description="Redis database number")
|
||||
redis_enabled: bool = Field(default=True, description="Enable Redis")
|
||||
redis_host: str = Field(default="localhost", description="Redis host")
|
||||
redis_port: int = Field(default=6379, description="Redis port")
|
||||
redis_required: bool = Field(default=False, description="Require Redis connection (fail if unavailable)")
|
||||
redis_max_connections: int = Field(default=10, description="Maximum Redis connections")
|
||||
redis_socket_timeout: int = Field(default=5, description="Redis socket timeout in seconds")
|
||||
redis_connect_timeout: int = Field(default=5, description="Redis connection timeout in seconds")
|
||||
|
||||
# Failsafe settings
|
||||
enable_database_failsafe: bool = Field(default=True, description="Enable automatic SQLite failsafe when PostgreSQL unavailable")
|
||||
enable_redis_failsafe: bool = Field(default=True, description="Enable automatic Redis failsafe (disable when unavailable)")
|
||||
sqlite_fallback_path: str = Field(default="./data/wifi_densepose_fallback.db", description="SQLite fallback database path")
|
||||
|
||||
# Hardware settings
|
||||
wifi_interface: str = Field(default="wlan0", description="WiFi interface name")
|
||||
csi_buffer_size: int = Field(default=1000, description="CSI data buffer size")
|
||||
hardware_polling_interval: float = Field(default=0.1, description="Hardware polling interval in seconds")
|
||||
|
||||
# CSI Processing settings
|
||||
csi_sampling_rate: int = Field(default=1000, description="CSI sampling rate")
|
||||
csi_window_size: int = Field(default=512, description="CSI window size")
|
||||
csi_overlap: float = Field(default=0.5, description="CSI window overlap")
|
||||
csi_noise_threshold: float = Field(default=0.1, description="CSI noise threshold")
|
||||
csi_human_detection_threshold: float = Field(default=0.8, description="CSI human detection threshold")
|
||||
csi_smoothing_factor: float = Field(default=0.9, description="CSI smoothing factor")
|
||||
csi_max_history_size: int = Field(default=500, description="CSI max history size")
|
||||
|
||||
# Pose estimation settings
|
||||
pose_model_path: Optional[str] = Field(default=None, description="Path to pose estimation model")
|
||||
pose_confidence_threshold: float = Field(default=0.5, description="Minimum confidence threshold")
|
||||
pose_processing_batch_size: int = Field(default=32, description="Batch size for pose processing")
|
||||
pose_max_persons: int = Field(default=10, description="Maximum persons to detect per frame")
|
||||
|
||||
# Streaming settings
|
||||
stream_fps: int = Field(default=30, description="Streaming frames per second")
|
||||
stream_buffer_size: int = Field(default=100, description="Stream buffer size")
|
||||
websocket_ping_interval: int = Field(default=60, description="WebSocket ping interval in seconds")
|
||||
websocket_timeout: int = Field(default=300, description="WebSocket timeout in seconds")
|
||||
|
||||
# Logging settings
|
||||
log_level: str = Field(default="INFO", description="Logging level")
|
||||
log_format: str = Field(
|
||||
default="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
description="Log format"
|
||||
)
|
||||
log_file: Optional[str] = Field(default=None, description="Log file path")
|
||||
log_directory: str = Field(default="./logs", description="Log directory path")
|
||||
log_max_size: int = Field(default=10485760, description="Max log file size in bytes (10MB)")
|
||||
log_backup_count: int = Field(default=5, description="Number of log backup files")
|
||||
|
||||
# Monitoring settings
|
||||
metrics_enabled: bool = Field(default=True, description="Enable metrics collection")
|
||||
health_check_interval: int = Field(default=30, description="Health check interval in seconds")
|
||||
performance_monitoring: bool = Field(default=True, description="Enable performance monitoring")
|
||||
monitoring_interval_seconds: int = Field(default=60, description="Monitoring task interval in seconds")
|
||||
cleanup_interval_seconds: int = Field(default=3600, description="Cleanup task interval in seconds")
|
||||
backup_interval_seconds: int = Field(default=86400, description="Backup task interval in seconds")
|
||||
|
||||
# Storage settings
|
||||
data_storage_path: str = Field(default="./data", description="Data storage directory")
|
||||
model_storage_path: str = Field(default="./models", description="Model storage directory")
|
||||
temp_storage_path: str = Field(default="./temp", description="Temporary storage directory")
|
||||
backup_directory: str = Field(default="./backups", description="Backup storage directory")
|
||||
max_storage_size_gb: int = Field(default=100, description="Maximum storage size in GB")
|
||||
|
||||
# API settings
|
||||
api_prefix: str = Field(default="/api/v1", description="API prefix")
|
||||
docs_url: str = Field(default="/docs", description="API documentation URL")
|
||||
redoc_url: str = Field(default="/redoc", description="ReDoc documentation URL")
|
||||
openapi_url: str = Field(default="/openapi.json", description="OpenAPI schema URL")
|
||||
|
||||
# Feature flags
|
||||
enable_authentication: bool = Field(default=True, description="Enable authentication")
|
||||
enable_rate_limiting: bool = Field(default=True, description="Enable rate limiting")
|
||||
enable_websockets: bool = Field(default=True, description="Enable WebSocket support")
|
||||
enable_historical_data: bool = Field(default=True, description="Enable historical data storage")
|
||||
enable_real_time_processing: bool = Field(default=True, description="Enable real-time processing")
|
||||
cors_enabled: bool = Field(default=True, description="Enable CORS middleware")
|
||||
cors_allow_credentials: bool = Field(default=True, description="Allow credentials in CORS")
|
||||
|
||||
# Development settings
|
||||
mock_hardware: bool = Field(default=False, description="Use mock hardware for development")
|
||||
mock_pose_data: bool = Field(default=False, description="Use mock pose data for development")
|
||||
enable_test_endpoints: bool = Field(default=False, description="Enable test endpoints")
|
||||
|
||||
# Cleanup settings
|
||||
csi_data_retention_days: int = Field(default=30, description="CSI data retention in days")
|
||||
pose_detection_retention_days: int = Field(default=30, description="Pose detection retention in days")
|
||||
metrics_retention_days: int = Field(default=7, description="Metrics retention in days")
|
||||
audit_log_retention_days: int = Field(default=90, description="Audit log retention in days")
|
||||
orphaned_session_threshold_days: int = Field(default=7, description="Orphaned session threshold in days")
|
||||
cleanup_batch_size: int = Field(default=1000, description="Cleanup batch size")
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False
|
||||
)
|
||||
|
||||
@field_validator("environment")
|
||||
@classmethod
|
||||
def validate_environment(cls, v):
|
||||
"""Validate environment setting."""
|
||||
allowed_environments = ["development", "staging", "production"]
|
||||
if v not in allowed_environments:
|
||||
raise ValueError(f"Environment must be one of: {allowed_environments}")
|
||||
return v
|
||||
|
||||
@field_validator("log_level")
|
||||
@classmethod
|
||||
def validate_log_level(cls, v):
|
||||
"""Validate log level setting."""
|
||||
allowed_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
|
||||
if v.upper() not in allowed_levels:
|
||||
raise ValueError(f"Log level must be one of: {allowed_levels}")
|
||||
return v.upper()
|
||||
|
||||
@field_validator("pose_confidence_threshold")
|
||||
@classmethod
|
||||
def validate_confidence_threshold(cls, v):
|
||||
"""Validate confidence threshold."""
|
||||
if not 0.0 <= v <= 1.0:
|
||||
raise ValueError("Confidence threshold must be between 0.0 and 1.0")
|
||||
return v
|
||||
|
||||
@field_validator("stream_fps")
|
||||
@classmethod
|
||||
def validate_stream_fps(cls, v):
|
||||
"""Validate streaming FPS."""
|
||||
if not 1 <= v <= 60:
|
||||
raise ValueError("Stream FPS must be between 1 and 60")
|
||||
return v
|
||||
|
||||
@field_validator("port")
|
||||
@classmethod
|
||||
def validate_port(cls, v):
|
||||
"""Validate port number."""
|
||||
if not 1 <= v <= 65535:
|
||||
raise ValueError("Port must be between 1 and 65535")
|
||||
return v
|
||||
|
||||
@field_validator("workers")
|
||||
@classmethod
|
||||
def validate_workers(cls, v):
|
||||
"""Validate worker count."""
|
||||
if v < 1:
|
||||
raise ValueError("Workers must be at least 1")
|
||||
return v
|
||||
|
||||
@field_validator("db_port")
|
||||
@classmethod
|
||||
def validate_db_port(cls, v):
|
||||
"""Validate database port."""
|
||||
if not 1 <= v <= 65535:
|
||||
raise ValueError("Database port must be between 1 and 65535")
|
||||
return v
|
||||
|
||||
@field_validator("redis_port")
|
||||
@classmethod
|
||||
def validate_redis_port(cls, v):
|
||||
"""Validate Redis port."""
|
||||
if not 1 <= v <= 65535:
|
||||
raise ValueError("Redis port must be between 1 and 65535")
|
||||
return v
|
||||
|
||||
@field_validator("db_pool_size")
|
||||
@classmethod
|
||||
def validate_db_pool_size(cls, v):
|
||||
"""Validate database pool size."""
|
||||
if v < 1:
|
||||
raise ValueError("Database pool size must be at least 1")
|
||||
return v
|
||||
|
||||
@field_validator("monitoring_interval_seconds", "cleanup_interval_seconds", "backup_interval_seconds")
|
||||
@classmethod
|
||||
def validate_interval_seconds(cls, v):
|
||||
"""Validate interval settings."""
|
||||
if v < 0:
|
||||
raise ValueError("Interval seconds must be non-negative")
|
||||
return v
|
||||
@property
|
||||
def is_development(self) -> bool:
|
||||
"""Check if running in development environment."""
|
||||
return self.environment == "development"
|
||||
|
||||
@property
|
||||
def is_production(self) -> bool:
|
||||
"""Check if running in production environment."""
|
||||
return self.environment == "production"
|
||||
|
||||
@property
|
||||
def is_testing(self) -> bool:
|
||||
"""Check if running in testing environment."""
|
||||
return self.environment == "testing"
|
||||
|
||||
def get_database_url(self) -> str:
|
||||
"""Get database URL with fallback."""
|
||||
if self.database_url:
|
||||
return self.database_url
|
||||
|
||||
# Build URL from individual components if available
|
||||
if self.db_host and self.db_name and self.db_user:
|
||||
password_part = f":{self.db_password}" if self.db_password else ""
|
||||
return f"postgresql://{self.db_user}{password_part}@{self.db_host}:{self.db_port}/{self.db_name}"
|
||||
|
||||
# Default SQLite database for development
|
||||
if self.is_development:
|
||||
return f"sqlite:///{self.data_storage_path}/wifi_densepose.db"
|
||||
|
||||
# SQLite failsafe for production if enabled
|
||||
if self.enable_database_failsafe:
|
||||
return f"sqlite:///{self.sqlite_fallback_path}"
|
||||
|
||||
raise ValueError("Database URL must be configured for non-development environments")
|
||||
|
||||
def get_sqlite_fallback_url(self) -> str:
|
||||
"""Get SQLite fallback database URL."""
|
||||
return f"sqlite:///{self.sqlite_fallback_path}"
|
||||
|
||||
def get_redis_url(self) -> Optional[str]:
|
||||
"""Get Redis URL with fallback."""
|
||||
if not self.redis_enabled:
|
||||
return None
|
||||
|
||||
if self.redis_url:
|
||||
return self.redis_url
|
||||
|
||||
# Build URL from individual components
|
||||
password_part = f":{self.redis_password}@" if self.redis_password else ""
|
||||
return f"redis://{password_part}{self.redis_host}:{self.redis_port}/{self.redis_db}"
|
||||
|
||||
def get_cors_config(self) -> Dict[str, Any]:
|
||||
"""Get CORS configuration."""
|
||||
if self.is_development:
|
||||
return {
|
||||
"allow_origins": ["*"],
|
||||
"allow_credentials": True,
|
||||
"allow_methods": ["*"],
|
||||
"allow_headers": ["*"],
|
||||
}
|
||||
|
||||
return {
|
||||
"allow_origins": self.cors_origins,
|
||||
"allow_credentials": True,
|
||||
"allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
"allow_headers": ["Authorization", "Content-Type"],
|
||||
}
|
||||
|
||||
def get_logging_config(self) -> Dict[str, Any]:
|
||||
"""Get logging configuration."""
|
||||
config = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"default": {
|
||||
"format": self.log_format,
|
||||
},
|
||||
"detailed": {
|
||||
"format": "%(asctime)s - %(name)s - %(levelname)s - %(module)s:%(lineno)d - %(message)s",
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": self.log_level,
|
||||
"formatter": "default",
|
||||
"stream": "ext://sys.stdout",
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"": {
|
||||
"level": self.log_level,
|
||||
"handlers": ["console"],
|
||||
},
|
||||
"uvicorn": {
|
||||
"level": "INFO",
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
},
|
||||
"fastapi": {
|
||||
"level": "INFO",
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Add file handler if log file is specified
|
||||
if self.log_file:
|
||||
config["handlers"]["file"] = {
|
||||
"class": "logging.handlers.RotatingFileHandler",
|
||||
"level": self.log_level,
|
||||
"formatter": "detailed",
|
||||
"filename": self.log_file,
|
||||
"maxBytes": self.log_max_size,
|
||||
"backupCount": self.log_backup_count,
|
||||
}
|
||||
|
||||
# Add file handler to all loggers
|
||||
for logger_config in config["loggers"].values():
|
||||
logger_config["handlers"].append("file")
|
||||
|
||||
return config
|
||||
|
||||
def create_directories(self):
|
||||
"""Create necessary directories."""
|
||||
directories = [
|
||||
self.data_storage_path,
|
||||
self.model_storage_path,
|
||||
self.temp_storage_path,
|
||||
self.log_directory,
|
||||
self.backup_directory,
|
||||
]
|
||||
|
||||
for directory in directories:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
"""Get cached settings instance."""
|
||||
settings = Settings()
|
||||
settings.create_directories()
|
||||
return settings
|
||||
|
||||
|
||||
def get_test_settings() -> Settings:
|
||||
"""Get settings for testing."""
|
||||
return Settings(
|
||||
environment="testing",
|
||||
debug=True,
|
||||
secret_key="test-secret-key",
|
||||
database_url="sqlite:///:memory:",
|
||||
mock_hardware=True,
|
||||
mock_pose_data=True,
|
||||
enable_test_endpoints=True,
|
||||
log_level="DEBUG"
|
||||
)
|
||||
|
||||
|
||||
def load_settings_from_file(file_path: str) -> Settings:
|
||||
"""Load settings from a specific file."""
|
||||
return Settings(_env_file=file_path)
|
||||
|
||||
|
||||
def validate_settings(settings: Settings) -> List[str]:
|
||||
"""Validate settings and return list of issues."""
|
||||
issues = []
|
||||
|
||||
# Check required settings for production
|
||||
if settings.is_production:
|
||||
if not settings.secret_key or settings.secret_key == "change-me":
|
||||
issues.append("Secret key must be set for production")
|
||||
|
||||
if not settings.database_url and not (settings.db_host and settings.db_name and settings.db_user):
|
||||
issues.append("Database URL or database connection parameters must be set for production")
|
||||
|
||||
if settings.debug:
|
||||
issues.append("Debug mode should be disabled in production")
|
||||
|
||||
if "*" in settings.allowed_hosts:
|
||||
issues.append("Allowed hosts should be restricted in production")
|
||||
|
||||
if "*" in settings.cors_origins:
|
||||
issues.append("CORS origins should be restricted in production")
|
||||
|
||||
# Check storage paths exist
|
||||
try:
|
||||
settings.create_directories()
|
||||
except Exception as e:
|
||||
issues.append(f"Cannot create storage directories: {e}")
|
||||
|
||||
return issues
|
||||
Reference in New Issue
Block a user