feat: Complete Rust port of WiFi-DensePose with modular crates

Major changes:
- Organized Python v1 implementation into v1/ subdirectory
- Created Rust workspace with 9 modular crates:
  - wifi-densepose-core: Core types, traits, errors
  - wifi-densepose-signal: CSI processing, phase sanitization, FFT
  - wifi-densepose-nn: Neural network inference (ONNX/Candle/tch)
  - wifi-densepose-api: Axum-based REST/WebSocket API
  - wifi-densepose-db: SQLx database layer
  - wifi-densepose-config: Configuration management
  - wifi-densepose-hardware: Hardware abstraction
  - wifi-densepose-wasm: WebAssembly bindings
  - wifi-densepose-cli: Command-line interface

Documentation:
- ADR-001: Workspace structure
- ADR-002: Signal processing library selection
- ADR-003: Neural network inference strategy
- DDD domain model with bounded contexts

Testing:
- 69 tests passing across all crates
- Signal processing: 45 tests
- Neural networks: 21 tests
- Core: 3 doc tests

Performance targets:
- 10x faster CSI processing (~0.5ms vs ~5ms)
- 5x lower memory usage (~100MB vs ~500MB)
- WASM support for browser deployment
This commit is contained in:
Claude
2026-01-13 03:11:16 +00:00
parent 5101504b72
commit 6ed69a3d48
427 changed files with 90993 additions and 0 deletions

View File

@@ -0,0 +1,8 @@
"""
Configuration management package
"""
from .settings import get_settings, Settings
from .domains import DomainConfig, get_domain_config
__all__ = ["get_settings", "Settings", "DomainConfig", "get_domain_config"]

481
v1/src/config/domains.py Normal file
View File

