feat: Complete Rust port of WiFi-DensePose with modular crates
Major changes: - Organized Python v1 implementation into v1/ subdirectory - Created Rust workspace with 9 modular crates: - wifi-densepose-core: Core types, traits, errors - wifi-densepose-signal: CSI processing, phase sanitization, FFT - wifi-densepose-nn: Neural network inference (ONNX/Candle/tch) - wifi-densepose-api: Axum-based REST/WebSocket API - wifi-densepose-db: SQLx database layer - wifi-densepose-config: Configuration management - wifi-densepose-hardware: Hardware abstraction - wifi-densepose-wasm: WebAssembly bindings - wifi-densepose-cli: Command-line interface Documentation: - ADR-001: Workspace structure - ADR-002: Signal processing library selection - ADR-003: Neural network inference strategy - DDD domain model with bounded contexts Testing: - 69 tests passing across all crates - Signal processing: 45 tests - Neural networks: 21 tests - Core: 3 doc tests Performance targets: - 10x faster CSI processing (~0.5ms vs ~5ms) - 5x lower memory usage (~100MB vs ~500MB) - WASM support for browser deployment
This commit is contained in:
19
v1/src/services/__init__.py
Normal file
19
v1/src/services/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Services package for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
from .orchestrator import ServiceOrchestrator
|
||||
from .health_check import HealthCheckService
|
||||
from .metrics import MetricsService
|
||||
from .pose_service import PoseService
|
||||
from .stream_service import StreamService
|
||||
from .hardware_service import HardwareService
|
||||
|
||||
__all__ = [
|
||||
'ServiceOrchestrator',
|
||||
'HealthCheckService',
|
||||
'MetricsService',
|
||||
'PoseService',
|
||||
'StreamService',
|
||||
'HardwareService'
|
||||
]
|
||||
482
v1/src/services/hardware_service.py
Normal file
482
v1/src/services/hardware_service.py
Normal file
@@ -0,0 +1,482 @@
|
||||
"""
|
||||
Hardware interface service for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.config.domains import DomainConfig
|
||||
from src.core.router_interface import RouterInterface
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HardwareService:
|
||||
"""Service for hardware interface operations."""
|
||||
|
||||
def __init__(self, settings: Settings, domain_config: DomainConfig):
|
||||
"""Initialize hardware service."""
|
||||
self.settings = settings
|
||||
self.domain_config = domain_config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Router interfaces
|
||||
self.router_interfaces: Dict[str, RouterInterface] = {}
|
||||
|
||||
# Service state
|
||||
self.is_running = False
|
||||
self.last_error = None
|
||||
|
||||
# Data collection statistics
|
||||
self.stats = {
|
||||
"total_samples": 0,
|
||||
"successful_samples": 0,
|
||||
"failed_samples": 0,
|
||||
"average_sample_rate": 0.0,
|
||||
"last_sample_time": None,
|
||||
"connected_routers": 0
|
||||
}
|
||||
|
||||
# Background tasks
|
||||
self.collection_task = None
|
||||
self.monitoring_task = None
|
||||
|
||||
# Data buffers
|
||||
self.recent_samples = []
|
||||
self.max_recent_samples = 1000
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the hardware service."""
|
||||
await self.start()
|
||||
|
||||
async def start(self):
|
||||
"""Start the hardware service."""
|
||||
if self.is_running:
|
||||
return
|
||||
|
||||
try:
|
||||
self.logger.info("Starting hardware service...")
|
||||
|
||||
# Initialize router interfaces
|
||||
await self._initialize_routers()
|
||||
|
||||
self.is_running = True
|
||||
|
||||
# Start background tasks
|
||||
if not self.settings.mock_hardware:
|
||||
self.collection_task = asyncio.create_task(self._data_collection_loop())
|
||||
|
||||
self.monitoring_task = asyncio.create_task(self._monitoring_loop())
|
||||
|
||||
self.logger.info("Hardware service started successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.last_error = str(e)
|
||||
self.logger.error(f"Failed to start hardware service: {e}")
|
||||
raise
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the hardware service."""
|
||||
self.is_running = False
|
||||
|
||||
# Cancel background tasks
|
||||
if self.collection_task:
|
||||
self.collection_task.cancel()
|
||||
try:
|
||||
await self.collection_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self.monitoring_task:
|
||||
self.monitoring_task.cancel()
|
||||
try:
|
||||
await self.monitoring_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Disconnect from routers
|
||||
await self._disconnect_routers()
|
||||
|
||||
self.logger.info("Hardware service stopped")
|
||||
|
||||
async def _initialize_routers(self):
|
||||
"""Initialize router interfaces."""
|
||||
try:
|
||||
# Get router configurations from domain config
|
||||
routers = self.domain_config.get_all_routers()
|
||||
|
||||
for router_config in routers:
|
||||
if not router_config.enabled:
|
||||
continue
|
||||
|
||||
router_id = router_config.router_id
|
||||
|
||||
# Create router interface
|
||||
router_interface = RouterInterface(
|
||||
router_id=router_id,
|
||||
host=router_config.ip_address,
|
||||
port=22, # Default SSH port
|
||||
username="admin", # Default username
|
||||
password="admin", # Default password
|
||||
interface=router_config.interface,
|
||||
mock_mode=self.settings.mock_hardware
|
||||
)
|
||||
|
||||
# Connect to router (always connect, even in mock mode)
|
||||
await router_interface.connect()
|
||||
|
||||
self.router_interfaces[router_id] = router_interface
|
||||
self.logger.info(f"Router interface initialized: {router_id}")
|
||||
|
||||
self.stats["connected_routers"] = len(self.router_interfaces)
|
||||
|
||||
if not self.router_interfaces:
|
||||
self.logger.warning("No router interfaces configured")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize routers: {e}")
|
||||
raise
|
||||
|
||||
async def _disconnect_routers(self):
|
||||
"""Disconnect from all routers."""
|
||||
for router_id, interface in self.router_interfaces.items():
|
||||
try:
|
||||
await interface.disconnect()
|
||||
self.logger.info(f"Disconnected from router: {router_id}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error disconnecting from router {router_id}: {e}")
|
||||
|
||||
self.router_interfaces.clear()
|
||||
self.stats["connected_routers"] = 0
|
||||
|
||||
async def _data_collection_loop(self):
|
||||
"""Background loop for data collection."""
|
||||
try:
|
||||
while self.is_running:
|
||||
start_time = time.time()
|
||||
|
||||
# Collect data from all routers
|
||||
await self._collect_data_from_routers()
|
||||
|
||||
# Calculate sleep time to maintain polling interval
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, self.settings.hardware_polling_interval - elapsed)
|
||||
|
||||
if sleep_time > 0:
|
||||
await asyncio.sleep(sleep_time)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self.logger.info("Data collection loop cancelled")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in data collection loop: {e}")
|
||||
self.last_error = str(e)
|
||||
|
||||
async def _monitoring_loop(self):
|
||||
"""Background loop for hardware monitoring."""
|
||||
try:
|
||||
while self.is_running:
|
||||
# Monitor router connections
|
||||
await self._monitor_router_health()
|
||||
|
||||
# Update statistics
|
||||
self._update_sample_rate_stats()
|
||||
|
||||
# Wait before next check
|
||||
await asyncio.sleep(30) # Check every 30 seconds
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self.logger.info("Monitoring loop cancelled")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in monitoring loop: {e}")
|
||||
|
||||
async def _collect_data_from_routers(self):
|
||||
"""Collect CSI data from all connected routers."""
|
||||
for router_id, interface in self.router_interfaces.items():
|
||||
try:
|
||||
# Get CSI data from router
|
||||
csi_data = await interface.get_csi_data()
|
||||
|
||||
if csi_data is not None:
|
||||
# Process the collected data
|
||||
await self._process_collected_data(router_id, csi_data)
|
||||
|
||||
self.stats["successful_samples"] += 1
|
||||
self.stats["last_sample_time"] = datetime.now().isoformat()
|
||||
else:
|
||||
self.stats["failed_samples"] += 1
|
||||
|
||||
self.stats["total_samples"] += 1
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error collecting data from router {router_id}: {e}")
|
||||
self.stats["failed_samples"] += 1
|
||||
self.stats["total_samples"] += 1
|
||||
|
||||
async def _process_collected_data(self, router_id: str, csi_data: np.ndarray):
|
||||
"""Process collected CSI data."""
|
||||
try:
|
||||
# Create sample metadata
|
||||
metadata = {
|
||||
"router_id": router_id,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"sample_rate": self.stats["average_sample_rate"],
|
||||
"data_shape": csi_data.shape if hasattr(csi_data, 'shape') else None
|
||||
}
|
||||
|
||||
# Add to recent samples buffer
|
||||
sample = {
|
||||
"router_id": router_id,
|
||||
"timestamp": metadata["timestamp"],
|
||||
"data": csi_data,
|
||||
"metadata": metadata
|
||||
}
|
||||
|
||||
self.recent_samples.append(sample)
|
||||
|
||||
# Maintain buffer size
|
||||
if len(self.recent_samples) > self.max_recent_samples:
|
||||
self.recent_samples.pop(0)
|
||||
|
||||
# Notify other services (this would typically be done through an event system)
|
||||
# For now, we'll just log the data collection
|
||||
self.logger.debug(f"Collected CSI data from {router_id}: shape {csi_data.shape if hasattr(csi_data, 'shape') else 'unknown'}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error processing collected data: {e}")
|
||||
|
||||
async def _monitor_router_health(self):
|
||||
"""Monitor health of router connections."""
|
||||
healthy_routers = 0
|
||||
|
||||
for router_id, interface in self.router_interfaces.items():
|
||||
try:
|
||||
is_healthy = await interface.check_health()
|
||||
|
||||
if is_healthy:
|
||||
healthy_routers += 1
|
||||
else:
|
||||
self.logger.warning(f"Router {router_id} is unhealthy")
|
||||
|
||||
# Try to reconnect if not in mock mode
|
||||
if not self.settings.mock_hardware:
|
||||
try:
|
||||
await interface.reconnect()
|
||||
self.logger.info(f"Reconnected to router {router_id}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to reconnect to router {router_id}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error checking health of router {router_id}: {e}")
|
||||
|
||||
self.stats["connected_routers"] = healthy_routers
|
||||
|
||||
def _update_sample_rate_stats(self):
|
||||
"""Update sample rate statistics."""
|
||||
if len(self.recent_samples) < 2:
|
||||
return
|
||||
|
||||
# Calculate sample rate from recent samples
|
||||
recent_count = min(100, len(self.recent_samples))
|
||||
recent_samples = self.recent_samples[-recent_count:]
|
||||
|
||||
if len(recent_samples) >= 2:
|
||||
# Calculate time differences
|
||||
time_diffs = []
|
||||
for i in range(1, len(recent_samples)):
|
||||
try:
|
||||
t1 = datetime.fromisoformat(recent_samples[i-1]["timestamp"])
|
||||
t2 = datetime.fromisoformat(recent_samples[i]["timestamp"])
|
||||
diff = (t2 - t1).total_seconds()
|
||||
if diff > 0:
|
||||
time_diffs.append(diff)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if time_diffs:
|
||||
avg_interval = sum(time_diffs) / len(time_diffs)
|
||||
self.stats["average_sample_rate"] = 1.0 / avg_interval if avg_interval > 0 else 0.0
|
||||
|
||||
async def get_router_status(self, router_id: str) -> Dict[str, Any]:
|
||||
"""Get status of a specific router."""
|
||||
if router_id not in self.router_interfaces:
|
||||
raise ValueError(f"Router {router_id} not found")
|
||||
|
||||
interface = self.router_interfaces[router_id]
|
||||
|
||||
try:
|
||||
is_healthy = await interface.check_health()
|
||||
status = await interface.get_status()
|
||||
|
||||
return {
|
||||
"router_id": router_id,
|
||||
"healthy": is_healthy,
|
||||
"connected": status.get("connected", False),
|
||||
"last_data_time": status.get("last_data_time"),
|
||||
"error_count": status.get("error_count", 0),
|
||||
"configuration": status.get("configuration", {})
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"router_id": router_id,
|
||||
"healthy": False,
|
||||
"connected": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def get_all_router_status(self) -> List[Dict[str, Any]]:
|
||||
"""Get status of all routers."""
|
||||
statuses = []
|
||||
|
||||
for router_id in self.router_interfaces:
|
||||
try:
|
||||
status = await self.get_router_status(router_id)
|
||||
statuses.append(status)
|
||||
except Exception as e:
|
||||
statuses.append({
|
||||
"router_id": router_id,
|
||||
"healthy": False,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
return statuses
|
||||
|
||||
async def get_recent_data(self, router_id: Optional[str] = None, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Get recent CSI data samples."""
|
||||
samples = self.recent_samples[-limit:] if limit else self.recent_samples
|
||||
|
||||
if router_id:
|
||||
samples = [s for s in samples if s["router_id"] == router_id]
|
||||
|
||||
# Convert numpy arrays to lists for JSON serialization
|
||||
result = []
|
||||
for sample in samples:
|
||||
sample_copy = sample.copy()
|
||||
if isinstance(sample_copy["data"], np.ndarray):
|
||||
sample_copy["data"] = sample_copy["data"].tolist()
|
||||
result.append(sample_copy)
|
||||
|
||||
return result
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get service status."""
|
||||
return {
|
||||
"status": "healthy" if self.is_running and not self.last_error else "unhealthy",
|
||||
"running": self.is_running,
|
||||
"last_error": self.last_error,
|
||||
"statistics": self.stats.copy(),
|
||||
"configuration": {
|
||||
"mock_hardware": self.settings.mock_hardware,
|
||||
"wifi_interface": self.settings.wifi_interface,
|
||||
"polling_interval": self.settings.hardware_polling_interval,
|
||||
"buffer_size": self.settings.csi_buffer_size
|
||||
},
|
||||
"routers": await self.get_all_router_status()
|
||||
}
|
||||
|
||||
async def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get service metrics."""
|
||||
total_samples = self.stats["total_samples"]
|
||||
success_rate = self.stats["successful_samples"] / max(1, total_samples)
|
||||
|
||||
return {
|
||||
"hardware_service": {
|
||||
"total_samples": total_samples,
|
||||
"successful_samples": self.stats["successful_samples"],
|
||||
"failed_samples": self.stats["failed_samples"],
|
||||
"success_rate": success_rate,
|
||||
"average_sample_rate": self.stats["average_sample_rate"],
|
||||
"connected_routers": self.stats["connected_routers"],
|
||||
"last_sample_time": self.stats["last_sample_time"]
|
||||
}
|
||||
}
|
||||
|
||||
async def reset(self):
|
||||
"""Reset service state."""
|
||||
self.stats = {
|
||||
"total_samples": 0,
|
||||
"successful_samples": 0,
|
||||
"failed_samples": 0,
|
||||
"average_sample_rate": 0.0,
|
||||
"last_sample_time": None,
|
||||
"connected_routers": len(self.router_interfaces)
|
||||
}
|
||||
|
||||
self.recent_samples.clear()
|
||||
self.last_error = None
|
||||
|
||||
self.logger.info("Hardware service reset")
|
||||
|
||||
async def trigger_manual_collection(self, router_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Manually trigger data collection."""
|
||||
if not self.is_running:
|
||||
raise RuntimeError("Hardware service is not running")
|
||||
|
||||
results = {}
|
||||
|
||||
if router_id:
|
||||
# Collect from specific router
|
||||
if router_id not in self.router_interfaces:
|
||||
raise ValueError(f"Router {router_id} not found")
|
||||
|
||||
interface = self.router_interfaces[router_id]
|
||||
try:
|
||||
csi_data = await interface.get_csi_data()
|
||||
if csi_data is not None:
|
||||
await self._process_collected_data(router_id, csi_data)
|
||||
results[router_id] = {"success": True, "data_shape": csi_data.shape if hasattr(csi_data, 'shape') else None}
|
||||
else:
|
||||
results[router_id] = {"success": False, "error": "No data received"}
|
||||
except Exception as e:
|
||||
results[router_id] = {"success": False, "error": str(e)}
|
||||
else:
|
||||
# Collect from all routers
|
||||
await self._collect_data_from_routers()
|
||||
results = {"message": "Manual collection triggered for all routers"}
|
||||
|
||||
return results
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Perform health check."""
|
||||
try:
|
||||
status = "healthy" if self.is_running and not self.last_error else "unhealthy"
|
||||
|
||||
# Check router health
|
||||
healthy_routers = 0
|
||||
total_routers = len(self.router_interfaces)
|
||||
|
||||
for router_id, interface in self.router_interfaces.items():
|
||||
try:
|
||||
if await interface.check_health():
|
||||
healthy_routers += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"status": status,
|
||||
"message": self.last_error if self.last_error else "Hardware service is running normally",
|
||||
"connected_routers": f"{healthy_routers}/{total_routers}",
|
||||
"metrics": {
|
||||
"total_samples": self.stats["total_samples"],
|
||||
"success_rate": (
|
||||
self.stats["successful_samples"] / max(1, self.stats["total_samples"])
|
||||
),
|
||||
"average_sample_rate": self.stats["average_sample_rate"]
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"message": f"Health check failed: {str(e)}"
|
||||
}
|
||||
|
||||
async def is_ready(self) -> bool:
|
||||
"""Check if service is ready."""
|
||||
return self.is_running and len(self.router_interfaces) > 0
|
||||
465
v1/src/services/health_check.py
Normal file
465
v1/src/services/health_check.py
Normal file
@@ -0,0 +1,465 @@
|
||||
"""
|
||||
Health check service for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
from src.config.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HealthStatus(Enum):
|
||||
"""Health status enumeration."""
|
||||
HEALTHY = "healthy"
|
||||
DEGRADED = "degraded"
|
||||
UNHEALTHY = "unhealthy"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HealthCheck:
|
||||
"""Health check result."""
|
||||
name: str
|
||||
status: HealthStatus
|
||||
message: str
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
duration_ms: float = 0.0
|
||||
details: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceHealth:
|
||||
"""Service health information."""
|
||||
name: str
|
||||
status: HealthStatus
|
||||
last_check: Optional[datetime] = None
|
||||
checks: List[HealthCheck] = field(default_factory=list)
|
||||
uptime: float = 0.0
|
||||
error_count: int = 0
|
||||
last_error: Optional[str] = None
|
||||
|
||||
|
||||
class HealthCheckService:
|
||||
"""Service for monitoring application health."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self._services: Dict[str, ServiceHealth] = {}
|
||||
self._start_time = time.time()
|
||||
self._initialized = False
|
||||
self._running = False
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize health check service."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
logger.info("Initializing health check service")
|
||||
|
||||
# Initialize service health tracking
|
||||
self._services = {
|
||||
"api": ServiceHealth("api", HealthStatus.UNKNOWN),
|
||||
"database": ServiceHealth("database", HealthStatus.UNKNOWN),
|
||||
"redis": ServiceHealth("redis", HealthStatus.UNKNOWN),
|
||||
"hardware": ServiceHealth("hardware", HealthStatus.UNKNOWN),
|
||||
"pose": ServiceHealth("pose", HealthStatus.UNKNOWN),
|
||||
"stream": ServiceHealth("stream", HealthStatus.UNKNOWN),
|
||||
}
|
||||
|
||||
self._initialized = True
|
||||
logger.info("Health check service initialized")
|
||||
|
||||
async def start(self):
|
||||
"""Start health check service."""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
self._running = True
|
||||
logger.info("Health check service started")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown health check service."""
|
||||
self._running = False
|
||||
logger.info("Health check service shut down")
|
||||
|
||||
async def perform_health_checks(self) -> Dict[str, HealthCheck]:
|
||||
"""Perform all health checks."""
|
||||
if not self._running:
|
||||
return {}
|
||||
|
||||
logger.debug("Performing health checks")
|
||||
results = {}
|
||||
|
||||
# Perform individual health checks
|
||||
checks = [
|
||||
self._check_api_health(),
|
||||
self._check_database_health(),
|
||||
self._check_redis_health(),
|
||||
self._check_hardware_health(),
|
||||
self._check_pose_health(),
|
||||
self._check_stream_health(),
|
||||
]
|
||||
|
||||
# Run checks concurrently
|
||||
check_results = await asyncio.gather(*checks, return_exceptions=True)
|
||||
|
||||
# Process results
|
||||
for i, result in enumerate(check_results):
|
||||
check_name = ["api", "database", "redis", "hardware", "pose", "stream"][i]
|
||||
|
||||
if isinstance(result, Exception):
|
||||
health_check = HealthCheck(
|
||||
name=check_name,
|
||||
status=HealthStatus.UNHEALTHY,
|
||||
message=f"Health check failed: {result}"
|
||||
)
|
||||
else:
|
||||
health_check = result
|
||||
|
||||
results[check_name] = health_check
|
||||
self._update_service_health(check_name, health_check)
|
||||
|
||||
logger.debug(f"Completed {len(results)} health checks")
|
||||
return results
|
||||
|
||||
async def _check_api_health(self) -> HealthCheck:
|
||||
"""Check API health."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Basic API health check
|
||||
uptime = time.time() - self._start_time
|
||||
|
||||
status = HealthStatus.HEALTHY
|
||||
message = "API is running normally"
|
||||
details = {
|
||||
"uptime_seconds": uptime,
|
||||
"uptime_formatted": str(timedelta(seconds=int(uptime)))
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
status = HealthStatus.UNHEALTHY
|
||||
message = f"API health check failed: {e}"
|
||||
details = {"error": str(e)}
|
||||
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
||||
return HealthCheck(
|
||||
name="api",
|
||||
status=status,
|
||||
message=message,
|
||||
duration_ms=duration_ms,
|
||||
details=details
|
||||
)
|
||||
|
||||
async def _check_database_health(self) -> HealthCheck:
|
||||
"""Check database health."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from src.database.connection import get_database_manager
|
||||
|
||||
db_manager = get_database_manager()
|
||||
|
||||
if not db_manager.is_connected():
|
||||
status = HealthStatus.UNHEALTHY
|
||||
message = "Database is not connected"
|
||||
details = {"connected": False}
|
||||
else:
|
||||
# Test database connection
|
||||
await db_manager.test_connection()
|
||||
|
||||
status = HealthStatus.HEALTHY
|
||||
message = "Database is connected and responsive"
|
||||
details = {
|
||||
"connected": True,
|
||||
"pool_size": db_manager.get_pool_size(),
|
||||
"active_connections": db_manager.get_active_connections()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
status = HealthStatus.UNHEALTHY
|
||||
message = f"Database health check failed: {e}"
|
||||
details = {"error": str(e)}
|
||||
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
||||
return HealthCheck(
|
||||
name="database",
|
||||
status=status,
|
||||
message=message,
|
||||
duration_ms=duration_ms,
|
||||
details=details
|
||||
)
|
||||
|
||||
async def _check_redis_health(self) -> HealthCheck:
|
||||
"""Check Redis health."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
redis_config = self.settings.get_redis_url()
|
||||
|
||||
if not redis_config:
|
||||
status = HealthStatus.UNKNOWN
|
||||
message = "Redis is not configured"
|
||||
details = {"configured": False}
|
||||
else:
|
||||
# Test Redis connection
|
||||
import redis.asyncio as redis
|
||||
|
||||
redis_client = redis.from_url(redis_config)
|
||||
await redis_client.ping()
|
||||
await redis_client.close()
|
||||
|
||||
status = HealthStatus.HEALTHY
|
||||
message = "Redis is connected and responsive"
|
||||
details = {"connected": True}
|
||||
|
||||
except Exception as e:
|
||||
status = HealthStatus.UNHEALTHY
|
||||
message = f"Redis health check failed: {e}"
|
||||
details = {"error": str(e)}
|
||||
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
||||
return HealthCheck(
|
||||
name="redis",
|
||||
status=status,
|
||||
message=message,
|
||||
duration_ms=duration_ms,
|
||||
details=details
|
||||
)
|
||||
|
||||
async def _check_hardware_health(self) -> HealthCheck:
|
||||
"""Check hardware service health."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from src.api.dependencies import get_hardware_service
|
||||
|
||||
hardware_service = get_hardware_service()
|
||||
|
||||
if hasattr(hardware_service, 'get_status'):
|
||||
status_info = await hardware_service.get_status()
|
||||
|
||||
if status_info.get("status") == "healthy":
|
||||
status = HealthStatus.HEALTHY
|
||||
message = "Hardware service is operational"
|
||||
else:
|
||||
status = HealthStatus.DEGRADED
|
||||
message = f"Hardware service status: {status_info.get('status', 'unknown')}"
|
||||
|
||||
details = status_info
|
||||
else:
|
||||
status = HealthStatus.UNKNOWN
|
||||
message = "Hardware service status unavailable"
|
||||
details = {}
|
||||
|
||||
except Exception as e:
|
||||
status = HealthStatus.UNHEALTHY
|
||||
message = f"Hardware health check failed: {e}"
|
||||
details = {"error": str(e)}
|
||||
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
||||
return HealthCheck(
|
||||
name="hardware",
|
||||
status=status,
|
||||
message=message,
|
||||
duration_ms=duration_ms,
|
||||
details=details
|
||||
)
|
||||
|
||||
async def _check_pose_health(self) -> HealthCheck:
|
||||
"""Check pose service health."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from src.api.dependencies import get_pose_service
|
||||
|
||||
pose_service = get_pose_service()
|
||||
|
||||
if hasattr(pose_service, 'get_status'):
|
||||
status_info = await pose_service.get_status()
|
||||
|
||||
if status_info.get("status") == "healthy":
|
||||
status = HealthStatus.HEALTHY
|
||||
message = "Pose service is operational"
|
||||
else:
|
||||
status = HealthStatus.DEGRADED
|
||||
message = f"Pose service status: {status_info.get('status', 'unknown')}"
|
||||
|
||||
details = status_info
|
||||
else:
|
||||
status = HealthStatus.UNKNOWN
|
||||
message = "Pose service status unavailable"
|
||||
details = {}
|
||||
|
||||
except Exception as e:
|
||||
status = HealthStatus.UNHEALTHY
|
||||
message = f"Pose health check failed: {e}"
|
||||
details = {"error": str(e)}
|
||||
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
||||
return HealthCheck(
|
||||
name="pose",
|
||||
status=status,
|
||||
message=message,
|
||||
duration_ms=duration_ms,
|
||||
details=details
|
||||
)
|
||||
|
||||
async def _check_stream_health(self) -> HealthCheck:
|
||||
"""Check stream service health."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from src.api.dependencies import get_stream_service
|
||||
|
||||
stream_service = get_stream_service()
|
||||
|
||||
if hasattr(stream_service, 'get_status'):
|
||||
status_info = await stream_service.get_status()
|
||||
|
||||
if status_info.get("status") == "healthy":
|
||||
status = HealthStatus.HEALTHY
|
||||
message = "Stream service is operational"
|
||||
else:
|
||||
status = HealthStatus.DEGRADED
|
||||
message = f"Stream service status: {status_info.get('status', 'unknown')}"
|
||||
|
||||
details = status_info
|
||||
else:
|
||||
status = HealthStatus.UNKNOWN
|
||||
message = "Stream service status unavailable"
|
||||
details = {}
|
||||
|
||||
except Exception as e:
|
||||
status = HealthStatus.UNHEALTHY
|
||||
message = f"Stream health check failed: {e}"
|
||||
details = {"error": str(e)}
|
||||
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
||||
return HealthCheck(
|
||||
name="stream",
|
||||
status=status,
|
||||
message=message,
|
||||
duration_ms=duration_ms,
|
||||
details=details
|
||||
)
|
||||
|
||||
def _update_service_health(self, service_name: str, health_check: HealthCheck):
|
||||
"""Update service health information."""
|
||||
if service_name not in self._services:
|
||||
self._services[service_name] = ServiceHealth(service_name, HealthStatus.UNKNOWN)
|
||||
|
||||
service_health = self._services[service_name]
|
||||
service_health.status = health_check.status
|
||||
service_health.last_check = health_check.timestamp
|
||||
service_health.uptime = time.time() - self._start_time
|
||||
|
||||
# Keep last 10 checks
|
||||
service_health.checks.append(health_check)
|
||||
if len(service_health.checks) > 10:
|
||||
service_health.checks.pop(0)
|
||||
|
||||
# Update error tracking
|
||||
if health_check.status == HealthStatus.UNHEALTHY:
|
||||
service_health.error_count += 1
|
||||
service_health.last_error = health_check.message
|
||||
|
||||
async def get_overall_health(self) -> Dict[str, Any]:
|
||||
"""Get overall system health."""
|
||||
if not self._services:
|
||||
return {
|
||||
"status": HealthStatus.UNKNOWN.value,
|
||||
"message": "Health checks not initialized"
|
||||
}
|
||||
|
||||
# Determine overall status
|
||||
statuses = [service.status for service in self._services.values()]
|
||||
|
||||
if all(status == HealthStatus.HEALTHY for status in statuses):
|
||||
overall_status = HealthStatus.HEALTHY
|
||||
message = "All services are healthy"
|
||||
elif any(status == HealthStatus.UNHEALTHY for status in statuses):
|
||||
overall_status = HealthStatus.UNHEALTHY
|
||||
unhealthy_services = [
|
||||
name for name, service in self._services.items()
|
||||
if service.status == HealthStatus.UNHEALTHY
|
||||
]
|
||||
message = f"Unhealthy services: {', '.join(unhealthy_services)}"
|
||||
elif any(status == HealthStatus.DEGRADED for status in statuses):
|
||||
overall_status = HealthStatus.DEGRADED
|
||||
degraded_services = [
|
||||
name for name, service in self._services.items()
|
||||
if service.status == HealthStatus.DEGRADED
|
||||
]
|
||||
message = f"Degraded services: {', '.join(degraded_services)}"
|
||||
else:
|
||||
overall_status = HealthStatus.UNKNOWN
|
||||
message = "System health status unknown"
|
||||
|
||||
return {
|
||||
"status": overall_status.value,
|
||||
"message": message,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"uptime": time.time() - self._start_time,
|
||||
"services": {
|
||||
name: {
|
||||
"status": service.status.value,
|
||||
"last_check": service.last_check.isoformat() if service.last_check else None,
|
||||
"error_count": service.error_count,
|
||||
"last_error": service.last_error
|
||||
}
|
||||
for name, service in self._services.items()
|
||||
}
|
||||
}
|
||||
|
||||
async def get_service_health(self, service_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get health information for a specific service."""
|
||||
service = self._services.get(service_name)
|
||||
if not service:
|
||||
return None
|
||||
|
||||
return {
|
||||
"name": service.name,
|
||||
"status": service.status.value,
|
||||
"last_check": service.last_check.isoformat() if service.last_check else None,
|
||||
"uptime": service.uptime,
|
||||
"error_count": service.error_count,
|
||||
"last_error": service.last_error,
|
||||
"recent_checks": [
|
||||
{
|
||||
"timestamp": check.timestamp.isoformat(),
|
||||
"status": check.status.value,
|
||||
"message": check.message,
|
||||
"duration_ms": check.duration_ms,
|
||||
"details": check.details
|
||||
}
|
||||
for check in service.checks[-5:] # Last 5 checks
|
||||
]
|
||||
}
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get health check service status."""
|
||||
return {
|
||||
"status": "healthy" if self._running else "stopped",
|
||||
"initialized": self._initialized,
|
||||
"running": self._running,
|
||||
"services_monitored": len(self._services),
|
||||
"uptime": time.time() - self._start_time
|
||||
}
|
||||
431
v1/src/services/metrics.py
Normal file
431
v1/src/services/metrics.py
Normal file
@@ -0,0 +1,431 @@
|
||||
"""
|
||||
Metrics collection service for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import psutil
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict, deque
|
||||
|
||||
from src.config.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetricPoint:
|
||||
"""Single metric data point."""
|
||||
timestamp: datetime
|
||||
value: float
|
||||
labels: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetricSeries:
|
||||
"""Time series of metric points."""
|
||||
name: str
|
||||
description: str
|
||||
unit: str
|
||||
points: deque = field(default_factory=lambda: deque(maxlen=1000))
|
||||
|
||||
def add_point(self, value: float, labels: Optional[Dict[str, str]] = None):
|
||||
"""Add a metric point."""
|
||||
point = MetricPoint(
|
||||
timestamp=datetime.utcnow(),
|
||||
value=value,
|
||||
labels=labels or {}
|
||||
)
|
||||
self.points.append(point)
|
||||
|
||||
def get_latest(self) -> Optional[MetricPoint]:
|
||||
"""Get the latest metric point."""
|
||||
return self.points[-1] if self.points else None
|
||||
|
||||
def get_average(self, duration: timedelta) -> Optional[float]:
|
||||
"""Get average value over a time duration."""
|
||||
cutoff = datetime.utcnow() - duration
|
||||
relevant_points = [
|
||||
point for point in self.points
|
||||
if point.timestamp >= cutoff
|
||||
]
|
||||
|
||||
if not relevant_points:
|
||||
return None
|
||||
|
||||
return sum(point.value for point in relevant_points) / len(relevant_points)
|
||||
|
||||
def get_max(self, duration: timedelta) -> Optional[float]:
|
||||
"""Get maximum value over a time duration."""
|
||||
cutoff = datetime.utcnow() - duration
|
||||
relevant_points = [
|
||||
point for point in self.points
|
||||
if point.timestamp >= cutoff
|
||||
]
|
||||
|
||||
if not relevant_points:
|
||||
return None
|
||||
|
||||
return max(point.value for point in relevant_points)
|
||||
|
||||
|
||||
class MetricsService:
|
||||
"""Service for collecting and managing application metrics."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self._metrics: Dict[str, MetricSeries] = {}
|
||||
self._counters: Dict[str, float] = defaultdict(float)
|
||||
self._gauges: Dict[str, float] = {}
|
||||
self._histograms: Dict[str, List[float]] = defaultdict(list)
|
||||
self._start_time = time.time()
|
||||
self._initialized = False
|
||||
self._running = False
|
||||
|
||||
# Initialize standard metrics
|
||||
self._initialize_standard_metrics()
|
||||
|
||||
def _initialize_standard_metrics(self):
|
||||
"""Initialize standard system and application metrics."""
|
||||
self._metrics.update({
|
||||
# System metrics
|
||||
"system_cpu_usage": MetricSeries(
|
||||
"system_cpu_usage", "System CPU usage percentage", "percent"
|
||||
),
|
||||
"system_memory_usage": MetricSeries(
|
||||
"system_memory_usage", "System memory usage percentage", "percent"
|
||||
),
|
||||
"system_disk_usage": MetricSeries(
|
||||
"system_disk_usage", "System disk usage percentage", "percent"
|
||||
),
|
||||
"system_network_bytes_sent": MetricSeries(
|
||||
"system_network_bytes_sent", "Network bytes sent", "bytes"
|
||||
),
|
||||
"system_network_bytes_recv": MetricSeries(
|
||||
"system_network_bytes_recv", "Network bytes received", "bytes"
|
||||
),
|
||||
|
||||
# Application metrics
|
||||
"app_requests_total": MetricSeries(
|
||||
"app_requests_total", "Total HTTP requests", "count"
|
||||
),
|
||||
"app_request_duration": MetricSeries(
|
||||
"app_request_duration", "HTTP request duration", "seconds"
|
||||
),
|
||||
"app_active_connections": MetricSeries(
|
||||
"app_active_connections", "Active WebSocket connections", "count"
|
||||
),
|
||||
"app_pose_detections": MetricSeries(
|
||||
"app_pose_detections", "Pose detections performed", "count"
|
||||
),
|
||||
"app_pose_processing_time": MetricSeries(
|
||||
"app_pose_processing_time", "Pose processing time", "seconds"
|
||||
),
|
||||
"app_csi_data_points": MetricSeries(
|
||||
"app_csi_data_points", "CSI data points processed", "count"
|
||||
),
|
||||
"app_stream_fps": MetricSeries(
|
||||
"app_stream_fps", "Streaming frames per second", "fps"
|
||||
),
|
||||
|
||||
# Error metrics
|
||||
"app_errors_total": MetricSeries(
|
||||
"app_errors_total", "Total application errors", "count"
|
||||
),
|
||||
"app_http_errors": MetricSeries(
|
||||
"app_http_errors", "HTTP errors", "count"
|
||||
),
|
||||
})
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize metrics service."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
logger.info("Initializing metrics service")
|
||||
self._initialized = True
|
||||
logger.info("Metrics service initialized")
|
||||
|
||||
async def start(self):
|
||||
"""Start metrics service."""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
self._running = True
|
||||
logger.info("Metrics service started")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown metrics service."""
|
||||
self._running = False
|
||||
logger.info("Metrics service shut down")
|
||||
|
||||
async def collect_metrics(self):
|
||||
"""Collect all metrics."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
logger.debug("Collecting metrics")
|
||||
|
||||
# Collect system metrics
|
||||
await self._collect_system_metrics()
|
||||
|
||||
# Collect application metrics
|
||||
await self._collect_application_metrics()
|
||||
|
||||
logger.debug("Metrics collection completed")
|
||||
|
||||
async def _collect_system_metrics(self):
|
||||
"""Collect system-level metrics."""
|
||||
try:
|
||||
# CPU usage
|
||||
cpu_percent = psutil.cpu_percent(interval=1)
|
||||
self._metrics["system_cpu_usage"].add_point(cpu_percent)
|
||||
|
||||
# Memory usage
|
||||
memory = psutil.virtual_memory()
|
||||
self._metrics["system_memory_usage"].add_point(memory.percent)
|
||||
|
||||
# Disk usage
|
||||
disk = psutil.disk_usage('/')
|
||||
disk_percent = (disk.used / disk.total) * 100
|
||||
self._metrics["system_disk_usage"].add_point(disk_percent)
|
||||
|
||||
# Network I/O
|
||||
network = psutil.net_io_counters()
|
||||
self._metrics["system_network_bytes_sent"].add_point(network.bytes_sent)
|
||||
self._metrics["system_network_bytes_recv"].add_point(network.bytes_recv)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting system metrics: {e}")
|
||||
|
||||
async def _collect_application_metrics(self):
|
||||
"""Collect application-specific metrics."""
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from src.api.websocket.connection_manager import connection_manager
|
||||
|
||||
# Active connections
|
||||
connection_stats = await connection_manager.get_connection_stats()
|
||||
active_connections = connection_stats.get("active_connections", 0)
|
||||
self._metrics["app_active_connections"].add_point(active_connections)
|
||||
|
||||
# Update counters as metrics
|
||||
for name, value in self._counters.items():
|
||||
if name in self._metrics:
|
||||
self._metrics[name].add_point(value)
|
||||
|
||||
# Update gauges as metrics
|
||||
for name, value in self._gauges.items():
|
||||
if name in self._metrics:
|
||||
self._metrics[name].add_point(value)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting application metrics: {e}")
|
||||
|
||||
def increment_counter(self, name: str, value: float = 1.0, labels: Optional[Dict[str, str]] = None):
|
||||
"""Increment a counter metric."""
|
||||
self._counters[name] += value
|
||||
|
||||
if name in self._metrics:
|
||||
self._metrics[name].add_point(self._counters[name], labels)
|
||||
|
||||
def set_gauge(self, name: str, value: float, labels: Optional[Dict[str, str]] = None):
|
||||
"""Set a gauge metric value."""
|
||||
self._gauges[name] = value
|
||||
|
||||
if name in self._metrics:
|
||||
self._metrics[name].add_point(value, labels)
|
||||
|
||||
def record_histogram(self, name: str, value: float, labels: Optional[Dict[str, str]] = None):
|
||||
"""Record a histogram value."""
|
||||
self._histograms[name].append(value)
|
||||
|
||||
# Keep only last 1000 values
|
||||
if len(self._histograms[name]) > 1000:
|
||||
self._histograms[name] = self._histograms[name][-1000:]
|
||||
|
||||
if name in self._metrics:
|
||||
self._metrics[name].add_point(value, labels)
|
||||
|
||||
def time_function(self, metric_name: str):
|
||||
"""Decorator to time function execution."""
|
||||
def decorator(func):
|
||||
import functools
|
||||
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
duration = time.time() - start_time
|
||||
self.record_histogram(metric_name, duration)
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
duration = time.time() - start_time
|
||||
self.record_histogram(metric_name, duration)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
def get_metric(self, name: str) -> Optional[MetricSeries]:
|
||||
"""Get a metric series by name."""
|
||||
return self._metrics.get(name)
|
||||
|
||||
def get_metric_value(self, name: str) -> Optional[float]:
|
||||
"""Get the latest value of a metric."""
|
||||
metric = self._metrics.get(name)
|
||||
if metric:
|
||||
latest = metric.get_latest()
|
||||
return latest.value if latest else None
|
||||
return None
|
||||
|
||||
def get_counter_value(self, name: str) -> float:
|
||||
"""Get current counter value."""
|
||||
return self._counters.get(name, 0.0)
|
||||
|
||||
def get_gauge_value(self, name: str) -> Optional[float]:
|
||||
"""Get current gauge value."""
|
||||
return self._gauges.get(name)
|
||||
|
||||
def get_histogram_stats(self, name: str) -> Dict[str, float]:
|
||||
"""Get histogram statistics."""
|
||||
values = self._histograms.get(name, [])
|
||||
if not values:
|
||||
return {}
|
||||
|
||||
sorted_values = sorted(values)
|
||||
count = len(sorted_values)
|
||||
|
||||
return {
|
||||
"count": count,
|
||||
"sum": sum(sorted_values),
|
||||
"min": sorted_values[0],
|
||||
"max": sorted_values[-1],
|
||||
"mean": sum(sorted_values) / count,
|
||||
"p50": sorted_values[int(count * 0.5)],
|
||||
"p90": sorted_values[int(count * 0.9)],
|
||||
"p95": sorted_values[int(count * 0.95)],
|
||||
"p99": sorted_values[int(count * 0.99)],
|
||||
}
|
||||
|
||||
async def get_all_metrics(self) -> Dict[str, Any]:
|
||||
"""Get all current metrics."""
|
||||
metrics = {}
|
||||
|
||||
# Current metric values
|
||||
for name, metric_series in self._metrics.items():
|
||||
latest = metric_series.get_latest()
|
||||
if latest:
|
||||
metrics[name] = {
|
||||
"value": latest.value,
|
||||
"timestamp": latest.timestamp.isoformat(),
|
||||
"description": metric_series.description,
|
||||
"unit": metric_series.unit,
|
||||
"labels": latest.labels
|
||||
}
|
||||
|
||||
# Counter values
|
||||
metrics.update({
|
||||
f"counter_{name}": value
|
||||
for name, value in self._counters.items()
|
||||
})
|
||||
|
||||
# Gauge values
|
||||
metrics.update({
|
||||
f"gauge_{name}": value
|
||||
for name, value in self._gauges.items()
|
||||
})
|
||||
|
||||
# Histogram statistics
|
||||
for name, values in self._histograms.items():
|
||||
if values:
|
||||
stats = self.get_histogram_stats(name)
|
||||
metrics[f"histogram_{name}"] = stats
|
||||
|
||||
return metrics
|
||||
|
||||
async def get_system_metrics(self) -> Dict[str, Any]:
|
||||
"""Get system metrics summary."""
|
||||
return {
|
||||
"cpu_usage": self.get_metric_value("system_cpu_usage"),
|
||||
"memory_usage": self.get_metric_value("system_memory_usage"),
|
||||
"disk_usage": self.get_metric_value("system_disk_usage"),
|
||||
"network_bytes_sent": self.get_metric_value("system_network_bytes_sent"),
|
||||
"network_bytes_recv": self.get_metric_value("system_network_bytes_recv"),
|
||||
}
|
||||
|
||||
async def get_application_metrics(self) -> Dict[str, Any]:
|
||||
"""Get application metrics summary."""
|
||||
return {
|
||||
"requests_total": self.get_counter_value("app_requests_total"),
|
||||
"active_connections": self.get_metric_value("app_active_connections"),
|
||||
"pose_detections": self.get_counter_value("app_pose_detections"),
|
||||
"csi_data_points": self.get_counter_value("app_csi_data_points"),
|
||||
"errors_total": self.get_counter_value("app_errors_total"),
|
||||
"uptime_seconds": time.time() - self._start_time,
|
||||
"request_duration_stats": self.get_histogram_stats("app_request_duration"),
|
||||
"pose_processing_time_stats": self.get_histogram_stats("app_pose_processing_time"),
|
||||
}
|
||||
|
||||
async def get_performance_summary(self) -> Dict[str, Any]:
|
||||
"""Get performance metrics summary."""
|
||||
one_hour = timedelta(hours=1)
|
||||
|
||||
return {
|
||||
"system": {
|
||||
"cpu_avg_1h": self._metrics["system_cpu_usage"].get_average(one_hour),
|
||||
"memory_avg_1h": self._metrics["system_memory_usage"].get_average(one_hour),
|
||||
"cpu_max_1h": self._metrics["system_cpu_usage"].get_max(one_hour),
|
||||
"memory_max_1h": self._metrics["system_memory_usage"].get_max(one_hour),
|
||||
},
|
||||
"application": {
|
||||
"avg_request_duration": self.get_histogram_stats("app_request_duration").get("mean"),
|
||||
"avg_pose_processing_time": self.get_histogram_stats("app_pose_processing_time").get("mean"),
|
||||
"total_requests": self.get_counter_value("app_requests_total"),
|
||||
"total_errors": self.get_counter_value("app_errors_total"),
|
||||
"error_rate": (
|
||||
self.get_counter_value("app_errors_total") /
|
||||
max(self.get_counter_value("app_requests_total"), 1)
|
||||
) * 100,
|
||||
}
|
||||
}
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get metrics service status."""
|
||||
return {
|
||||
"status": "healthy" if self._running else "stopped",
|
||||
"initialized": self._initialized,
|
||||
"running": self._running,
|
||||
"metrics_count": len(self._metrics),
|
||||
"counters_count": len(self._counters),
|
||||
"gauges_count": len(self._gauges),
|
||||
"histograms_count": len(self._histograms),
|
||||
"uptime": time.time() - self._start_time
|
||||
}
|
||||
|
||||
def reset_metrics(self):
|
||||
"""Reset all metrics."""
|
||||
logger.info("Resetting all metrics")
|
||||
|
||||
# Clear metric points but keep series definitions
|
||||
for metric_series in self._metrics.values():
|
||||
metric_series.points.clear()
|
||||
|
||||
# Reset counters, gauges, and histograms
|
||||
self._counters.clear()
|
||||
self._gauges.clear()
|
||||
self._histograms.clear()
|
||||
|
||||
logger.info("All metrics reset")
|
||||
395
v1/src/services/orchestrator.py
Normal file
395
v1/src/services/orchestrator.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
Main service orchestrator for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.services.health_check import HealthCheckService
|
||||
from src.services.metrics import MetricsService
|
||||
from src.api.dependencies import (
|
||||
get_hardware_service,
|
||||
get_pose_service,
|
||||
get_stream_service
|
||||
)
|
||||
from src.api.websocket.connection_manager import connection_manager
|
||||
from src.api.websocket.pose_stream import PoseStreamHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServiceOrchestrator:
|
||||
"""Main service orchestrator that manages all application services."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self._services: Dict[str, Any] = {}
|
||||
self._background_tasks: List[asyncio.Task] = []
|
||||
self._initialized = False
|
||||
self._started = False
|
||||
|
||||
# Core services
|
||||
self.health_service = HealthCheckService(settings)
|
||||
self.metrics_service = MetricsService(settings)
|
||||
|
||||
# Application services (will be initialized later)
|
||||
self.hardware_service = None
|
||||
self.pose_service = None
|
||||
self.stream_service = None
|
||||
self.pose_stream_handler = None
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize all services."""
|
||||
if self._initialized:
|
||||
logger.warning("Services already initialized")
|
||||
return
|
||||
|
||||
logger.info("Initializing services...")
|
||||
|
||||
try:
|
||||
# Initialize core services
|
||||
await self.health_service.initialize()
|
||||
await self.metrics_service.initialize()
|
||||
|
||||
# Initialize application services
|
||||
await self._initialize_application_services()
|
||||
|
||||
# Store services in registry
|
||||
self._services = {
|
||||
'health': self.health_service,
|
||||
'metrics': self.metrics_service,
|
||||
'hardware': self.hardware_service,
|
||||
'pose': self.pose_service,
|
||||
'stream': self.stream_service,
|
||||
'pose_stream_handler': self.pose_stream_handler,
|
||||
'connection_manager': connection_manager
|
||||
}
|
||||
|
||||
self._initialized = True
|
||||
logger.info("All services initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize services: {e}")
|
||||
await self.shutdown()
|
||||
raise
|
||||
|
||||
async def _initialize_application_services(self):
|
||||
"""Initialize application-specific services."""
|
||||
try:
|
||||
# Initialize hardware service
|
||||
self.hardware_service = get_hardware_service()
|
||||
await self.hardware_service.initialize()
|
||||
logger.info("Hardware service initialized")
|
||||
|
||||
# Initialize pose service
|
||||
self.pose_service = get_pose_service()
|
||||
await self.pose_service.initialize()
|
||||
logger.info("Pose service initialized")
|
||||
|
||||
# Initialize stream service
|
||||
self.stream_service = get_stream_service()
|
||||
await self.stream_service.initialize()
|
||||
logger.info("Stream service initialized")
|
||||
|
||||
# Initialize pose stream handler
|
||||
self.pose_stream_handler = PoseStreamHandler(
|
||||
connection_manager=connection_manager,
|
||||
pose_service=self.pose_service,
|
||||
stream_service=self.stream_service
|
||||
)
|
||||
logger.info("Pose stream handler initialized")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize application services: {e}")
|
||||
raise
|
||||
|
||||
async def start(self):
|
||||
"""Start all services and background tasks."""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
if self._started:
|
||||
logger.warning("Services already started")
|
||||
return
|
||||
|
||||
logger.info("Starting services...")
|
||||
|
||||
try:
|
||||
# Start core services
|
||||
await self.health_service.start()
|
||||
await self.metrics_service.start()
|
||||
|
||||
# Start application services
|
||||
await self._start_application_services()
|
||||
|
||||
# Start background tasks
|
||||
await self._start_background_tasks()
|
||||
|
||||
self._started = True
|
||||
logger.info("All services started successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start services: {e}")
|
||||
await self.shutdown()
|
||||
raise
|
||||
|
||||
async def _start_application_services(self):
|
||||
"""Start application-specific services."""
|
||||
try:
|
||||
# Start hardware service
|
||||
if hasattr(self.hardware_service, 'start'):
|
||||
await self.hardware_service.start()
|
||||
|
||||
# Start pose service
|
||||
if hasattr(self.pose_service, 'start'):
|
||||
await self.pose_service.start()
|
||||
|
||||
# Start stream service
|
||||
if hasattr(self.stream_service, 'start'):
|
||||
await self.stream_service.start()
|
||||
|
||||
logger.info("Application services started")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start application services: {e}")
|
||||
raise
|
||||
|
||||
async def _start_background_tasks(self):
|
||||
"""Start background tasks."""
|
||||
try:
|
||||
# Start health check monitoring
|
||||
if self.settings.health_check_interval > 0:
|
||||
task = asyncio.create_task(self._health_check_loop())
|
||||
self._background_tasks.append(task)
|
||||
|
||||
# Start metrics collection
|
||||
if self.settings.metrics_enabled:
|
||||
task = asyncio.create_task(self._metrics_collection_loop())
|
||||
self._background_tasks.append(task)
|
||||
|
||||
# Start pose streaming if enabled
|
||||
if self.settings.enable_real_time_processing:
|
||||
await self.pose_stream_handler.start_streaming()
|
||||
|
||||
logger.info(f"Started {len(self._background_tasks)} background tasks")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start background tasks: {e}")
|
||||
raise
|
||||
|
||||
async def _health_check_loop(self):
|
||||
"""Background health check loop."""
|
||||
logger.info("Starting health check loop")
|
||||
|
||||
while True:
|
||||
try:
|
||||
await self.health_service.perform_health_checks()
|
||||
await asyncio.sleep(self.settings.health_check_interval)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Health check loop cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in health check loop: {e}")
|
||||
await asyncio.sleep(self.settings.health_check_interval)
|
||||
|
||||
async def _metrics_collection_loop(self):
|
||||
"""Background metrics collection loop."""
|
||||
logger.info("Starting metrics collection loop")
|
||||
|
||||
while True:
|
||||
try:
|
||||
await self.metrics_service.collect_metrics()
|
||||
await asyncio.sleep(60) # Collect metrics every minute
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Metrics collection loop cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in metrics collection loop: {e}")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown all services and cleanup resources."""
|
||||
logger.info("Shutting down services...")
|
||||
|
||||
try:
|
||||
# Cancel background tasks
|
||||
for task in self._background_tasks:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
|
||||
if self._background_tasks:
|
||||
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||
self._background_tasks.clear()
|
||||
|
||||
# Stop pose streaming
|
||||
if self.pose_stream_handler:
|
||||
await self.pose_stream_handler.shutdown()
|
||||
|
||||
# Shutdown connection manager
|
||||
await connection_manager.shutdown()
|
||||
|
||||
# Shutdown application services
|
||||
await self._shutdown_application_services()
|
||||
|
||||
# Shutdown core services
|
||||
await self.health_service.shutdown()
|
||||
await self.metrics_service.shutdown()
|
||||
|
||||
self._started = False
|
||||
self._initialized = False
|
||||
|
||||
logger.info("All services shut down successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during shutdown: {e}")
|
||||
|
||||
async def _shutdown_application_services(self):
|
||||
"""Shutdown application-specific services."""
|
||||
try:
|
||||
# Shutdown services in reverse order
|
||||
if self.stream_service and hasattr(self.stream_service, 'shutdown'):
|
||||
await self.stream_service.shutdown()
|
||||
|
||||
if self.pose_service and hasattr(self.pose_service, 'shutdown'):
|
||||
await self.pose_service.shutdown()
|
||||
|
||||
if self.hardware_service and hasattr(self.hardware_service, 'shutdown'):
|
||||
await self.hardware_service.shutdown()
|
||||
|
||||
logger.info("Application services shut down")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error shutting down application services: {e}")
|
||||
|
||||
async def restart_service(self, service_name: str):
|
||||
"""Restart a specific service."""
|
||||
logger.info(f"Restarting service: {service_name}")
|
||||
|
||||
service = self._services.get(service_name)
|
||||
if not service:
|
||||
raise ValueError(f"Service not found: {service_name}")
|
||||
|
||||
try:
|
||||
# Stop service
|
||||
if hasattr(service, 'stop'):
|
||||
await service.stop()
|
||||
elif hasattr(service, 'shutdown'):
|
||||
await service.shutdown()
|
||||
|
||||
# Reinitialize service
|
||||
if hasattr(service, 'initialize'):
|
||||
await service.initialize()
|
||||
|
||||
# Start service
|
||||
if hasattr(service, 'start'):
|
||||
await service.start()
|
||||
|
||||
logger.info(f"Service restarted successfully: {service_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to restart service {service_name}: {e}")
|
||||
raise
|
||||
|
||||
async def reset_services(self):
|
||||
"""Reset all services to initial state."""
|
||||
logger.info("Resetting all services")
|
||||
|
||||
try:
|
||||
# Reset application services
|
||||
if self.hardware_service and hasattr(self.hardware_service, 'reset'):
|
||||
await self.hardware_service.reset()
|
||||
|
||||
if self.pose_service and hasattr(self.pose_service, 'reset'):
|
||||
await self.pose_service.reset()
|
||||
|
||||
if self.stream_service and hasattr(self.stream_service, 'reset'):
|
||||
await self.stream_service.reset()
|
||||
|
||||
# Reset connection manager
|
||||
await connection_manager.reset()
|
||||
|
||||
logger.info("All services reset successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reset services: {e}")
|
||||
raise
|
||||
|
||||
async def get_service_status(self) -> Dict[str, Any]:
|
||||
"""Get status of all services."""
|
||||
status = {}
|
||||
|
||||
for name, service in self._services.items():
|
||||
try:
|
||||
if hasattr(service, 'get_status'):
|
||||
status[name] = await service.get_status()
|
||||
else:
|
||||
status[name] = {"status": "unknown"}
|
||||
except Exception as e:
|
||||
status[name] = {"status": "error", "error": str(e)}
|
||||
|
||||
return status
|
||||
|
||||
async def get_service_metrics(self) -> Dict[str, Any]:
|
||||
"""Get metrics from all services."""
|
||||
metrics = {}
|
||||
|
||||
for name, service in self._services.items():
|
||||
try:
|
||||
if hasattr(service, 'get_metrics'):
|
||||
metrics[name] = await service.get_metrics()
|
||||
elif hasattr(service, 'get_performance_metrics'):
|
||||
metrics[name] = await service.get_performance_metrics()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get metrics from {name}: {e}")
|
||||
metrics[name] = {"error": str(e)}
|
||||
|
||||
return metrics
|
||||
|
||||
async def get_service_info(self) -> Dict[str, Any]:
|
||||
"""Get information about all services."""
|
||||
info = {
|
||||
"total_services": len(self._services),
|
||||
"initialized": self._initialized,
|
||||
"started": self._started,
|
||||
"background_tasks": len(self._background_tasks),
|
||||
"services": {}
|
||||
}
|
||||
|
||||
for name, service in self._services.items():
|
||||
service_info = {
|
||||
"type": type(service).__name__,
|
||||
"module": type(service).__module__
|
||||
}
|
||||
|
||||
# Add service-specific info if available
|
||||
if hasattr(service, 'get_info'):
|
||||
try:
|
||||
service_info.update(await service.get_info())
|
||||
except Exception as e:
|
||||
service_info["error"] = str(e)
|
||||
|
||||
info["services"][name] = service_info
|
||||
|
||||
return info
|
||||
|
||||
def get_service(self, name: str) -> Optional[Any]:
|
||||
"""Get a specific service by name."""
|
||||
return self._services.get(name)
|
||||
|
||||
@property
|
||||
def is_healthy(self) -> bool:
|
||||
"""Check if all services are healthy."""
|
||||
return self._initialized and self._started
|
||||
|
||||
@asynccontextmanager
|
||||
async def service_context(self):
|
||||
"""Context manager for service lifecycle."""
|
||||
try:
|
||||
await self.initialize()
|
||||
await self.start()
|
||||
yield self
|
||||
finally:
|
||||
await self.shutdown()
|
||||
757
v1/src/services/pose_service.py
Normal file
757
v1/src/services/pose_service.py
Normal file
@@ -0,0 +1,757 @@
|
||||
"""
|
||||
Pose estimation service for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.config.domains import DomainConfig
|
||||
from src.core.csi_processor import CSIProcessor
|
||||
from src.core.phase_sanitizer import PhaseSanitizer
|
||||
from src.models.densepose_head import DensePoseHead
|
||||
from src.models.modality_translation import ModalityTranslationNetwork
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PoseService:
|
||||
"""Service for pose estimation operations."""
|
||||
|
||||
def __init__(self, settings: Settings, domain_config: DomainConfig):
|
||||
"""Initialize pose service."""
|
||||
self.settings = settings
|
||||
self.domain_config = domain_config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize components
|
||||
self.csi_processor = None
|
||||
self.phase_sanitizer = None
|
||||
self.densepose_model = None
|
||||
self.modality_translator = None
|
||||
|
||||
# Service state
|
||||
self.is_initialized = False
|
||||
self.is_running = False
|
||||
self.last_error = None
|
||||
|
||||
# Processing statistics
|
||||
self.stats = {
|
||||
"total_processed": 0,
|
||||
"successful_detections": 0,
|
||||
"failed_detections": 0,
|
||||
"average_confidence": 0.0,
|
||||
"processing_time_ms": 0.0
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the pose service."""
|
||||
try:
|
||||
self.logger.info("Initializing pose service...")
|
||||
|
||||
# Initialize CSI processor
|
||||
csi_config = {
|
||||
'buffer_size': self.settings.csi_buffer_size,
|
||||
'sampling_rate': getattr(self.settings, 'csi_sampling_rate', 1000),
|
||||
'window_size': getattr(self.settings, 'csi_window_size', 512),
|
||||
'overlap': getattr(self.settings, 'csi_overlap', 0.5),
|
||||
'noise_threshold': getattr(self.settings, 'csi_noise_threshold', 0.1),
|
||||
'human_detection_threshold': getattr(self.settings, 'csi_human_detection_threshold', 0.8),
|
||||
'smoothing_factor': getattr(self.settings, 'csi_smoothing_factor', 0.9),
|
||||
'max_history_size': getattr(self.settings, 'csi_max_history_size', 500),
|
||||
'num_subcarriers': 56,
|
||||
'num_antennas': 3
|
||||
}
|
||||
self.csi_processor = CSIProcessor(config=csi_config)
|
||||
|
||||
# Initialize phase sanitizer
|
||||
phase_config = {
|
||||
'unwrapping_method': 'numpy',
|
||||
'outlier_threshold': 3.0,
|
||||
'smoothing_window': 5,
|
||||
'enable_outlier_removal': True,
|
||||
'enable_smoothing': True,
|
||||
'enable_noise_filtering': True,
|
||||
'noise_threshold': getattr(self.settings, 'csi_noise_threshold', 0.1)
|
||||
}
|
||||
self.phase_sanitizer = PhaseSanitizer(config=phase_config)
|
||||
|
||||
# Initialize models if not mocking
|
||||
if not self.settings.mock_pose_data:
|
||||
await self._initialize_models()
|
||||
else:
|
||||
self.logger.info("Using mock pose data for development")
|
||||
|
||||
self.is_initialized = True
|
||||
self.logger.info("Pose service initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.last_error = str(e)
|
||||
self.logger.error(f"Failed to initialize pose service: {e}")
|
||||
raise
|
||||
|
||||
async def _initialize_models(self):
|
||||
"""Initialize neural network models."""
|
||||
try:
|
||||
# Initialize DensePose model
|
||||
if self.settings.pose_model_path:
|
||||
self.densepose_model = DensePoseHead()
|
||||
# Load model weights if path is provided
|
||||
# model_state = torch.load(self.settings.pose_model_path)
|
||||
# self.densepose_model.load_state_dict(model_state)
|
||||
self.logger.info("DensePose model loaded")
|
||||
else:
|
||||
self.logger.warning("No pose model path provided, using default model")
|
||||
self.densepose_model = DensePoseHead()
|
||||
|
||||
# Initialize modality translation
|
||||
config = {
|
||||
'input_channels': 64, # CSI data channels
|
||||
'hidden_channels': [128, 256, 512],
|
||||
'output_channels': 256, # Visual feature channels
|
||||
'use_attention': True
|
||||
}
|
||||
self.modality_translator = ModalityTranslationNetwork(config)
|
||||
|
||||
# Set models to evaluation mode
|
||||
self.densepose_model.eval()
|
||||
self.modality_translator.eval()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize models: {e}")
|
||||
raise
|
||||
|
||||
async def start(self):
|
||||
"""Start the pose service."""
|
||||
if not self.is_initialized:
|
||||
await self.initialize()
|
||||
|
||||
self.is_running = True
|
||||
self.logger.info("Pose service started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the pose service."""
|
||||
self.is_running = False
|
||||
self.logger.info("Pose service stopped")
|
||||
|
||||
async def process_csi_data(self, csi_data: np.ndarray, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Process CSI data and estimate poses."""
|
||||
if not self.is_running:
|
||||
raise RuntimeError("Pose service is not running")
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
# Process CSI data
|
||||
processed_csi = await self._process_csi(csi_data, metadata)
|
||||
|
||||
# Estimate poses
|
||||
poses = await self._estimate_poses(processed_csi, metadata)
|
||||
|
||||
# Update statistics
|
||||
processing_time = (datetime.now() - start_time).total_seconds() * 1000
|
||||
self._update_stats(poses, processing_time)
|
||||
|
||||
return {
|
||||
"timestamp": start_time.isoformat(),
|
||||
"poses": poses,
|
||||
"metadata": metadata,
|
||||
"processing_time_ms": processing_time,
|
||||
"confidence_scores": [pose.get("confidence", 0.0) for pose in poses]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.last_error = str(e)
|
||||
self.stats["failed_detections"] += 1
|
||||
self.logger.error(f"Error processing CSI data: {e}")
|
||||
raise
|
||||
|
||||
async def _process_csi(self, csi_data: np.ndarray, metadata: Dict[str, Any]) -> np.ndarray:
|
||||
"""Process raw CSI data."""
|
||||
# Convert raw data to CSIData format
|
||||
from src.hardware.csi_extractor import CSIData
|
||||
|
||||
# Create CSIData object with proper fields
|
||||
# For mock data, create amplitude and phase from input
|
||||
if csi_data.ndim == 1:
|
||||
amplitude = np.abs(csi_data)
|
||||
phase = np.angle(csi_data) if np.iscomplexobj(csi_data) else np.zeros_like(csi_data)
|
||||
else:
|
||||
amplitude = csi_data
|
||||
phase = np.zeros_like(csi_data)
|
||||
|
||||
csi_data_obj = CSIData(
|
||||
timestamp=metadata.get("timestamp", datetime.now()),
|
||||
amplitude=amplitude,
|
||||
phase=phase,
|
||||
frequency=metadata.get("frequency", 5.0), # 5 GHz default
|
||||
bandwidth=metadata.get("bandwidth", 20.0), # 20 MHz default
|
||||
num_subcarriers=metadata.get("num_subcarriers", 56),
|
||||
num_antennas=metadata.get("num_antennas", 3),
|
||||
snr=metadata.get("snr", 20.0), # 20 dB default
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
# Process CSI data
|
||||
try:
|
||||
detection_result = await self.csi_processor.process_csi_data(csi_data_obj)
|
||||
|
||||
# Add to history for temporal analysis
|
||||
self.csi_processor.add_to_history(csi_data_obj)
|
||||
|
||||
# Extract amplitude data for pose estimation
|
||||
if detection_result and detection_result.features:
|
||||
amplitude_data = detection_result.features.amplitude_mean
|
||||
|
||||
# Apply phase sanitization if we have phase data
|
||||
if hasattr(detection_result.features, 'phase_difference'):
|
||||
phase_data = detection_result.features.phase_difference
|
||||
sanitized_phase = self.phase_sanitizer.sanitize(phase_data)
|
||||
# Combine amplitude and phase data
|
||||
return np.concatenate([amplitude_data, sanitized_phase])
|
||||
|
||||
return amplitude_data
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"CSI processing failed, using raw data: {e}")
|
||||
|
||||
return csi_data
|
||||
|
||||
async def _estimate_poses(self, csi_data: np.ndarray, metadata: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Estimate poses from processed CSI data."""
|
||||
if self.settings.mock_pose_data:
|
||||
return self._generate_mock_poses()
|
||||
|
||||
try:
|
||||
# Convert CSI data to tensor
|
||||
csi_tensor = torch.from_numpy(csi_data).float()
|
||||
|
||||
# Add batch dimension if needed
|
||||
if len(csi_tensor.shape) == 2:
|
||||
csi_tensor = csi_tensor.unsqueeze(0)
|
||||
|
||||
# Translate modality (CSI to visual-like features)
|
||||
with torch.no_grad():
|
||||
visual_features = self.modality_translator(csi_tensor)
|
||||
|
||||
# Estimate poses using DensePose
|
||||
pose_outputs = self.densepose_model(visual_features)
|
||||
|
||||
# Convert outputs to pose detections
|
||||
poses = self._parse_pose_outputs(pose_outputs)
|
||||
|
||||
# Filter by confidence threshold
|
||||
filtered_poses = [
|
||||
pose for pose in poses
|
||||
if pose.get("confidence", 0.0) >= self.settings.pose_confidence_threshold
|
||||
]
|
||||
|
||||
# Limit number of persons
|
||||
if len(filtered_poses) > self.settings.pose_max_persons:
|
||||
filtered_poses = sorted(
|
||||
filtered_poses,
|
||||
key=lambda x: x.get("confidence", 0.0),
|
||||
reverse=True
|
||||
)[:self.settings.pose_max_persons]
|
||||
|
||||
return filtered_poses
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in pose estimation: {e}")
|
||||
return []
|
||||
|
||||
def _parse_pose_outputs(self, outputs: torch.Tensor) -> List[Dict[str, Any]]:
|
||||
"""Parse neural network outputs into pose detections."""
|
||||
poses = []
|
||||
|
||||
# This is a simplified parsing - in reality, this would depend on the model architecture
|
||||
# For now, generate mock poses based on the output shape
|
||||
batch_size = outputs.shape[0]
|
||||
|
||||
for i in range(batch_size):
|
||||
# Extract pose information (mock implementation)
|
||||
confidence = float(torch.sigmoid(outputs[i, 0]).item()) if outputs.shape[1] > 0 else 0.5
|
||||
|
||||
pose = {
|
||||
"person_id": i,
|
||||
"confidence": confidence,
|
||||
"keypoints": self._generate_keypoints(),
|
||||
"bounding_box": self._generate_bounding_box(),
|
||||
"activity": self._classify_activity(outputs[i] if len(outputs.shape) > 1 else outputs),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
poses.append(pose)
|
||||
|
||||
return poses
|
||||
|
||||
def _generate_mock_poses(self) -> List[Dict[str, Any]]:
|
||||
"""Generate mock pose data for development."""
|
||||
import random
|
||||
|
||||
num_persons = random.randint(1, min(3, self.settings.pose_max_persons))
|
||||
poses = []
|
||||
|
||||
for i in range(num_persons):
|
||||
confidence = random.uniform(0.3, 0.95)
|
||||
|
||||
pose = {
|
||||
"person_id": i,
|
||||
"confidence": confidence,
|
||||
"keypoints": self._generate_keypoints(),
|
||||
"bounding_box": self._generate_bounding_box(),
|
||||
"activity": random.choice(["standing", "sitting", "walking", "lying"]),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
poses.append(pose)
|
||||
|
||||
return poses
|
||||
|
||||
def _generate_keypoints(self) -> List[Dict[str, Any]]:
|
||||
"""Generate keypoints for a person."""
|
||||
import random
|
||||
|
||||
keypoint_names = [
|
||||
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
|
||||
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
|
||||
"left_wrist", "right_wrist", "left_hip", "right_hip",
|
||||
"left_knee", "right_knee", "left_ankle", "right_ankle"
|
||||
]
|
||||
|
||||
keypoints = []
|
||||
for name in keypoint_names:
|
||||
keypoints.append({
|
||||
"name": name,
|
||||
"x": random.uniform(0.1, 0.9),
|
||||
"y": random.uniform(0.1, 0.9),
|
||||
"confidence": random.uniform(0.5, 0.95)
|
||||
})
|
||||
|
||||
return keypoints
|
||||
|
||||
def _generate_bounding_box(self) -> Dict[str, float]:
|
||||
"""Generate bounding box for a person."""
|
||||
import random
|
||||
|
||||
x = random.uniform(0.1, 0.6)
|
||||
y = random.uniform(0.1, 0.6)
|
||||
width = random.uniform(0.2, 0.4)
|
||||
height = random.uniform(0.3, 0.5)
|
||||
|
||||
return {
|
||||
"x": x,
|
||||
"y": y,
|
||||
"width": width,
|
||||
"height": height
|
||||
}
|
||||
|
||||
def _classify_activity(self, features: torch.Tensor) -> str:
|
||||
"""Classify activity from features."""
|
||||
# Simple mock classification
|
||||
import random
|
||||
activities = ["standing", "sitting", "walking", "lying", "unknown"]
|
||||
return random.choice(activities)
|
||||
|
||||
def _update_stats(self, poses: List[Dict[str, Any]], processing_time: float):
|
||||
"""Update processing statistics."""
|
||||
self.stats["total_processed"] += 1
|
||||
|
||||
if poses:
|
||||
self.stats["successful_detections"] += 1
|
||||
confidences = [pose.get("confidence", 0.0) for pose in poses]
|
||||
avg_confidence = sum(confidences) / len(confidences)
|
||||
|
||||
# Update running average
|
||||
total = self.stats["successful_detections"]
|
||||
current_avg = self.stats["average_confidence"]
|
||||
self.stats["average_confidence"] = (current_avg * (total - 1) + avg_confidence) / total
|
||||
else:
|
||||
self.stats["failed_detections"] += 1
|
||||
|
||||
# Update processing time (running average)
|
||||
total = self.stats["total_processed"]
|
||||
current_avg = self.stats["processing_time_ms"]
|
||||
self.stats["processing_time_ms"] = (current_avg * (total - 1) + processing_time) / total
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get service status."""
|
||||
return {
|
||||
"status": "healthy" if self.is_running and not self.last_error else "unhealthy",
|
||||
"initialized": self.is_initialized,
|
||||
"running": self.is_running,
|
||||
"last_error": self.last_error,
|
||||
"statistics": self.stats.copy(),
|
||||
"configuration": {
|
||||
"mock_data": self.settings.mock_pose_data,
|
||||
"confidence_threshold": self.settings.pose_confidence_threshold,
|
||||
"max_persons": self.settings.pose_max_persons,
|
||||
"batch_size": self.settings.pose_processing_batch_size
|
||||
}
|
||||
}
|
||||
|
||||
async def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get service metrics."""
|
||||
return {
|
||||
"pose_service": {
|
||||
"total_processed": self.stats["total_processed"],
|
||||
"successful_detections": self.stats["successful_detections"],
|
||||
"failed_detections": self.stats["failed_detections"],
|
||||
"success_rate": (
|
||||
self.stats["successful_detections"] / max(1, self.stats["total_processed"])
|
||||
),
|
||||
"average_confidence": self.stats["average_confidence"],
|
||||
"average_processing_time_ms": self.stats["processing_time_ms"]
|
||||
}
|
||||
}
|
||||
|
||||
async def reset(self):
|
||||
"""Reset service state."""
|
||||
self.stats = {
|
||||
"total_processed": 0,
|
||||
"successful_detections": 0,
|
||||
"failed_detections": 0,
|
||||
"average_confidence": 0.0,
|
||||
"processing_time_ms": 0.0
|
||||
}
|
||||
self.last_error = None
|
||||
self.logger.info("Pose service reset")
|
||||
|
||||
# API endpoint methods
|
||||
async def estimate_poses(self, zone_ids=None, confidence_threshold=None, max_persons=None,
|
||||
include_keypoints=True, include_segmentation=False):
|
||||
"""Estimate poses with API parameters."""
|
||||
try:
|
||||
# Generate mock CSI data for estimation
|
||||
mock_csi = np.random.randn(64, 56, 3) # Mock CSI data
|
||||
metadata = {
|
||||
"timestamp": datetime.now(),
|
||||
"zone_ids": zone_ids or ["zone_1"],
|
||||
"confidence_threshold": confidence_threshold or self.settings.pose_confidence_threshold,
|
||||
"max_persons": max_persons or self.settings.pose_max_persons
|
||||
}
|
||||
|
||||
# Process the data
|
||||
result = await self.process_csi_data(mock_csi, metadata)
|
||||
|
||||
# Format for API response
|
||||
persons = []
|
||||
for i, pose in enumerate(result["poses"]):
|
||||
person = {
|
||||
"person_id": str(pose["person_id"]),
|
||||
"confidence": pose["confidence"],
|
||||
"bounding_box": pose["bounding_box"],
|
||||
"zone_id": zone_ids[0] if zone_ids else "zone_1",
|
||||
"activity": pose["activity"],
|
||||
"timestamp": datetime.fromisoformat(pose["timestamp"])
|
||||
}
|
||||
|
||||
if include_keypoints:
|
||||
person["keypoints"] = pose["keypoints"]
|
||||
|
||||
if include_segmentation:
|
||||
person["segmentation"] = {"mask": "mock_segmentation_data"}
|
||||
|
||||
persons.append(person)
|
||||
|
||||
# Zone summary
|
||||
zone_summary = {}
|
||||
for zone_id in (zone_ids or ["zone_1"]):
|
||||
zone_summary[zone_id] = len([p for p in persons if p.get("zone_id") == zone_id])
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(),
|
||||
"frame_id": f"frame_{int(datetime.now().timestamp())}",
|
||||
"persons": persons,
|
||||
"zone_summary": zone_summary,
|
||||
"processing_time_ms": result["processing_time_ms"],
|
||||
"metadata": {"mock_data": self.settings.mock_pose_data}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in estimate_poses: {e}")
|
||||
raise
|
||||
|
||||
async def analyze_with_params(self, zone_ids=None, confidence_threshold=None, max_persons=None,
|
||||
include_keypoints=True, include_segmentation=False):
|
||||
"""Analyze pose data with custom parameters."""
|
||||
return await self.estimate_poses(zone_ids, confidence_threshold, max_persons,
|
||||
include_keypoints, include_segmentation)
|
||||
|
||||
async def get_zone_occupancy(self, zone_id: str):
|
||||
"""Get current occupancy for a specific zone."""
|
||||
try:
|
||||
# Mock occupancy data
|
||||
import random
|
||||
count = random.randint(0, 5)
|
||||
persons = []
|
||||
|
||||
for i in range(count):
|
||||
persons.append({
|
||||
"person_id": f"person_{i}",
|
||||
"confidence": random.uniform(0.7, 0.95),
|
||||
"activity": random.choice(["standing", "sitting", "walking"])
|
||||
})
|
||||
|
||||
return {
|
||||
"count": count,
|
||||
"max_occupancy": 10,
|
||||
"persons": persons,
|
||||
"timestamp": datetime.now()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting zone occupancy: {e}")
|
||||
return None
|
||||
|
||||
async def get_zones_summary(self):
|
||||
"""Get occupancy summary for all zones."""
|
||||
try:
|
||||
import random
|
||||
zones = ["zone_1", "zone_2", "zone_3", "zone_4"]
|
||||
zone_data = {}
|
||||
total_persons = 0
|
||||
active_zones = 0
|
||||
|
||||
for zone_id in zones:
|
||||
count = random.randint(0, 3)
|
||||
zone_data[zone_id] = {
|
||||
"occupancy": count,
|
||||
"max_occupancy": 10,
|
||||
"status": "active" if count > 0 else "inactive"
|
||||
}
|
||||
total_persons += count
|
||||
if count > 0:
|
||||
active_zones += 1
|
||||
|
||||
return {
|
||||
"total_persons": total_persons,
|
||||
"zones": zone_data,
|
||||
"active_zones": active_zones
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting zones summary: {e}")
|
||||
raise
|
||||
|
||||
async def get_historical_data(self, start_time, end_time, zone_ids=None,
|
||||
aggregation_interval=300, include_raw_data=False):
|
||||
"""Get historical pose estimation data."""
|
||||
try:
|
||||
# Mock historical data
|
||||
import random
|
||||
from datetime import timedelta
|
||||
|
||||
current_time = start_time
|
||||
aggregated_data = []
|
||||
raw_data = [] if include_raw_data else None
|
||||
|
||||
while current_time < end_time:
|
||||
# Generate aggregated data point
|
||||
data_point = {
|
||||
"timestamp": current_time,
|
||||
"total_persons": random.randint(0, 8),
|
||||
"zones": {}
|
||||
}
|
||||
|
||||
for zone_id in (zone_ids or ["zone_1", "zone_2", "zone_3"]):
|
||||
data_point["zones"][zone_id] = {
|
||||
"occupancy": random.randint(0, 3),
|
||||
"avg_confidence": random.uniform(0.7, 0.95)
|
||||
}
|
||||
|
||||
aggregated_data.append(data_point)
|
||||
|
||||
# Generate raw data if requested
|
||||
if include_raw_data:
|
||||
for _ in range(random.randint(0, 5)):
|
||||
raw_data.append({
|
||||
"timestamp": current_time + timedelta(seconds=random.randint(0, aggregation_interval)),
|
||||
"person_id": f"person_{random.randint(1, 10)}",
|
||||
"zone_id": random.choice(zone_ids or ["zone_1", "zone_2", "zone_3"]),
|
||||
"confidence": random.uniform(0.5, 0.95),
|
||||
"activity": random.choice(["standing", "sitting", "walking"])
|
||||
})
|
||||
|
||||
current_time += timedelta(seconds=aggregation_interval)
|
||||
|
||||
return {
|
||||
"aggregated_data": aggregated_data,
|
||||
"raw_data": raw_data,
|
||||
"total_records": len(aggregated_data)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting historical data: {e}")
|
||||
raise
|
||||
|
||||
async def get_recent_activities(self, zone_id=None, limit=10):
|
||||
"""Get recently detected activities."""
|
||||
try:
|
||||
import random
|
||||
activities = []
|
||||
|
||||
for i in range(limit):
|
||||
activity = {
|
||||
"activity_id": f"activity_{i}",
|
||||
"person_id": f"person_{random.randint(1, 5)}",
|
||||
"zone_id": zone_id or random.choice(["zone_1", "zone_2", "zone_3"]),
|
||||
"activity": random.choice(["standing", "sitting", "walking", "lying"]),
|
||||
"confidence": random.uniform(0.6, 0.95),
|
||||
"timestamp": datetime.now() - timedelta(minutes=random.randint(0, 60)),
|
||||
"duration_seconds": random.randint(10, 300)
|
||||
}
|
||||
activities.append(activity)
|
||||
|
||||
return activities
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting recent activities: {e}")
|
||||
raise
|
||||
|
||||
async def is_calibrating(self):
|
||||
"""Check if calibration is in progress."""
|
||||
return False # Mock implementation
|
||||
|
||||
async def start_calibration(self):
|
||||
"""Start calibration process."""
|
||||
import uuid
|
||||
calibration_id = str(uuid.uuid4())
|
||||
self.logger.info(f"Started calibration: {calibration_id}")
|
||||
return calibration_id
|
||||
|
||||
async def run_calibration(self, calibration_id):
|
||||
"""Run calibration process."""
|
||||
self.logger.info(f"Running calibration: {calibration_id}")
|
||||
# Mock calibration process
|
||||
await asyncio.sleep(5)
|
||||
self.logger.info(f"Calibration completed: {calibration_id}")
|
||||
|
||||
async def get_calibration_status(self):
|
||||
"""Get current calibration status."""
|
||||
return {
|
||||
"is_calibrating": False,
|
||||
"calibration_id": None,
|
||||
"progress_percent": 100,
|
||||
"current_step": "completed",
|
||||
"estimated_remaining_minutes": 0,
|
||||
"last_calibration": datetime.now() - timedelta(hours=1)
|
||||
}
|
||||
|
||||
async def get_statistics(self, start_time, end_time):
|
||||
"""Get pose estimation statistics."""
|
||||
try:
|
||||
import random
|
||||
|
||||
# Mock statistics
|
||||
total_detections = random.randint(100, 1000)
|
||||
successful_detections = int(total_detections * random.uniform(0.8, 0.95))
|
||||
|
||||
return {
|
||||
"total_detections": total_detections,
|
||||
"successful_detections": successful_detections,
|
||||
"failed_detections": total_detections - successful_detections,
|
||||
"success_rate": successful_detections / total_detections,
|
||||
"average_confidence": random.uniform(0.75, 0.90),
|
||||
"average_processing_time_ms": random.uniform(50, 200),
|
||||
"unique_persons": random.randint(5, 20),
|
||||
"most_active_zone": random.choice(["zone_1", "zone_2", "zone_3"]),
|
||||
"activity_distribution": {
|
||||
"standing": random.uniform(0.3, 0.5),
|
||||
"sitting": random.uniform(0.2, 0.4),
|
||||
"walking": random.uniform(0.1, 0.3),
|
||||
"lying": random.uniform(0.0, 0.1)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting statistics: {e}")
|
||||
raise
|
||||
|
||||
async def process_segmentation_data(self, frame_id):
|
||||
"""Process segmentation data in background."""
|
||||
self.logger.info(f"Processing segmentation data for frame: {frame_id}")
|
||||
# Mock background processing
|
||||
await asyncio.sleep(2)
|
||||
self.logger.info(f"Segmentation processing completed for frame: {frame_id}")
|
||||
|
||||
# WebSocket streaming methods
|
||||
async def get_current_pose_data(self):
|
||||
"""Get current pose data for streaming."""
|
||||
try:
|
||||
# Generate current pose data
|
||||
result = await self.estimate_poses()
|
||||
|
||||
# Format data by zones for WebSocket streaming
|
||||
zone_data = {}
|
||||
|
||||
# Group persons by zone
|
||||
for person in result["persons"]:
|
||||
zone_id = person.get("zone_id", "zone_1")
|
||||
|
||||
if zone_id not in zone_data:
|
||||
zone_data[zone_id] = {
|
||||
"pose": {
|
||||
"persons": [],
|
||||
"count": 0
|
||||
},
|
||||
"confidence": 0.0,
|
||||
"activity": None,
|
||||
"metadata": {
|
||||
"frame_id": result["frame_id"],
|
||||
"processing_time_ms": result["processing_time_ms"]
|
||||
}
|
||||
}
|
||||
|
||||
zone_data[zone_id]["pose"]["persons"].append(person)
|
||||
zone_data[zone_id]["pose"]["count"] += 1
|
||||
|
||||
# Update zone confidence (average)
|
||||
current_confidence = zone_data[zone_id]["confidence"]
|
||||
person_confidence = person.get("confidence", 0.0)
|
||||
zone_data[zone_id]["confidence"] = (current_confidence + person_confidence) / 2
|
||||
|
||||
# Set activity if not already set
|
||||
if not zone_data[zone_id]["activity"] and person.get("activity"):
|
||||
zone_data[zone_id]["activity"] = person["activity"]
|
||||
|
||||
return zone_data
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error getting current pose data: {e}")
|
||||
# Return empty zone data on error
|
||||
return {}
|
||||
|
||||
# Health check methods
|
||||
async def health_check(self):
|
||||
"""Perform health check."""
|
||||
try:
|
||||
status = "healthy" if self.is_running and not self.last_error else "unhealthy"
|
||||
|
||||
return {
|
||||
"status": status,
|
||||
"message": self.last_error if self.last_error else "Service is running normally",
|
||||
"uptime_seconds": 0.0, # TODO: Implement actual uptime tracking
|
||||
"metrics": {
|
||||
"total_processed": self.stats["total_processed"],
|
||||
"success_rate": (
|
||||
self.stats["successful_detections"] / max(1, self.stats["total_processed"])
|
||||
),
|
||||
"average_processing_time_ms": self.stats["processing_time_ms"]
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"message": f"Health check failed: {str(e)}"
|
||||
}
|
||||
|
||||
async def is_ready(self):
|
||||
"""Check if service is ready."""
|
||||
return self.is_initialized and self.is_running
|
||||
397
v1/src/services/stream_service.py
Normal file
397
v1/src/services/stream_service.py
Normal file
@@ -0,0 +1,397 @@
|
||||
"""
|
||||
Real-time streaming service for WiFi-DensePose API
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, List, Optional, Any, Set
|
||||
from datetime import datetime
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
from fastapi import WebSocket
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.config.domains import DomainConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamService:
|
||||
"""Service for real-time data streaming."""
|
||||
|
||||
def __init__(self, settings: Settings, domain_config: DomainConfig):
|
||||
"""Initialize stream service."""
|
||||
self.settings = settings
|
||||
self.domain_config = domain_config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# WebSocket connections
|
||||
self.connections: Set[WebSocket] = set()
|
||||
self.connection_metadata: Dict[WebSocket, Dict[str, Any]] = {}
|
||||
|
||||
# Stream buffers
|
||||
self.pose_buffer = deque(maxlen=self.settings.stream_buffer_size)
|
||||
self.csi_buffer = deque(maxlen=self.settings.stream_buffer_size)
|
||||
|
||||
# Service state
|
||||
self.is_running = False
|
||||
self.last_error = None
|
||||
|
||||
# Streaming statistics
|
||||
self.stats = {
|
||||
"active_connections": 0,
|
||||
"total_connections": 0,
|
||||
"messages_sent": 0,
|
||||
"messages_failed": 0,
|
||||
"data_points_streamed": 0,
|
||||
"average_latency_ms": 0.0
|
||||
}
|
||||
|
||||
# Background tasks
|
||||
self.streaming_task = None
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the stream service."""
|
||||
self.logger.info("Stream service initialized")
|
||||
|
||||
async def start(self):
|
||||
"""Start the stream service."""
|
||||
if self.is_running:
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
self.logger.info("Stream service started")
|
||||
|
||||
# Start background streaming task
|
||||
if self.settings.enable_real_time_processing:
|
||||
self.streaming_task = asyncio.create_task(self._streaming_loop())
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the stream service."""
|
||||
self.is_running = False
|
||||
|
||||
# Cancel background task
|
||||
if self.streaming_task:
|
||||
self.streaming_task.cancel()
|
||||
try:
|
||||
await self.streaming_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Close all connections
|
||||
await self._close_all_connections()
|
||||
|
||||
self.logger.info("Stream service stopped")
|
||||
|
||||
async def add_connection(self, websocket: WebSocket, metadata: Dict[str, Any] = None):
|
||||
"""Add a new WebSocket connection."""
|
||||
try:
|
||||
await websocket.accept()
|
||||
self.connections.add(websocket)
|
||||
self.connection_metadata[websocket] = metadata or {}
|
||||
|
||||
self.stats["active_connections"] = len(self.connections)
|
||||
self.stats["total_connections"] += 1
|
||||
|
||||
self.logger.info(f"New WebSocket connection added. Total: {len(self.connections)}")
|
||||
|
||||
# Send initial data if available
|
||||
await self._send_initial_data(websocket)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error adding WebSocket connection: {e}")
|
||||
raise
|
||||
|
||||
async def remove_connection(self, websocket: WebSocket):
|
||||
"""Remove a WebSocket connection."""
|
||||
try:
|
||||
if websocket in self.connections:
|
||||
self.connections.remove(websocket)
|
||||
self.connection_metadata.pop(websocket, None)
|
||||
|
||||
self.stats["active_connections"] = len(self.connections)
|
||||
|
||||
self.logger.info(f"WebSocket connection removed. Total: {len(self.connections)}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error removing WebSocket connection: {e}")
|
||||
|
||||
async def broadcast_pose_data(self, pose_data: Dict[str, Any]):
|
||||
"""Broadcast pose data to all connected clients."""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
# Add to buffer
|
||||
self.pose_buffer.append({
|
||||
"type": "pose_data",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": pose_data
|
||||
})
|
||||
|
||||
# Broadcast to all connections
|
||||
await self._broadcast_message({
|
||||
"type": "pose_update",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": pose_data
|
||||
})
|
||||
|
||||
async def broadcast_csi_data(self, csi_data: np.ndarray, metadata: Dict[str, Any]):
|
||||
"""Broadcast CSI data to all connected clients."""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
# Convert numpy array to list for JSON serialization
|
||||
csi_list = csi_data.tolist() if isinstance(csi_data, np.ndarray) else csi_data
|
||||
|
||||
# Add to buffer
|
||||
self.csi_buffer.append({
|
||||
"type": "csi_data",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": csi_list,
|
||||
"metadata": metadata
|
||||
})
|
||||
|
||||
# Broadcast to all connections
|
||||
await self._broadcast_message({
|
||||
"type": "csi_update",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": csi_list,
|
||||
"metadata": metadata
|
||||
})
|
||||
|
||||
async def broadcast_system_status(self, status_data: Dict[str, Any]):
|
||||
"""Broadcast system status to all connected clients."""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
await self._broadcast_message({
|
||||
"type": "system_status",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": status_data
|
||||
})
|
||||
|
||||
async def send_to_connection(self, websocket: WebSocket, message: Dict[str, Any]):
|
||||
"""Send message to a specific connection."""
|
||||
try:
|
||||
if websocket in self.connections:
|
||||
await websocket.send_text(json.dumps(message))
|
||||
self.stats["messages_sent"] += 1
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error sending message to connection: {e}")
|
||||
self.stats["messages_failed"] += 1
|
||||
await self.remove_connection(websocket)
|
||||
|
||||
async def _broadcast_message(self, message: Dict[str, Any]):
|
||||
"""Broadcast message to all connected clients."""
|
||||
if not self.connections:
|
||||
return
|
||||
|
||||
disconnected = set()
|
||||
|
||||
for websocket in self.connections.copy():
|
||||
try:
|
||||
await websocket.send_text(json.dumps(message))
|
||||
self.stats["messages_sent"] += 1
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to send message to connection: {e}")
|
||||
self.stats["messages_failed"] += 1
|
||||
disconnected.add(websocket)
|
||||
|
||||
# Remove disconnected clients
|
||||
for websocket in disconnected:
|
||||
await self.remove_connection(websocket)
|
||||
|
||||
if message.get("type") in ["pose_update", "csi_update"]:
|
||||
self.stats["data_points_streamed"] += 1
|
||||
|
||||
async def _send_initial_data(self, websocket: WebSocket):
|
||||
"""Send initial data to a new connection."""
|
||||
try:
|
||||
# Send recent pose data
|
||||
if self.pose_buffer:
|
||||
recent_poses = list(self.pose_buffer)[-10:] # Last 10 poses
|
||||
await self.send_to_connection(websocket, {
|
||||
"type": "initial_poses",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": recent_poses
|
||||
})
|
||||
|
||||
# Send recent CSI data
|
||||
if self.csi_buffer:
|
||||
recent_csi = list(self.csi_buffer)[-5:] # Last 5 CSI readings
|
||||
await self.send_to_connection(websocket, {
|
||||
"type": "initial_csi",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": recent_csi
|
||||
})
|
||||
|
||||
# Send service status
|
||||
status = await self.get_status()
|
||||
await self.send_to_connection(websocket, {
|
||||
"type": "service_status",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": status
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error sending initial data: {e}")
|
||||
|
||||
async def _streaming_loop(self):
|
||||
"""Background streaming loop for periodic updates."""
|
||||
try:
|
||||
while self.is_running:
|
||||
# Send periodic heartbeat
|
||||
if self.connections:
|
||||
await self._broadcast_message({
|
||||
"type": "heartbeat",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"active_connections": len(self.connections)
|
||||
})
|
||||
|
||||
# Wait for next iteration
|
||||
await asyncio.sleep(self.settings.websocket_ping_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self.logger.info("Streaming loop cancelled")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in streaming loop: {e}")
|
||||
self.last_error = str(e)
|
||||
|
||||
async def _close_all_connections(self):
|
||||
"""Close all WebSocket connections."""
|
||||
disconnected = []
|
||||
|
||||
for websocket in self.connections.copy():
|
||||
try:
|
||||
await websocket.close()
|
||||
disconnected.append(websocket)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error closing connection: {e}")
|
||||
disconnected.append(websocket)
|
||||
|
||||
# Clear all connections
|
||||
for websocket in disconnected:
|
||||
await self.remove_connection(websocket)
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get service status."""
|
||||
return {
|
||||
"status": "healthy" if self.is_running and not self.last_error else "unhealthy",
|
||||
"running": self.is_running,
|
||||
"last_error": self.last_error,
|
||||
"connections": {
|
||||
"active": len(self.connections),
|
||||
"total": self.stats["total_connections"]
|
||||
},
|
||||
"buffers": {
|
||||
"pose_buffer_size": len(self.pose_buffer),
|
||||
"csi_buffer_size": len(self.csi_buffer),
|
||||
"max_buffer_size": self.settings.stream_buffer_size
|
||||
},
|
||||
"statistics": self.stats.copy(),
|
||||
"configuration": {
|
||||
"stream_fps": self.settings.stream_fps,
|
||||
"buffer_size": self.settings.stream_buffer_size,
|
||||
"ping_interval": self.settings.websocket_ping_interval,
|
||||
"timeout": self.settings.websocket_timeout
|
||||
}
|
||||
}
|
||||
|
||||
async def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get service metrics."""
|
||||
total_messages = self.stats["messages_sent"] + self.stats["messages_failed"]
|
||||
success_rate = self.stats["messages_sent"] / max(1, total_messages)
|
||||
|
||||
return {
|
||||
"stream_service": {
|
||||
"active_connections": self.stats["active_connections"],
|
||||
"total_connections": self.stats["total_connections"],
|
||||
"messages_sent": self.stats["messages_sent"],
|
||||
"messages_failed": self.stats["messages_failed"],
|
||||
"message_success_rate": success_rate,
|
||||
"data_points_streamed": self.stats["data_points_streamed"],
|
||||
"average_latency_ms": self.stats["average_latency_ms"]
|
||||
}
|
||||
}
|
||||
|
||||
async def get_connection_info(self) -> List[Dict[str, Any]]:
|
||||
"""Get information about active connections."""
|
||||
connections_info = []
|
||||
|
||||
for websocket in self.connections:
|
||||
metadata = self.connection_metadata.get(websocket, {})
|
||||
|
||||
connection_info = {
|
||||
"id": id(websocket),
|
||||
"connected_at": metadata.get("connected_at", "unknown"),
|
||||
"user_agent": metadata.get("user_agent", "unknown"),
|
||||
"ip_address": metadata.get("ip_address", "unknown"),
|
||||
"subscription_types": metadata.get("subscription_types", [])
|
||||
}
|
||||
|
||||
connections_info.append(connection_info)
|
||||
|
||||
return connections_info
|
||||
|
||||
async def reset(self):
|
||||
"""Reset service state."""
|
||||
# Clear buffers
|
||||
self.pose_buffer.clear()
|
||||
self.csi_buffer.clear()
|
||||
|
||||
# Reset statistics
|
||||
self.stats = {
|
||||
"active_connections": len(self.connections),
|
||||
"total_connections": 0,
|
||||
"messages_sent": 0,
|
||||
"messages_failed": 0,
|
||||
"data_points_streamed": 0,
|
||||
"average_latency_ms": 0.0
|
||||
}
|
||||
|
||||
self.last_error = None
|
||||
self.logger.info("Stream service reset")
|
||||
|
||||
def get_buffer_data(self, buffer_type: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Get data from buffers."""
|
||||
if buffer_type == "pose":
|
||||
return list(self.pose_buffer)[-limit:]
|
||||
elif buffer_type == "csi":
|
||||
return list(self.csi_buffer)[-limit:]
|
||||
else:
|
||||
return []
|
||||
|
||||
@property
|
||||
def is_active(self) -> bool:
|
||||
"""Check if stream service is active."""
|
||||
return self.is_running
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Perform health check."""
|
||||
try:
|
||||
status = "healthy" if self.is_running and not self.last_error else "unhealthy"
|
||||
|
||||
return {
|
||||
"status": status,
|
||||
"message": self.last_error if self.last_error else "Stream service is running normally",
|
||||
"active_connections": len(self.connections),
|
||||
"metrics": {
|
||||
"messages_sent": self.stats["messages_sent"],
|
||||
"messages_failed": self.stats["messages_failed"],
|
||||
"data_points_streamed": self.stats["data_points_streamed"]
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"message": f"Health check failed: {str(e)}"
|
||||
}
|
||||
|
||||
async def is_ready(self) -> bool:
|
||||
"""Check if service is ready."""
|
||||
return self.is_running
|
||||
Reference in New Issue
Block a user