@@ -0,0 +1,481 @@
"""
Domain-specific configuration for WiFi-DensePose
"""
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
from enum import Enum
from functools import lru_cache
from pydantic import BaseModel, Field, validator
class ZoneType(str, Enum):
"""Zone types for pose detection."""
ROOM = "room"
HALLWAY = "hallway"
ENTRANCE = "entrance"
OUTDOOR = "outdoor"
OFFICE = "office"
MEETING_ROOM = "meeting_room"
KITCHEN = "kitchen"
BATHROOM = "bathroom"
BEDROOM = "bedroom"
LIVING_ROOM = "living_room"
class ActivityType(str, Enum):
"""Activity types for pose classification."""
STANDING = "standing"
SITTING = "sitting"
WALKING = "walking"
LYING = "lying"
RUNNING = "running"
JUMPING = "jumping"
FALLING = "falling"
UNKNOWN = "unknown"
class HardwareType(str, Enum):
"""Hardware types for WiFi devices."""
ROUTER = "router"
ACCESS_POINT = "access_point"
REPEATER = "repeater"
MESH_NODE = "mesh_node"
CUSTOM = "custom"
@dataclass
class ZoneConfig:
"""Configuration for a detection zone."""
zone_id: str
name: str
zone_type: ZoneType
description: Optional[str] = None
# Physical boundaries (in meters)
x_min: float = 0.0
x_max: float = 10.0
y_min: float = 0.0
y_max: float = 10.0
z_min: float = 0.0
z_max: float = 3.0
# Detection settings
enabled: bool = True
confidence_threshold: float = 0.5
max_persons: int = 5
activity_detection: bool = True
# Hardware assignments
primary_router: Optional[str] = None
secondary_routers: List[str] = field(default_factory=list)
# Processing settings
processing_interval: float = 0.1 # seconds
data_retention_hours: int = 24
# Alert settings
enable_alerts: bool = False
alert_threshold: float = 0.8
alert_activities: List[ActivityType] = field(default_factory=list)
@dataclass
class RouterConfig:
"""Configuration for a WiFi router/device."""
router_id: str
name: str
hardware_type: HardwareType
# Network settings
ip_address: str
mac_address: str
interface: str = "wlan0"
channel: int = 6
frequency: float = 2.4 # GHz
# CSI settings
csi_enabled: bool = True
csi_rate: int = 100 # Hz
csi_subcarriers: int = 56
antenna_count: int = 3
# Position (in meters)
x_position: float = 0.0
y_position: float = 0.0
z_position: float = 2.5 # typical ceiling mount
# Calibration
calibrated: bool = False
calibration_data: Optional[Dict[str, Any]] = None
# Status
enabled: bool = True
last_seen: Optional[str] = None
# Performance settings
max_connections: int = 50
power_level: int = 20 # dBm
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
"router_id": self.router_id,
"name": self.name,
"hardware_type": self.hardware_type.value,
"ip_address": self.ip_address,
"mac_address": self.mac_address,
"interface": self.interface,
"channel": self.channel,
"frequency": self.frequency,
"csi_enabled": self.csi_enabled,
"csi_rate": self.csi_rate,
"csi_subcarriers": self.csi_subcarriers,
"antenna_count": self.antenna_count,
"position": {
"x": self.x_position,
"y": self.y_position,
"z": self.z_position
},
"calibrated": self.calibrated,
"calibration_data": self.calibration_data,
"enabled": self.enabled,
"last_seen": self.last_seen,
"max_connections": self.max_connections,
"power_level": self.power_level
}
class PoseModelConfig(BaseModel):
"""Configuration for pose estimation models."""
model_name: str = Field(..., description="Model name")
model_path: str = Field(..., description="Path to model file")
model_type: str = Field(default="densepose", description="Model type")
# Input settings
input_width: int = Field(default=256, description="Input image width")
input_height: int = Field(default=256, description="Input image height")
input_channels: int = Field(default=3, description="Input channels")
# Processing settings
batch_size: int = Field(default=1, description="Batch size for inference")
confidence_threshold: float = Field(default=0.5, description="Confidence threshold")
nms_threshold: float = Field(default=0.4, description="NMS threshold")
# Output settings
max_detections: int = Field(default=10, description="Maximum detections per frame")
keypoint_count: int = Field(default=17, description="Number of keypoints")
# Performance settings
use_gpu: bool = Field(default=True, description="Use GPU acceleration")
gpu_memory_fraction: float = Field(default=0.5, description="GPU memory fraction")
num_threads: int = Field(default=4, description="Number of CPU threads")
@validator("confidence_threshold", "nms_threshold", "gpu_memory_fraction")
def validate_thresholds(cls, v):
"""Validate threshold values."""
if not 0.0 <= v <= 1.0:
raise ValueError("Threshold must be between 0.0 and 1.0")
return v
class StreamingConfig(BaseModel):
"""Configuration for real-time streaming."""
# Stream settings
fps: int = Field(default=30, description="Frames per second")
resolution: str = Field(default="720p", description="Stream resolution")
quality: str = Field(default="medium", description="Stream quality")
# Buffer settings
buffer_size: int = Field(default=100, description="Buffer size")
max_latency_ms: int = Field(default=100, description="Maximum latency in milliseconds")
# Compression settings
compression_enabled: bool = Field(default=True, description="Enable compression")
compression_level: int = Field(default=5, description="Compression level (1-9)")
# WebSocket settings
ping_interval: int = Field(default=60, description="Ping interval in seconds")
timeout: int = Field(default=300, description="Connection timeout in seconds")
max_connections: int = Field(default=100, description="Maximum concurrent connections")
# Data filtering
min_confidence: float = Field(default=0.5, description="Minimum confidence for streaming")
include_metadata: bool = Field(default=True, description="Include metadata in stream")
@validator("fps")
def validate_fps(cls, v):
"""Validate FPS value."""
if not 1 <= v <= 60:
raise ValueError("FPS must be between 1 and 60")
return v
@validator("compression_level")
def validate_compression_level(cls, v):
"""Validate compression level."""
if not 1 <= v <= 9:
raise ValueError("Compression level must be between 1 and 9")
return v
class AlertConfig(BaseModel):
"""Configuration for alerts and notifications."""
# Alert types
enable_pose_alerts: bool = Field(default=False, description="Enable pose-based alerts")
enable_activity_alerts: bool = Field(default=False, description="Enable activity-based alerts")
enable_zone_alerts: bool = Field(default=False, description="Enable zone-based alerts")
enable_system_alerts: bool = Field(default=True, description="Enable system alerts")
# Thresholds
confidence_threshold: float = Field(default=0.8, description="Alert confidence threshold")
duration_threshold: int = Field(default=5, description="Alert duration threshold in seconds")
# Activities that trigger alerts
alert_activities: List[ActivityType] = Field(
default=[ActivityType.FALLING],
description="Activities that trigger alerts"
)
# Notification settings
email_enabled: bool = Field(default=False, description="Enable email notifications")
webhook_enabled: bool = Field(default=False, description="Enable webhook notifications")
sms_enabled: bool = Field(default=False, description="Enable SMS notifications")
# Rate limiting
max_alerts_per_hour: int = Field(default=10, description="Maximum alerts per hour")
cooldown_minutes: int = Field(default=5, description="Cooldown between similar alerts")
class DomainConfig:
"""Main domain configuration container."""
def __init__(self):
self.zones: Dict[str, ZoneConfig] = {}
self.routers: Dict[str, RouterConfig] = {}
self.pose_models: Dict[str, PoseModelConfig] = {}
self.streaming = StreamingConfig()
self.alerts = AlertConfig()
# Load default configurations
self._load_defaults()
def _load_defaults(self):
"""Load default configurations."""
# Default pose model
self.pose_models["default"] = PoseModelConfig(
model_name="densepose_rcnn_R_50_FPN_s1x",
model_path="./models/densepose_rcnn_R_50_FPN_s1x.pkl",
model_type="densepose"
)
# Example zone
self.zones["living_room"] = ZoneConfig(
zone_id="living_room",
name="Living Room",
zone_type=ZoneType.LIVING_ROOM,
description="Main living area",
x_max=5.0,
y_max=4.0,
z_max=3.0
)
# Example router
self.routers["main_router"] = RouterConfig(
router_id="main_router",
name="Main Router",
hardware_type=HardwareType.ROUTER,
ip_address="192.168.1.1",
mac_address="00:11:22:33:44:55",
x_position=2.5,
y_position=2.0,
z_position=2.5
)
def add_zone(self, zone: ZoneConfig):
"""Add a zone configuration."""
self.zones[zone.zone_id] = zone
def add_router(self, router: RouterConfig):
"""Add a router configuration."""
self.routers[router.router_id] = router
def add_pose_model(self, model: PoseModelConfig):
"""Add a pose model configuration."""
self.pose_models[model.model_name] = model
def get_zone(self, zone_id: str) -> Optional[ZoneConfig]:
"""Get zone configuration by ID."""
return self.zones.get(zone_id)
def get_router(self, router_id: str) -> Optional[RouterConfig]:
"""Get router configuration by ID."""
return self.routers.get(router_id)
def get_pose_model(self, model_name: str) -> Optional[PoseModelConfig]:
"""Get pose model configuration by name."""
return self.pose_models.get(model_name)
def get_zones_for_router(self, router_id: str) -> List[ZoneConfig]:
"""Get zones that use a specific router."""
zones = []
for zone in self.zones.values():
if (zone.primary_router == router_id or
router_id in zone.secondary_routers):
zones.append(zone)
return zones
def get_routers_for_zone(self, zone_id: str) -> List[RouterConfig]:
"""Get routers assigned to a specific zone."""
zone = self.get_zone(zone_id)
if not zone:
return []
routers = []
# Add primary router
if zone.primary_router and zone.primary_router in self.routers:
routers.append(self.routers[zone.primary_router])
# Add secondary routers
for router_id in zone.secondary_routers:
if router_id in self.routers:
routers.append(self.routers[router_id])
return routers
def get_all_routers(self) -> List[RouterConfig]:
"""Get all router configurations."""
return list(self.routers.values())
def validate_configuration(self) -> List[str]:
"""Validate the entire configuration."""
issues = []
# Validate zones
for zone_id, zone in self.zones.items():
if zone.primary_router and zone.primary_router not in self.routers:
issues.append(f"Zone {zone_id} references unknown primary router: {zone.primary_router}")
for router_id in zone.secondary_routers:
if router_id not in self.routers:
issues.append(f"Zone {zone_id} references unknown secondary router: {router_id}")
# Validate routers
for router_id, router in self.routers.items():
if not router.ip_address:
issues.append(f"Router {router_id} missing IP address")
if not router.mac_address:
issues.append(f"Router {router_id} missing MAC address")
# Validate pose models
for model_name, model in self.pose_models.items():
import os
if not os.path.exists(model.model_path):
issues.append(f"Pose model {model_name} file not found: {model.model_path}")
return issues
def to_dict(self) -> Dict[str, Any]:
"""Convert configuration to dictionary."""
return {
"zones": {
zone_id: {
"zone_id": zone.zone_id,
"name": zone.name,
"zone_type": zone.zone_type.value,
"description": zone.description,
"boundaries": {
"x_min": zone.x_min,
"x_max": zone.x_max,
"y_min": zone.y_min,
"y_max": zone.y_max,
"z_min": zone.z_min,
"z_max": zone.z_max
},
"settings": {
"enabled": zone.enabled,
"confidence_threshold": zone.confidence_threshold,
"max_persons": zone.max_persons,
"activity_detection": zone.activity_detection
},
"hardware": {
"primary_router": zone.primary_router,
"secondary_routers": zone.secondary_routers
}
}
for zone_id, zone in self.zones.items()
},
"routers": {
router_id: router.to_dict()
for router_id, router in self.routers.items()
},
"pose_models": {
model_name: model.dict()
for model_name, model in self.pose_models.items()
},
"streaming": self.streaming.dict(),
"alerts": self.alerts.dict()
}
@lru_cache()
def get_domain_config() -> DomainConfig:
"""Get cached domain configuration instance."""
return DomainConfig()
def load_domain_config_from_file(file_path: str) -> DomainConfig:
"""Load domain configuration from file."""
import json
config = DomainConfig()
try:
with open(file_path, 'r') as f:
data = json.load(f)
# Load zones
for zone_data in data.get("zones", []):
zone = ZoneConfig(**zone_data)
config.add_zone(zone)
# Load routers
for router_data in data.get("routers", []):
router = RouterConfig(**router_data)
config.add_router(router)
# Load pose models
for model_data in data.get("pose_models", []):
model = PoseModelConfig(**model_data)
config.add_pose_model(model)
# Load streaming config
if "streaming" in data:
config.streaming = StreamingConfig(**data["streaming"])
# Load alerts config
if "alerts" in data:
config.alerts = AlertConfig(**data["alerts"])
except Exception as e:
raise ValueError(f"Failed to load domain configuration: {e}")
return config
def save_domain_config_to_file(config: DomainConfig, file_path: str):
"""Save domain configuration to file."""
import json
try:
with open(file_path, 'w') as f:
json.dump(config.to_dict(), f, indent=2)
except Exception as e:
raise ValueError(f"Failed to save domain configuration: {e}")

435
v1/src/config/settings.py Normal file
View File

@@ -0,0 +1,435 @@
"""
Pydantic settings for WiFi-DensePose API
"""
import os
from typing import List, Optional, Dict, Any
from functools import lru_cache
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Application settings with environment variable support."""
# Application settings
app_name: str = Field(default="WiFi-DensePose API", description="Application name")
version: str = Field(default="1.0.0", description="Application version")
environment: str = Field(default="development", description="Environment (development, staging, production)")
debug: bool = Field(default=False, description="Debug mode")
# Server settings
host: str = Field(default="0.0.0.0", description="Server host")
port: int = Field(default=8000, description="Server port")
reload: bool = Field(default=False, description="Auto-reload on code changes")
workers: int = Field(default=1, description="Number of worker processes")
# Security settings
secret_key: str = Field(..., description="Secret key for JWT tokens")
jwt_algorithm: str = Field(default="HS256", description="JWT algorithm")
jwt_expire_hours: int = Field(default=24, description="JWT token expiration in hours")
allowed_hosts: List[str] = Field(default=["*"], description="Allowed hosts")
cors_origins: List[str] = Field(default=["*"], description="CORS allowed origins")
# Rate limiting settings
rate_limit_requests: int = Field(default=100, description="Rate limit requests per window")
rate_limit_authenticated_requests: int = Field(default=1000, description="Rate limit for authenticated users")
rate_limit_window: int = Field(default=3600, description="Rate limit window in seconds")
# Database settings
database_url: Optional[str] = Field(default=None, description="Database connection URL")
database_pool_size: int = Field(default=10, description="Database connection pool size")
database_max_overflow: int = Field(default=20, description="Database max overflow connections")
# Database connection pool settings (alternative naming for compatibility)
db_pool_size: int = Field(default=10, description="Database connection pool size")
db_max_overflow: int = Field(default=20, description="Database max overflow connections")
db_pool_timeout: int = Field(default=30, description="Database pool timeout in seconds")
db_pool_recycle: int = Field(default=3600, description="Database pool recycle time in seconds")
# Database connection settings
db_host: Optional[str] = Field(default=None, description="Database host")
db_port: int = Field(default=5432, description="Database port")
db_name: Optional[str] = Field(default=None, description="Database name")
db_user: Optional[str] = Field(default=None, description="Database user")
db_password: Optional[str] = Field(default=None, description="Database password")
db_echo: bool = Field(default=False, description="Enable database query logging")
# Redis settings (for caching and rate limiting)
redis_url: Optional[str] = Field(default=None, description="Redis connection URL")
redis_password: Optional[str] = Field(default=None, description="Redis password")
redis_db: int = Field(default=0, description="Redis database number")
redis_enabled: bool = Field(default=True, description="Enable Redis")
redis_host: str = Field(default="localhost", description="Redis host")
redis_port: int = Field(default=6379, description="Redis port")
redis_required: bool = Field(default=False, description="Require Redis connection (fail if unavailable)")
redis_max_connections: int = Field(default=10, description="Maximum Redis connections")
redis_socket_timeout: int = Field(default=5, description="Redis socket timeout in seconds")
redis_connect_timeout: int = Field(default=5, description="Redis connection timeout in seconds")
# Failsafe settings
enable_database_failsafe: bool = Field(default=True, description="Enable automatic SQLite failsafe when PostgreSQL unavailable")
enable_redis_failsafe: bool = Field(default=True, description="Enable automatic Redis failsafe (disable when unavailable)")
sqlite_fallback_path: str = Field(default="./data/wifi_densepose_fallback.db", description="SQLite fallback database path")
# Hardware settings
wifi_interface: str = Field(default="wlan0", description="WiFi interface name")
csi_buffer_size: int = Field(default=1000, description="CSI data buffer size")
hardware_polling_interval: float = Field(default=0.1, description="Hardware polling interval in seconds")
# CSI Processing settings
csi_sampling_rate: int = Field(default=1000, description="CSI sampling rate")
csi_window_size: int = Field(default=512, description="CSI window size")
csi_overlap: float = Field(default=0.5, description="CSI window overlap")
csi_noise_threshold: float = Field(default=0.1, description="CSI noise threshold")
csi_human_detection_threshold: float = Field(default=0.8, description="CSI human detection threshold")
csi_smoothing_factor: float = Field(default=0.9, description="CSI smoothing factor")
csi_max_history_size: int = Field(default=500, description="CSI max history size")
# Pose estimation settings
pose_model_path: Optional[str] = Field(default=None, description="Path to pose estimation model")
pose_confidence_threshold: float = Field(default=0.5, description="Minimum confidence threshold")
pose_processing_batch_size: int = Field(default=32, description="Batch size for pose processing")
pose_max_persons: int = Field(default=10, description="Maximum persons to detect per frame")
# Streaming settings
stream_fps: int = Field(default=30, description="Streaming frames per second")
stream_buffer_size: int = Field(default=100, description="Stream buffer size")
websocket_ping_interval: int = Field(default=60, description="WebSocket ping interval in seconds")
websocket_timeout: int = Field(default=300, description="WebSocket timeout in seconds")
# Logging settings
log_level: str = Field(default="INFO", description="Logging level")
log_format: str = Field(
default="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
description="Log format"
)
log_file: Optional[str] = Field(default=None, description="Log file path")
log_directory: str = Field(default="./logs", description="Log directory path")
log_max_size: int = Field(default=10485760, description="Max log file size in bytes (10MB)")
log_backup_count: int = Field(default=5, description="Number of log backup files")
# Monitoring settings
metrics_enabled: bool = Field(default=True, description="Enable metrics collection")
health_check_interval: int = Field(default=30, description="Health check interval in seconds")
performance_monitoring: bool = Field(default=True, description="Enable performance monitoring")
monitoring_interval_seconds: int = Field(default=60, description="Monitoring task interval in seconds")
cleanup_interval_seconds: int = Field(default=3600, description="Cleanup task interval in seconds")
backup_interval_seconds: int = Field(default=86400, description="Backup task interval in seconds")
# Storage settings
data_storage_path: str = Field(default="./data", description="Data storage directory")
model_storage_path: str = Field(default="./models", description="Model storage directory")
temp_storage_path: str = Field(default="./temp", description="Temporary storage directory")
backup_directory: str = Field(default="./backups", description="Backup storage directory")
max_storage_size_gb: int = Field(default=100, description="Maximum storage size in GB")
# API settings
api_prefix: str = Field(default="/api/v1", description="API prefix")
docs_url: str = Field(default="/docs", description="API documentation URL")
redoc_url: str = Field(default="/redoc", description="ReDoc documentation URL")
openapi_url: str = Field(default="/openapi.json", description="OpenAPI schema URL")
# Feature flags
enable_authentication: bool = Field(default=True, description="Enable authentication")
enable_rate_limiting: bool = Field(default=True, description="Enable rate limiting")
enable_websockets: bool = Field(default=True, description="Enable WebSocket support")
enable_historical_data: bool = Field(default=True, description="Enable historical data storage")
enable_real_time_processing: bool = Field(default=True, description="Enable real-time processing")
cors_enabled: bool = Field(default=True, description="Enable CORS middleware")
cors_allow_credentials: bool = Field(default=True, description="Allow credentials in CORS")
# Development settings
mock_hardware: bool = Field(default=False, description="Use mock hardware for development")
mock_pose_data: bool = Field(default=False, description="Use mock pose data for development")
enable_test_endpoints: bool = Field(default=False, description="Enable test endpoints")
# Cleanup settings
csi_data_retention_days: int = Field(default=30, description="CSI data retention in days")
pose_detection_retention_days: int = Field(default=30, description="Pose detection retention in days")
metrics_retention_days: int = Field(default=7, description="Metrics retention in days")
audit_log_retention_days: int = Field(default=90, description="Audit log retention in days")
orphaned_session_threshold_days: int = Field(default=7, description="Orphaned session threshold in days")
cleanup_batch_size: int = Field(default=1000, description="Cleanup batch size")
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False
)
@field_validator("environment")
@classmethod
def validate_environment(cls, v):
"""Validate environment setting."""
allowed_environments = ["development", "staging", "production"]
if v not in allowed_environments:
raise ValueError(f"Environment must be one of: {allowed_environments}")
return v
@field_validator("log_level")
@classmethod
def validate_log_level(cls, v):
"""Validate log level setting."""
allowed_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
if v.upper() not in allowed_levels:
raise ValueError(f"Log level must be one of: {allowed_levels}")
return v.upper()
@field_validator("pose_confidence_threshold")
@classmethod
def validate_confidence_threshold(cls, v):
"""Validate confidence threshold."""
if not 0.0 <= v <= 1.0:
raise ValueError("Confidence threshold must be between 0.0 and 1.0")
return v
@field_validator("stream_fps")
@classmethod
def validate_stream_fps(cls, v):
"""Validate streaming FPS."""
if not 1 <= v <= 60:
raise ValueError("Stream FPS must be between 1 and 60")
return v
@field_validator("port")
@classmethod
def validate_port(cls, v):
"""Validate port number."""
if not 1 <= v <= 65535:
raise ValueError("Port must be between 1 and 65535")
return v
@field_validator("workers")
@classmethod
def validate_workers(cls, v):
"""Validate worker count."""
if v < 1:
raise ValueError("Workers must be at least 1")
return v
@field_validator("db_port")
@classmethod
def validate_db_port(cls, v):
"""Validate database port."""
if not 1 <= v <= 65535:
raise ValueError("Database port must be between 1 and 65535")
return v
@field_validator("redis_port")
@classmethod
def validate_redis_port(cls, v):
"""Validate Redis port."""
if not 1 <= v <= 65535:
raise ValueError("Redis port must be between 1 and 65535")
return v
@field_validator("db_pool_size")
@classmethod
def validate_db_pool_size(cls, v):
"""Validate database pool size."""
if v < 1:
raise ValueError("Database pool size must be at least 1")
return v
@field_validator("monitoring_interval_seconds", "cleanup_interval_seconds", "backup_interval_seconds")
@classmethod
def validate_interval_seconds(cls, v):
"""Validate interval settings."""
if v < 0:
raise ValueError("Interval seconds must be non-negative")
return v
@property
def is_development(self) -> bool:
"""Check if running in development environment."""
return self.environment == "development"
@property
def is_production(self) -> bool:
"""Check if running in production environment."""
return self.environment == "production"
@property
def is_testing(self) -> bool:
"""Check if running in testing environment."""
return self.environment == "testing"
def get_database_url(self) -> str:
"""Get database URL with fallback."""
if self.database_url:
return self.database_url
# Build URL from individual components if available
if self.db_host and self.db_name and self.db_user:
password_part = f":{self.db_password}" if self.db_password else ""
return f"postgresql://{self.db_user}{password_part}@{self.db_host}:{self.db_port}/{self.db_name}"
# Default SQLite database for development
if self.is_development:
return f"sqlite:///{self.data_storage_path}/wifi_densepose.db"
# SQLite failsafe for production if enabled
if self.enable_database_failsafe:
return f"sqlite:///{self.sqlite_fallback_path}"
raise ValueError("Database URL must be configured for non-development environments")
def get_sqlite_fallback_url(self) -> str:
"""Get SQLite fallback database URL."""
return f"sqlite:///{self.sqlite_fallback_path}"
def get_redis_url(self) -> Optional[str]:
"""Get Redis URL with fallback."""
if not self.redis_enabled:
return None
if self.redis_url:
return self.redis_url
# Build URL from individual components
password_part = f":{self.redis_password}@" if self.redis_password else ""
return f"redis://{password_part}{self.redis_host}:{self.redis_port}/{self.redis_db}"
def get_cors_config(self) -> Dict[str, Any]:
"""Get CORS configuration."""
if self.is_development:
return {
"allow_origins": ["*"],
"allow_credentials": True,
"allow_methods": ["*"],
"allow_headers": ["*"],
}
return {
"allow_origins": self.cors_origins,
"allow_credentials": True,
"allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
"allow_headers": ["Authorization", "Content-Type"],
}
def get_logging_config(self) -> Dict[str, Any]:
"""Get logging configuration."""
config = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": self.log_format,
},
"detailed": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(module)s:%(lineno)d - %(message)s",
},
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"level": self.log_level,
"formatter": "default",
"stream": "ext://sys.stdout",
},
},
"loggers": {
"": {
"level": self.log_level,
"handlers": ["console"],
},
"uvicorn": {
"level": "INFO",
"handlers": ["console"],
"propagate": False,
},
"fastapi": {
"level": "INFO",
"handlers": ["console"],
"propagate": False,
},
},
}
# Add file handler if log file is specified
if self.log_file:
config["handlers"]["file"] = {
"class": "logging.handlers.RotatingFileHandler",
"level": self.log_level,
"formatter": "detailed",
"filename": self.log_file,
"maxBytes": self.log_max_size,
"backupCount": self.log_backup_count,
}
# Add file handler to all loggers
for logger_config in config["loggers"].values():
logger_config["handlers"].append("file")
return config
def create_directories(self):
"""Create necessary directories."""
directories = [
self.data_storage_path,
self.model_storage_path,
self.temp_storage_path,
self.log_directory,
self.backup_directory,
]
for directory in directories:
os.makedirs(directory, exist_ok=True)
@lru_cache()
def get_settings() -> Settings:
"""Get cached settings instance."""
settings = Settings()
settings.create_directories()
return settings
def get_test_settings() -> Settings:
"""Get settings for testing."""
return Settings(
environment="testing",
debug=True,
secret_key="test-secret-key",
database_url="sqlite:///:memory:",
mock_hardware=True,
mock_pose_data=True,
enable_test_endpoints=True,
log_level="DEBUG"
)
def load_settings_from_file(file_path: str) -> Settings:
"""Load settings from a specific file."""
return Settings(_env_file=file_path)
def validate_settings(settings: Settings) -> List[str]:
"""Validate settings and return list of issues."""
issues = []
# Check required settings for production
if settings.is_production:
if not settings.secret_key or settings.secret_key == "change-me":
issues.append("Secret key must be set for production")
if not settings.database_url and not (settings.db_host and settings.db_name and settings.db_user):
issues.append("Database URL or database connection parameters must be set for production")
if settings.debug:
issues.append("Debug mode should be disabled in production")
if "*" in settings.allowed_hosts:
issues.append("Allowed hosts should be restricted in production")
if "*" in settings.cors_origins:
issues.append("CORS origins should be restricted in production")
# Check storage paths exist
try:
settings.create_directories()
except Exception as e:
issues.append(f"Cannot create storage directories: {e}")
return issues