feat: Complete Rust port of WiFi-DensePose with modular crates
Major changes: - Organized Python v1 implementation into v1/ subdirectory - Created Rust workspace with 9 modular crates: - wifi-densepose-core: Core types, traits, errors - wifi-densepose-signal: CSI processing, phase sanitization, FFT - wifi-densepose-nn: Neural network inference (ONNX/Candle/tch) - wifi-densepose-api: Axum-based REST/WebSocket API - wifi-densepose-db: SQLx database layer - wifi-densepose-config: Configuration management - wifi-densepose-hardware: Hardware abstraction - wifi-densepose-wasm: WebAssembly bindings - wifi-densepose-cli: Command-line interface Documentation: - ADR-001: Workspace structure - ADR-002: Signal processing library selection - ADR-003: Neural network inference strategy - DDD domain model with bounded contexts Testing: - 69 tests passing across all crates - Signal processing: 45 tests - Neural networks: 21 tests - Core: 3 doc tests Performance targets: - 10x faster CSI processing (~0.5ms vs ~5ms) - 5x lower memory usage (~100MB vs ~500MB) - WASM support for browser deployment
This commit is contained in:
@@ -0,0 +1,506 @@
|
||||
//! Error types for the WiFi-DensePose system.
|
||||
//!
|
||||
//! This module provides comprehensive error handling using [`thiserror`] for
|
||||
//! automatic `Display` and `Error` trait implementations.
|
||||
//!
|
||||
//! # Error Hierarchy
|
||||
//!
|
||||
//! - [`CoreError`]: Top-level error type that encompasses all subsystem errors
|
||||
//! - [`SignalError`]: Errors related to CSI signal processing
|
||||
//! - [`InferenceError`]: Errors from neural network inference
|
||||
//! - [`StorageError`]: Errors from data persistence operations
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```rust
|
||||
//! use wifi_densepose_core::error::{CoreError, SignalError};
|
||||
//!
|
||||
//! fn process_signal() -> Result<(), CoreError> {
|
||||
//! // Signal processing that might fail
|
||||
//! Err(SignalError::InvalidSubcarrierCount { expected: 256, actual: 128 }.into())
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// A specialized `Result` type for core operations.
|
||||
pub type CoreResult<T> = Result<T, CoreError>;
|
||||
|
||||
/// Top-level error type for the WiFi-DensePose system.
|
||||
///
|
||||
/// This enum encompasses all possible errors that can occur within the core
|
||||
/// system, providing a unified error type for the entire crate.
|
||||
#[derive(Error, Debug)]
|
||||
#[non_exhaustive]
|
||||
pub enum CoreError {
|
||||
/// Signal processing error
|
||||
#[error("Signal processing error: {0}")]
|
||||
Signal(#[from] SignalError),
|
||||
|
||||
/// Neural network inference error
|
||||
#[error("Inference error: {0}")]
|
||||
Inference(#[from] InferenceError),
|
||||
|
||||
/// Data storage error
|
||||
#[error("Storage error: {0}")]
|
||||
Storage(#[from] StorageError),
|
||||
|
||||
/// Configuration error
|
||||
#[error("Configuration error: {message}")]
|
||||
Configuration {
|
||||
/// Description of the configuration error
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// Validation error for input data
|
||||
#[error("Validation error: {message}")]
|
||||
Validation {
|
||||
/// Description of what validation failed
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// Resource not found
|
||||
#[error("Resource not found: {resource_type} with id '{id}'")]
|
||||
NotFound {
|
||||
/// Type of resource that was not found
|
||||
resource_type: &'static str,
|
||||
/// Identifier of the missing resource
|
||||
id: String,
|
||||
},
|
||||
|
||||
/// Operation timed out
|
||||
#[error("Operation timed out after {duration_ms}ms: {operation}")]
|
||||
Timeout {
|
||||
/// The operation that timed out
|
||||
operation: String,
|
||||
/// Duration in milliseconds before timeout
|
||||
duration_ms: u64,
|
||||
},
|
||||
|
||||
/// Invalid state for the requested operation
|
||||
#[error("Invalid state: expected {expected}, found {actual}")]
|
||||
InvalidState {
|
||||
/// Expected state
|
||||
expected: String,
|
||||
/// Actual state
|
||||
actual: String,
|
||||
},
|
||||
|
||||
/// Internal error (should not happen in normal operation)
|
||||
#[error("Internal error: {message}")]
|
||||
Internal {
|
||||
/// Description of the internal error
|
||||
message: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl CoreError {
|
||||
/// Creates a new configuration error.
|
||||
#[must_use]
|
||||
pub fn configuration(message: impl Into<String>) -> Self {
|
||||
Self::Configuration {
|
||||
message: message.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new validation error.
|
||||
#[must_use]
|
||||
pub fn validation(message: impl Into<String>) -> Self {
|
||||
Self::Validation {
|
||||
message: message.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new not found error.
|
||||
#[must_use]
|
||||
pub fn not_found(resource_type: &'static str, id: impl Into<String>) -> Self {
|
||||
Self::NotFound {
|
||||
resource_type,
|
||||
id: id.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new timeout error.
|
||||
#[must_use]
|
||||
pub fn timeout(operation: impl Into<String>, duration_ms: u64) -> Self {
|
||||
Self::Timeout {
|
||||
operation: operation.into(),
|
||||
duration_ms,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new invalid state error.
|
||||
#[must_use]
|
||||
pub fn invalid_state(expected: impl Into<String>, actual: impl Into<String>) -> Self {
|
||||
Self::InvalidState {
|
||||
expected: expected.into(),
|
||||
actual: actual.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new internal error.
|
||||
#[must_use]
|
||||
pub fn internal(message: impl Into<String>) -> Self {
|
||||
Self::Internal {
|
||||
message: message.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` if this error is recoverable.
|
||||
#[must_use]
|
||||
pub fn is_recoverable(&self) -> bool {
|
||||
match self {
|
||||
Self::Signal(e) => e.is_recoverable(),
|
||||
Self::Inference(e) => e.is_recoverable(),
|
||||
Self::Storage(e) => e.is_recoverable(),
|
||||
Self::Timeout { .. } => true,
|
||||
Self::NotFound { .. }
|
||||
| Self::Configuration { .. }
|
||||
| Self::Validation { .. }
|
||||
| Self::InvalidState { .. }
|
||||
| Self::Internal { .. } => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors related to CSI signal processing.
|
||||
#[derive(Error, Debug)]
|
||||
#[non_exhaustive]
|
||||
pub enum SignalError {
|
||||
/// Invalid number of subcarriers in CSI data
|
||||
#[error("Invalid subcarrier count: expected {expected}, got {actual}")]
|
||||
InvalidSubcarrierCount {
|
||||
/// Expected number of subcarriers
|
||||
expected: usize,
|
||||
/// Actual number of subcarriers received
|
||||
actual: usize,
|
||||
},
|
||||
|
||||
/// Invalid antenna configuration
|
||||
#[error("Invalid antenna configuration: {message}")]
|
||||
InvalidAntennaConfig {
|
||||
/// Description of the configuration error
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// Signal amplitude out of valid range
|
||||
#[error("Signal amplitude {value} out of range [{min}, {max}]")]
|
||||
AmplitudeOutOfRange {
|
||||
/// The invalid amplitude value
|
||||
value: f64,
|
||||
/// Minimum valid amplitude
|
||||
min: f64,
|
||||
/// Maximum valid amplitude
|
||||
max: f64,
|
||||
},
|
||||
|
||||
/// Phase unwrapping failed
|
||||
#[error("Phase unwrapping failed: {reason}")]
|
||||
PhaseUnwrapFailed {
|
||||
/// Reason for the failure
|
||||
reason: String,
|
||||
},
|
||||
|
||||
/// FFT operation failed
|
||||
#[error("FFT operation failed: {message}")]
|
||||
FftFailed {
|
||||
/// Description of the FFT error
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// Filter design or application error
|
||||
#[error("Filter error: {message}")]
|
||||
FilterError {
|
||||
/// Description of the filter error
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// Insufficient samples for processing
|
||||
#[error("Insufficient samples: need at least {required}, got {available}")]
|
||||
InsufficientSamples {
|
||||
/// Minimum required samples
|
||||
required: usize,
|
||||
/// Available samples
|
||||
available: usize,
|
||||
},
|
||||
|
||||
/// Signal quality too low for reliable processing
|
||||
#[error("Signal quality too low: SNR {snr_db:.2} dB below threshold {threshold_db:.2} dB")]
|
||||
LowSignalQuality {
|
||||
/// Measured SNR in dB
|
||||
snr_db: f64,
|
||||
/// Required minimum SNR in dB
|
||||
threshold_db: f64,
|
||||
},
|
||||
|
||||
/// Timestamp synchronization error
|
||||
#[error("Timestamp synchronization error: {message}")]
|
||||
TimestampSync {
|
||||
/// Description of the sync error
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// Invalid frequency band
|
||||
#[error("Invalid frequency band: {band}")]
|
||||
InvalidFrequencyBand {
|
||||
/// The invalid band identifier
|
||||
band: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl SignalError {
|
||||
/// Returns `true` if this error is recoverable.
|
||||
#[must_use]
|
||||
pub const fn is_recoverable(&self) -> bool {
|
||||
match self {
|
||||
Self::LowSignalQuality { .. }
|
||||
| Self::InsufficientSamples { .. }
|
||||
| Self::TimestampSync { .. }
|
||||
| Self::PhaseUnwrapFailed { .. }
|
||||
| Self::FftFailed { .. } => true,
|
||||
Self::InvalidSubcarrierCount { .. }
|
||||
| Self::InvalidAntennaConfig { .. }
|
||||
| Self::AmplitudeOutOfRange { .. }
|
||||
| Self::FilterError { .. }
|
||||
| Self::InvalidFrequencyBand { .. } => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors related to neural network inference.
|
||||
#[derive(Error, Debug)]
|
||||
#[non_exhaustive]
|
||||
pub enum InferenceError {
|
||||
/// Model file not found or could not be loaded
|
||||
#[error("Failed to load model from '{path}': {reason}")]
|
||||
ModelLoadFailed {
|
||||
/// Path to the model file
|
||||
path: String,
|
||||
/// Reason for the failure
|
||||
reason: String,
|
||||
},
|
||||
|
||||
/// Input tensor shape mismatch
|
||||
#[error("Input shape mismatch: expected {expected:?}, got {actual:?}")]
|
||||
InputShapeMismatch {
|
||||
/// Expected tensor shape
|
||||
expected: Vec<usize>,
|
||||
/// Actual tensor shape
|
||||
actual: Vec<usize>,
|
||||
},
|
||||
|
||||
/// Output tensor shape mismatch
|
||||
#[error("Output shape mismatch: expected {expected:?}, got {actual:?}")]
|
||||
OutputShapeMismatch {
|
||||
/// Expected tensor shape
|
||||
expected: Vec<usize>,
|
||||
/// Actual tensor shape
|
||||
actual: Vec<usize>,
|
||||
},
|
||||
|
||||
/// CUDA/GPU error
|
||||
#[error("GPU error: {message}")]
|
||||
GpuError {
|
||||
/// Description of the GPU error
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// Model inference failed
|
||||
#[error("Inference failed: {message}")]
|
||||
InferenceFailed {
|
||||
/// Description of the failure
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// Model not initialized
|
||||
#[error("Model not initialized: {name}")]
|
||||
ModelNotInitialized {
|
||||
/// Name of the uninitialized model
|
||||
name: String,
|
||||
},
|
||||
|
||||
/// Unsupported model format
|
||||
#[error("Unsupported model format: {format}")]
|
||||
UnsupportedFormat {
|
||||
/// The unsupported format
|
||||
format: String,
|
||||
},
|
||||
|
||||
/// Quantization error
|
||||
#[error("Quantization error: {message}")]
|
||||
QuantizationError {
|
||||
/// Description of the quantization error
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// Batch size error
|
||||
#[error("Invalid batch size: {size}, maximum is {max_size}")]
|
||||
InvalidBatchSize {
|
||||
/// The invalid batch size
|
||||
size: usize,
|
||||
/// Maximum allowed batch size
|
||||
max_size: usize,
|
||||
},
|
||||
}
|
||||
|
||||
impl InferenceError {
|
||||
/// Returns `true` if this error is recoverable.
|
||||
#[must_use]
|
||||
pub const fn is_recoverable(&self) -> bool {
|
||||
match self {
|
||||
Self::GpuError { .. } | Self::InferenceFailed { .. } => true,
|
||||
Self::ModelLoadFailed { .. }
|
||||
| Self::InputShapeMismatch { .. }
|
||||
| Self::OutputShapeMismatch { .. }
|
||||
| Self::ModelNotInitialized { .. }
|
||||
| Self::UnsupportedFormat { .. }
|
||||
| Self::QuantizationError { .. }
|
||||
| Self::InvalidBatchSize { .. } => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors related to data storage and persistence.
|
||||
#[derive(Error, Debug)]
|
||||
#[non_exhaustive]
|
||||
pub enum StorageError {
|
||||
/// Database connection failed
|
||||
#[error("Database connection failed: {message}")]
|
||||
ConnectionFailed {
|
||||
/// Description of the connection error
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// Query execution failed
|
||||
#[error("Query failed: {query_type} - {message}")]
|
||||
QueryFailed {
|
||||
/// Type of query that failed
|
||||
query_type: String,
|
||||
/// Error message
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// Record not found
|
||||
#[error("Record not found: {table}.{id}")]
|
||||
RecordNotFound {
|
||||
/// Table name
|
||||
table: String,
|
||||
/// Record identifier
|
||||
id: String,
|
||||
},
|
||||
|
||||
/// Duplicate key violation
|
||||
#[error("Duplicate key in {table}: {key}")]
|
||||
DuplicateKey {
|
||||
/// Table name
|
||||
table: String,
|
||||
/// The duplicate key
|
||||
key: String,
|
||||
},
|
||||
|
||||
/// Transaction error
|
||||
#[error("Transaction error: {message}")]
|
||||
TransactionError {
|
||||
/// Description of the transaction error
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// Serialization/deserialization error
|
||||
#[error("Serialization error: {message}")]
|
||||
SerializationError {
|
||||
/// Description of the serialization error
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// Cache error
|
||||
#[error("Cache error: {message}")]
|
||||
CacheError {
|
||||
/// Description of the cache error
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// Migration error
|
||||
#[error("Migration error: {message}")]
|
||||
MigrationError {
|
||||
/// Description of the migration error
|
||||
message: String,
|
||||
},
|
||||
|
||||
/// Storage capacity exceeded
|
||||
#[error("Storage capacity exceeded: {current} / {limit} bytes")]
|
||||
CapacityExceeded {
|
||||
/// Current storage usage
|
||||
current: u64,
|
||||
/// Storage limit
|
||||
limit: u64,
|
||||
},
|
||||
}
|
||||
|
||||
impl StorageError {
|
||||
/// Returns `true` if this error is recoverable.
|
||||
#[must_use]
|
||||
pub const fn is_recoverable(&self) -> bool {
|
||||
match self {
|
||||
Self::ConnectionFailed { .. }
|
||||
| Self::QueryFailed { .. }
|
||||
| Self::TransactionError { .. }
|
||||
| Self::CacheError { .. } => true,
|
||||
Self::RecordNotFound { .. }
|
||||
| Self::DuplicateKey { .. }
|
||||
| Self::SerializationError { .. }
|
||||
| Self::MigrationError { .. }
|
||||
| Self::CapacityExceeded { .. } => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_core_error_display() {
|
||||
let err = CoreError::configuration("Invalid threshold value");
|
||||
assert!(err.to_string().contains("Configuration error"));
|
||||
assert!(err.to_string().contains("Invalid threshold"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_signal_error_recoverable() {
|
||||
let recoverable = SignalError::LowSignalQuality {
|
||||
snr_db: 5.0,
|
||||
threshold_db: 10.0,
|
||||
};
|
||||
assert!(recoverable.is_recoverable());
|
||||
|
||||
let non_recoverable = SignalError::InvalidSubcarrierCount {
|
||||
expected: 256,
|
||||
actual: 128,
|
||||
};
|
||||
assert!(!non_recoverable.is_recoverable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_conversion() {
|
||||
let signal_err = SignalError::InvalidSubcarrierCount {
|
||||
expected: 256,
|
||||
actual: 128,
|
||||
};
|
||||
let core_err: CoreError = signal_err.into();
|
||||
assert!(matches!(core_err, CoreError::Signal(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_not_found_error() {
|
||||
let err = CoreError::not_found("CsiFrame", "frame_123");
|
||||
assert!(err.to_string().contains("CsiFrame"));
|
||||
assert!(err.to_string().contains("frame_123"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timeout_error() {
|
||||
let err = CoreError::timeout("inference", 5000);
|
||||
assert!(err.to_string().contains("5000ms"));
|
||||
assert!(err.to_string().contains("inference"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
//! # WiFi-DensePose Core
|
||||
//!
|
||||
//! Core types, traits, and utilities for the WiFi-DensePose pose estimation system.
|
||||
//!
|
||||
//! This crate provides the foundational building blocks used throughout the
|
||||
//! WiFi-DensePose ecosystem, including:
|
||||
//!
|
||||
//! - **Core Data Types**: [`CsiFrame`], [`ProcessedSignal`], [`PoseEstimate`],
|
||||
//! [`PersonPose`], and [`Keypoint`] for representing `WiFi` CSI data and pose
|
||||
//! estimation results.
|
||||
//!
|
||||
//! - **Error Types**: Comprehensive error handling via the [`error`] module,
|
||||
//! with specific error types for different subsystems.
|
||||
//!
|
||||
//! - **Traits**: Core abstractions like [`SignalProcessor`], [`NeuralInference`],
|
||||
//! and [`DataStore`] that define the contracts for signal processing, neural
|
||||
//! network inference, and data persistence.
|
||||
//!
|
||||
//! - **Utilities**: Common helper functions and types used across the codebase.
|
||||
//!
|
||||
//! ## Feature Flags
|
||||
//!
|
||||
//! - `std` (default): Enable standard library support
|
||||
//! - `serde`: Enable serialization/deserialization via serde
|
||||
//! - `async`: Enable async trait definitions
|
||||
//!
|
||||
//! ## Example
|
||||
//!
|
||||
//! ```rust
|
||||
//! use wifi_densepose_core::{CsiFrame, Keypoint, KeypointType, Confidence};
|
||||
//!
|
||||
//! // Create a keypoint with high confidence
|
||||
//! let keypoint = Keypoint::new(
|
||||
//! KeypointType::Nose,
|
||||
//! 0.5,
|
||||
//! 0.3,
|
||||
//! Confidence::new(0.95).unwrap(),
|
||||
//! );
|
||||
//!
|
||||
//! assert!(keypoint.is_visible());
|
||||
//! ```
|
||||
|
||||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
extern crate alloc;
|
||||
|
||||
pub mod error;
|
||||
pub mod traits;
|
||||
pub mod types;
|
||||
pub mod utils;
|
||||
|
||||
// Re-export commonly used types at the crate root
|
||||
pub use error::{CoreError, CoreResult, SignalError, InferenceError, StorageError};
|
||||
pub use traits::{SignalProcessor, NeuralInference, DataStore};
|
||||
pub use types::{
|
||||
// CSI types
|
||||
CsiFrame, CsiMetadata, AntennaConfig,
|
||||
// Signal types
|
||||
ProcessedSignal, SignalFeatures, FrequencyBand,
|
||||
// Pose types
|
||||
PoseEstimate, PersonPose, Keypoint, KeypointType,
|
||||
// Common types
|
||||
Confidence, Timestamp, FrameId, DeviceId,
|
||||
// Bounding box
|
||||
BoundingBox,
|
||||
};
|
||||
|
||||
/// Crate version
|
||||
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
/// Maximum number of keypoints per person (COCO format)
|
||||
pub const MAX_KEYPOINTS: usize = 17;
|
||||
|
||||
/// Maximum number of subcarriers typically used in `WiFi` CSI
|
||||
pub const MAX_SUBCARRIERS: usize = 256;
|
||||
|
||||
/// Default confidence threshold for keypoint visibility
|
||||
pub const DEFAULT_CONFIDENCE_THRESHOLD: f32 = 0.5;
|
||||
|
||||
/// Prelude module for convenient imports.
|
||||
///
|
||||
/// Convenient re-exports of commonly used types and traits.
|
||||
///
|
||||
/// ```rust
|
||||
/// use wifi_densepose_core::prelude::*;
|
||||
/// ```
|
||||
pub mod prelude {
|
||||
|
||||
pub use crate::error::{CoreError, CoreResult};
|
||||
pub use crate::traits::{DataStore, NeuralInference, SignalProcessor};
|
||||
pub use crate::types::{
|
||||
AntennaConfig, BoundingBox, Confidence, CsiFrame, CsiMetadata, DeviceId, FrameId,
|
||||
FrequencyBand, Keypoint, KeypointType, PersonPose, PoseEstimate, ProcessedSignal,
|
||||
SignalFeatures, Timestamp,
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_version_is_valid() {
|
||||
assert!(!VERSION.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_constants() {
|
||||
assert_eq!(MAX_KEYPOINTS, 17);
|
||||
assert!(MAX_SUBCARRIERS > 0);
|
||||
assert!(DEFAULT_CONFIDENCE_THRESHOLD > 0.0);
|
||||
assert!(DEFAULT_CONFIDENCE_THRESHOLD < 1.0);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,626 @@
|
||||
//! Core trait definitions for the WiFi-DensePose system.
|
||||
//!
|
||||
//! This module defines the fundamental abstractions used throughout the system,
|
||||
//! enabling a modular and testable architecture.
|
||||
//!
|
||||
//! # Traits
|
||||
//!
|
||||
//! - [`SignalProcessor`]: Process raw CSI frames into neural network-ready tensors
|
||||
//! - [`NeuralInference`]: Run pose estimation inference on processed signals
|
||||
//! - [`DataStore`]: Persist and retrieve CSI data and pose estimates
|
||||
//!
|
||||
//! # Design Philosophy
|
||||
//!
|
||||
//! These traits are designed with the following principles:
|
||||
//!
|
||||
//! 1. **Single Responsibility**: Each trait handles one concern
|
||||
//! 2. **Testability**: All traits can be easily mocked for unit testing
|
||||
//! 3. **Async-Ready**: Async versions available with the `async` feature
|
||||
//! 4. **Error Handling**: Consistent use of `Result` types with domain errors
|
||||
|
||||
use crate::error::{CoreResult, InferenceError, SignalError, StorageError};
|
||||
use crate::types::{CsiFrame, FrameId, PoseEstimate, ProcessedSignal, Timestamp};
|
||||
|
||||
/// Configuration for signal processing.
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct SignalProcessorConfig {
|
||||
/// Number of frames to buffer before processing
|
||||
pub buffer_size: usize,
|
||||
/// Sampling rate in Hz
|
||||
pub sample_rate_hz: f64,
|
||||
/// Whether to apply noise filtering
|
||||
pub apply_noise_filter: bool,
|
||||
/// Noise filter cutoff frequency in Hz
|
||||
pub filter_cutoff_hz: f64,
|
||||
/// Whether to normalize amplitudes
|
||||
pub normalize_amplitude: bool,
|
||||
/// Whether to unwrap phases
|
||||
pub unwrap_phase: bool,
|
||||
/// Window function for spectral analysis
|
||||
pub window_function: WindowFunction,
|
||||
}
|
||||
|
||||
impl Default for SignalProcessorConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
buffer_size: 64,
|
||||
sample_rate_hz: 1000.0,
|
||||
apply_noise_filter: true,
|
||||
filter_cutoff_hz: 50.0,
|
||||
normalize_amplitude: true,
|
||||
unwrap_phase: true,
|
||||
window_function: WindowFunction::Hann,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Window functions for spectral analysis.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub enum WindowFunction {
|
||||
/// Rectangular window (no windowing)
|
||||
Rectangular,
|
||||
/// Hann window
|
||||
#[default]
|
||||
Hann,
|
||||
/// Hamming window
|
||||
Hamming,
|
||||
/// Blackman window
|
||||
Blackman,
|
||||
/// Kaiser window
|
||||
Kaiser,
|
||||
}
|
||||
|
||||
/// Signal processor for converting raw CSI frames into processed signals.
|
||||
///
|
||||
/// Implementations of this trait handle:
|
||||
/// - Buffering and aggregating CSI frames
|
||||
/// - Noise filtering and signal conditioning
|
||||
/// - Phase unwrapping and amplitude normalization
|
||||
/// - Feature extraction
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// use wifi_densepose_core::{SignalProcessor, CsiFrame};
|
||||
///
|
||||
/// fn process_frames(processor: &mut impl SignalProcessor, frames: Vec<CsiFrame>) {
|
||||
/// for frame in frames {
|
||||
/// if let Err(e) = processor.push_frame(frame) {
|
||||
/// eprintln!("Failed to push frame: {}", e);
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// if let Some(signal) = processor.try_process() {
|
||||
/// println!("Processed signal with {} time steps", signal.num_time_steps());
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
pub trait SignalProcessor: Send + Sync {
|
||||
/// Returns the current configuration.
|
||||
fn config(&self) -> &SignalProcessorConfig;
|
||||
|
||||
/// Updates the configuration.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the configuration is invalid.
|
||||
fn set_config(&mut self, config: SignalProcessorConfig) -> Result<(), SignalError>;
|
||||
|
||||
/// Pushes a new CSI frame into the processing buffer.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the frame is invalid or the buffer is full.
|
||||
fn push_frame(&mut self, frame: CsiFrame) -> Result<(), SignalError>;
|
||||
|
||||
/// Attempts to process the buffered frames.
|
||||
///
|
||||
/// Returns `None` if insufficient frames are buffered.
|
||||
/// Returns `Some(ProcessedSignal)` on successful processing.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if processing fails.
|
||||
fn try_process(&mut self) -> Result<Option<ProcessedSignal>, SignalError>;
|
||||
|
||||
/// Forces processing of whatever frames are buffered.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if no frames are buffered or processing fails.
|
||||
fn force_process(&mut self) -> Result<ProcessedSignal, SignalError>;
|
||||
|
||||
/// Returns the number of frames currently buffered.
|
||||
fn buffered_frame_count(&self) -> usize;
|
||||
|
||||
/// Clears the frame buffer.
|
||||
fn clear_buffer(&mut self);
|
||||
|
||||
/// Resets the processor to its initial state.
|
||||
fn reset(&mut self);
|
||||
}
|
||||
|
||||
/// Configuration for neural network inference.
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct InferenceConfig {
|
||||
/// Path to the model file
|
||||
pub model_path: String,
|
||||
/// Device to run inference on
|
||||
pub device: InferenceDevice,
|
||||
/// Maximum batch size
|
||||
pub max_batch_size: usize,
|
||||
/// Number of threads for CPU inference
|
||||
pub num_threads: usize,
|
||||
/// Confidence threshold for detections
|
||||
pub confidence_threshold: f32,
|
||||
/// Non-maximum suppression threshold
|
||||
pub nms_threshold: f32,
|
||||
/// Whether to use half precision (FP16)
|
||||
pub use_fp16: bool,
|
||||
}
|
||||
|
||||
impl Default for InferenceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
model_path: String::new(),
|
||||
device: InferenceDevice::Cpu,
|
||||
max_batch_size: 8,
|
||||
num_threads: 4,
|
||||
confidence_threshold: 0.5,
|
||||
nms_threshold: 0.45,
|
||||
use_fp16: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Device for running neural network inference.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub enum InferenceDevice {
|
||||
/// CPU inference
|
||||
#[default]
|
||||
Cpu,
|
||||
/// CUDA GPU inference
|
||||
Cuda {
|
||||
/// GPU device index
|
||||
device_id: usize,
|
||||
},
|
||||
/// TensorRT accelerated inference
|
||||
TensorRt {
|
||||
/// GPU device index
|
||||
device_id: usize,
|
||||
},
|
||||
/// CoreML (Apple Silicon)
|
||||
CoreMl,
|
||||
/// WebGPU for browser environments
|
||||
WebGpu,
|
||||
}
|
||||
|
||||
/// Neural network inference engine for pose estimation.
|
||||
///
|
||||
/// Implementations of this trait handle:
|
||||
/// - Loading and managing neural network models
|
||||
/// - Running inference on processed signals
|
||||
/// - Post-processing outputs into pose estimates
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// use wifi_densepose_core::{NeuralInference, ProcessedSignal};
|
||||
///
|
||||
/// async fn estimate_pose(
|
||||
/// engine: &impl NeuralInference,
|
||||
/// signal: ProcessedSignal,
|
||||
/// ) -> Result<PoseEstimate, InferenceError> {
|
||||
/// engine.infer(signal).await
|
||||
/// }
|
||||
/// ```
|
||||
pub trait NeuralInference: Send + Sync {
|
||||
/// Returns the current configuration.
|
||||
fn config(&self) -> &InferenceConfig;
|
||||
|
||||
/// Returns `true` if the model is loaded and ready.
|
||||
fn is_ready(&self) -> bool;
|
||||
|
||||
/// Returns the model version string.
|
||||
fn model_version(&self) -> &str;
|
||||
|
||||
/// Loads the model from the configured path.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the model cannot be loaded.
|
||||
fn load_model(&mut self) -> Result<(), InferenceError>;
|
||||
|
||||
/// Unloads the current model to free resources.
|
||||
fn unload_model(&mut self);
|
||||
|
||||
/// Runs inference on a single processed signal.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if inference fails.
|
||||
fn infer(&self, signal: &ProcessedSignal) -> Result<PoseEstimate, InferenceError>;
|
||||
|
||||
/// Runs inference on a batch of processed signals.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if inference fails.
|
||||
fn infer_batch(&self, signals: &[ProcessedSignal])
|
||||
-> Result<Vec<PoseEstimate>, InferenceError>;
|
||||
|
||||
/// Warms up the model by running a dummy inference.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if warmup fails.
|
||||
fn warmup(&mut self) -> Result<(), InferenceError>;
|
||||
|
||||
/// Returns performance statistics.
|
||||
fn stats(&self) -> InferenceStats;
|
||||
}
|
||||
|
||||
/// Performance statistics for neural network inference.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct InferenceStats {
|
||||
/// Total number of inferences performed
|
||||
pub total_inferences: u64,
|
||||
/// Average inference latency in milliseconds
|
||||
pub avg_latency_ms: f64,
|
||||
/// 95th percentile latency in milliseconds
|
||||
pub p95_latency_ms: f64,
|
||||
/// Maximum latency in milliseconds
|
||||
pub max_latency_ms: f64,
|
||||
/// Inferences per second throughput
|
||||
pub throughput: f64,
|
||||
/// GPU memory usage in bytes (if applicable)
|
||||
pub gpu_memory_bytes: Option<u64>,
|
||||
}
|
||||
|
||||
/// Query options for data store operations.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct QueryOptions {
|
||||
/// Maximum number of results to return
|
||||
pub limit: Option<usize>,
|
||||
/// Number of results to skip
|
||||
pub offset: Option<usize>,
|
||||
/// Start time filter (inclusive)
|
||||
pub start_time: Option<Timestamp>,
|
||||
/// End time filter (inclusive)
|
||||
pub end_time: Option<Timestamp>,
|
||||
/// Device ID filter
|
||||
pub device_id: Option<String>,
|
||||
/// Sort order
|
||||
pub sort_order: SortOrder,
|
||||
}
|
||||
|
||||
/// Sort order for query results.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum SortOrder {
|
||||
/// Ascending order (oldest first)
|
||||
#[default]
|
||||
Ascending,
|
||||
/// Descending order (newest first)
|
||||
Descending,
|
||||
}
|
||||
|
||||
/// Data storage trait for persisting and retrieving CSI data and pose estimates.
|
||||
///
|
||||
/// Implementations can use various backends:
|
||||
/// - PostgreSQL/SQLite for relational storage
|
||||
/// - Redis for caching
|
||||
/// - Time-series databases for efficient temporal queries
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// use wifi_densepose_core::{DataStore, CsiFrame, PoseEstimate};
|
||||
///
|
||||
/// async fn save_and_query(
|
||||
/// store: &impl DataStore,
|
||||
/// frame: CsiFrame,
|
||||
/// estimate: PoseEstimate,
|
||||
/// ) {
|
||||
/// store.store_csi_frame(&frame).await?;
|
||||
/// store.store_pose_estimate(&estimate).await?;
|
||||
///
|
||||
/// let recent = store.get_recent_estimates(10).await?;
|
||||
/// println!("Found {} recent estimates", recent.len());
|
||||
/// }
|
||||
/// ```
|
||||
pub trait DataStore: Send + Sync {
|
||||
/// Returns `true` if the store is connected and ready.
|
||||
fn is_connected(&self) -> bool;
|
||||
|
||||
/// Stores a CSI frame.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the store operation fails.
|
||||
fn store_csi_frame(&self, frame: &CsiFrame) -> Result<(), StorageError>;
|
||||
|
||||
/// Retrieves a CSI frame by ID.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the frame is not found or retrieval fails.
|
||||
fn get_csi_frame(&self, id: &FrameId) -> Result<CsiFrame, StorageError>;
|
||||
|
||||
/// Retrieves CSI frames matching the query options.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the query fails.
|
||||
fn query_csi_frames(&self, options: &QueryOptions) -> Result<Vec<CsiFrame>, StorageError>;
|
||||
|
||||
/// Stores a pose estimate.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the store operation fails.
|
||||
fn store_pose_estimate(&self, estimate: &PoseEstimate) -> Result<(), StorageError>;
|
||||
|
||||
/// Retrieves a pose estimate by ID.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the estimate is not found or retrieval fails.
|
||||
fn get_pose_estimate(&self, id: &FrameId) -> Result<PoseEstimate, StorageError>;
|
||||
|
||||
/// Retrieves pose estimates matching the query options.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the query fails.
|
||||
fn query_pose_estimates(
|
||||
&self,
|
||||
options: &QueryOptions,
|
||||
) -> Result<Vec<PoseEstimate>, StorageError>;
|
||||
|
||||
/// Retrieves the N most recent pose estimates.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the query fails.
|
||||
fn get_recent_estimates(&self, count: usize) -> Result<Vec<PoseEstimate>, StorageError>;
|
||||
|
||||
/// Deletes CSI frames older than the given timestamp.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the deletion fails.
|
||||
fn delete_csi_frames_before(&self, timestamp: &Timestamp) -> Result<u64, StorageError>;
|
||||
|
||||
/// Deletes pose estimates older than the given timestamp.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the deletion fails.
|
||||
fn delete_pose_estimates_before(&self, timestamp: &Timestamp) -> Result<u64, StorageError>;
|
||||
|
||||
/// Returns storage statistics.
|
||||
fn stats(&self) -> StorageStats;
|
||||
}
|
||||
|
||||
/// Storage statistics.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct StorageStats {
|
||||
/// Total number of CSI frames stored
|
||||
pub csi_frame_count: u64,
|
||||
/// Total number of pose estimates stored
|
||||
pub pose_estimate_count: u64,
|
||||
/// Total storage size in bytes
|
||||
pub total_size_bytes: u64,
|
||||
/// Oldest record timestamp
|
||||
pub oldest_record: Option<Timestamp>,
|
||||
/// Newest record timestamp
|
||||
pub newest_record: Option<Timestamp>,
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Async Trait Definitions (with `async` feature)
|
||||
// =============================================================================
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
use async_trait::async_trait;
|
||||
|
||||
/// Async version of [`SignalProcessor`].
|
||||
#[cfg(feature = "async")]
|
||||
#[async_trait]
|
||||
pub trait AsyncSignalProcessor: Send + Sync {
|
||||
/// Returns the current configuration.
|
||||
fn config(&self) -> &SignalProcessorConfig;
|
||||
|
||||
/// Updates the configuration.
|
||||
async fn set_config(&mut self, config: SignalProcessorConfig) -> Result<(), SignalError>;
|
||||
|
||||
/// Pushes a new CSI frame into the processing buffer.
|
||||
async fn push_frame(&mut self, frame: CsiFrame) -> Result<(), SignalError>;
|
||||
|
||||
/// Attempts to process the buffered frames.
|
||||
async fn try_process(&mut self) -> Result<Option<ProcessedSignal>, SignalError>;
|
||||
|
||||
/// Forces processing of whatever frames are buffered.
|
||||
async fn force_process(&mut self) -> Result<ProcessedSignal, SignalError>;
|
||||
|
||||
/// Returns the number of frames currently buffered.
|
||||
fn buffered_frame_count(&self) -> usize;
|
||||
|
||||
/// Clears the frame buffer.
|
||||
async fn clear_buffer(&mut self);
|
||||
|
||||
/// Resets the processor to its initial state.
|
||||
async fn reset(&mut self);
|
||||
}
|
||||
|
||||
/// Async version of [`NeuralInference`].
|
||||
#[cfg(feature = "async")]
|
||||
#[async_trait]
|
||||
pub trait AsyncNeuralInference: Send + Sync {
|
||||
/// Returns the current configuration.
|
||||
fn config(&self) -> &InferenceConfig;
|
||||
|
||||
/// Returns `true` if the model is loaded and ready.
|
||||
fn is_ready(&self) -> bool;
|
||||
|
||||
/// Returns the model version string.
|
||||
fn model_version(&self) -> &str;
|
||||
|
||||
/// Loads the model from the configured path.
|
||||
async fn load_model(&mut self) -> Result<(), InferenceError>;
|
||||
|
||||
/// Unloads the current model to free resources.
|
||||
async fn unload_model(&mut self);
|
||||
|
||||
/// Runs inference on a single processed signal.
|
||||
async fn infer(&self, signal: &ProcessedSignal) -> Result<PoseEstimate, InferenceError>;
|
||||
|
||||
/// Runs inference on a batch of processed signals.
|
||||
async fn infer_batch(
|
||||
&self,
|
||||
signals: &[ProcessedSignal],
|
||||
) -> Result<Vec<PoseEstimate>, InferenceError>;
|
||||
|
||||
/// Warms up the model by running a dummy inference.
|
||||
async fn warmup(&mut self) -> Result<(), InferenceError>;
|
||||
|
||||
/// Returns performance statistics.
|
||||
fn stats(&self) -> InferenceStats;
|
||||
}
|
||||
|
||||
/// Async version of [`DataStore`].
|
||||
#[cfg(feature = "async")]
|
||||
#[async_trait]
|
||||
pub trait AsyncDataStore: Send + Sync {
|
||||
/// Returns `true` if the store is connected and ready.
|
||||
fn is_connected(&self) -> bool;
|
||||
|
||||
/// Stores a CSI frame.
|
||||
async fn store_csi_frame(&self, frame: &CsiFrame) -> Result<(), StorageError>;
|
||||
|
||||
/// Retrieves a CSI frame by ID.
|
||||
async fn get_csi_frame(&self, id: &FrameId) -> Result<CsiFrame, StorageError>;
|
||||
|
||||
/// Retrieves CSI frames matching the query options.
|
||||
async fn query_csi_frames(&self, options: &QueryOptions) -> Result<Vec<CsiFrame>, StorageError>;
|
||||
|
||||
/// Stores a pose estimate.
|
||||
async fn store_pose_estimate(&self, estimate: &PoseEstimate) -> Result<(), StorageError>;
|
||||
|
||||
/// Retrieves a pose estimate by ID.
|
||||
async fn get_pose_estimate(&self, id: &FrameId) -> Result<PoseEstimate, StorageError>;
|
||||
|
||||
/// Retrieves pose estimates matching the query options.
|
||||
async fn query_pose_estimates(
|
||||
&self,
|
||||
options: &QueryOptions,
|
||||
) -> Result<Vec<PoseEstimate>, StorageError>;
|
||||
|
||||
/// Retrieves the N most recent pose estimates.
|
||||
async fn get_recent_estimates(&self, count: usize) -> Result<Vec<PoseEstimate>, StorageError>;
|
||||
|
||||
/// Deletes CSI frames older than the given timestamp.
|
||||
async fn delete_csi_frames_before(&self, timestamp: &Timestamp) -> Result<u64, StorageError>;
|
||||
|
||||
/// Deletes pose estimates older than the given timestamp.
|
||||
async fn delete_pose_estimates_before(
|
||||
&self,
|
||||
timestamp: &Timestamp,
|
||||
) -> Result<u64, StorageError>;
|
||||
|
||||
/// Returns storage statistics.
|
||||
fn stats(&self) -> StorageStats;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Extension Traits
|
||||
// =============================================================================
|
||||
|
||||
/// Extension trait for pipeline composition.
|
||||
pub trait Pipeline: Send + Sync {
|
||||
/// The input type for this pipeline stage.
|
||||
type Input;
|
||||
/// The output type for this pipeline stage.
|
||||
type Output;
|
||||
/// The error type for this pipeline stage.
|
||||
type Error;
|
||||
|
||||
/// Processes input and produces output.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if processing fails.
|
||||
fn process(&self, input: Self::Input) -> Result<Self::Output, Self::Error>;
|
||||
}
|
||||
|
||||
/// Trait for types that can validate themselves.
|
||||
pub trait Validate {
|
||||
/// Validates the instance.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error describing validation failures.
|
||||
fn validate(&self) -> CoreResult<()>;
|
||||
}
|
||||
|
||||
/// Trait for types that can be reset to a default state.
|
||||
pub trait Resettable {
|
||||
/// Resets the instance to its initial state.
|
||||
fn reset(&mut self);
|
||||
}
|
||||
|
||||
/// Trait for types that track health status.
|
||||
pub trait HealthCheck {
|
||||
/// Health status of the component.
|
||||
type Status;
|
||||
|
||||
/// Performs a health check and returns the current status.
|
||||
fn health_check(&self) -> Self::Status;
|
||||
|
||||
/// Returns `true` if the component is healthy.
|
||||
fn is_healthy(&self) -> bool;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_signal_processor_config_default() {
|
||||
let config = SignalProcessorConfig::default();
|
||||
assert_eq!(config.buffer_size, 64);
|
||||
assert!(config.apply_noise_filter);
|
||||
assert!(config.sample_rate_hz > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inference_config_default() {
|
||||
let config = InferenceConfig::default();
|
||||
assert_eq!(config.device, InferenceDevice::Cpu);
|
||||
assert!(config.confidence_threshold > 0.0);
|
||||
assert!(config.max_batch_size > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_options_default() {
|
||||
let options = QueryOptions::default();
|
||||
assert!(options.limit.is_none());
|
||||
assert!(options.offset.is_none());
|
||||
assert_eq!(options.sort_order, SortOrder::Ascending);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inference_device_variants() {
|
||||
let cpu = InferenceDevice::Cpu;
|
||||
let cuda = InferenceDevice::Cuda { device_id: 0 };
|
||||
let tensorrt = InferenceDevice::TensorRt { device_id: 1 };
|
||||
|
||||
assert_eq!(cpu, InferenceDevice::Cpu);
|
||||
assert!(matches!(cuda, InferenceDevice::Cuda { device_id: 0 }));
|
||||
assert!(matches!(tensorrt, InferenceDevice::TensorRt { device_id: 1 }));
|
||||
}
|
||||
}
|
||||
1100
rust-port/wifi-densepose-rs/crates/wifi-densepose-core/src/types.rs
Normal file
1100
rust-port/wifi-densepose-rs/crates/wifi-densepose-core/src/types.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,250 @@
|
||||
//! Common utility functions for the WiFi-DensePose system.
|
||||
//!
|
||||
//! This module provides helper functions used throughout the crate.
|
||||
|
||||
use ndarray::{Array1, Array2};
|
||||
use num_complex::Complex64;
|
||||
|
||||
/// Computes the magnitude (absolute value) of complex numbers.
|
||||
#[must_use]
|
||||
pub fn complex_magnitude(data: &Array2<Complex64>) -> Array2<f64> {
|
||||
data.mapv(num_complex::Complex::norm)
|
||||
}
|
||||
|
||||
/// Computes the phase (argument) of complex numbers in radians.
|
||||
#[must_use]
|
||||
pub fn complex_phase(data: &Array2<Complex64>) -> Array2<f64> {
|
||||
data.mapv(num_complex::Complex::arg)
|
||||
}
|
||||
|
||||
/// Unwraps phase values to remove discontinuities.
|
||||
///
|
||||
/// Phase unwrapping corrects for the 2*pi jumps that occur when phase
|
||||
/// values wrap around from pi to -pi.
|
||||
#[must_use]
|
||||
pub fn unwrap_phase(phase: &Array1<f64>) -> Array1<f64> {
|
||||
let mut unwrapped = phase.clone();
|
||||
let pi = std::f64::consts::PI;
|
||||
let two_pi = 2.0 * pi;
|
||||
|
||||
for i in 1..unwrapped.len() {
|
||||
let diff = unwrapped[i] - unwrapped[i - 1];
|
||||
if diff > pi {
|
||||
for j in i..unwrapped.len() {
|
||||
unwrapped[j] -= two_pi;
|
||||
}
|
||||
} else if diff < -pi {
|
||||
for j in i..unwrapped.len() {
|
||||
unwrapped[j] += two_pi;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unwrapped
|
||||
}
|
||||
|
||||
/// Normalizes values to the range [0, 1].
|
||||
#[must_use]
|
||||
pub fn normalize_min_max(data: &Array1<f64>) -> Array1<f64> {
|
||||
let min = data.iter().copied().fold(f64::INFINITY, f64::min);
|
||||
let max = data.iter().copied().fold(f64::NEG_INFINITY, f64::max);
|
||||
|
||||
if (max - min).abs() < f64::EPSILON {
|
||||
return Array1::zeros(data.len());
|
||||
}
|
||||
|
||||
data.mapv(|x| (x - min) / (max - min))
|
||||
}
|
||||
|
||||
/// Normalizes values using z-score normalization.
|
||||
#[must_use]
|
||||
pub fn normalize_zscore(data: &Array1<f64>) -> Array1<f64> {
|
||||
let mean = data.mean().unwrap_or(0.0);
|
||||
let std = data.std(0.0);
|
||||
|
||||
if std.abs() < f64::EPSILON {
|
||||
return Array1::zeros(data.len());
|
||||
}
|
||||
|
||||
data.mapv(|x| (x - mean) / std)
|
||||
}
|
||||
|
||||
/// Calculates the Signal-to-Noise Ratio in dB.
|
||||
#[must_use]
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
pub fn calculate_snr_db(signal: &Array1<f64>, noise: &Array1<f64>) -> f64 {
|
||||
let signal_power: f64 = signal.iter().map(|x| x * x).sum::<f64>() / signal.len() as f64;
|
||||
let noise_power: f64 = noise.iter().map(|x| x * x).sum::<f64>() / noise.len() as f64;
|
||||
|
||||
if noise_power.abs() < f64::EPSILON {
|
||||
return f64::INFINITY;
|
||||
}
|
||||
|
||||
10.0 * (signal_power / noise_power).log10()
|
||||
}
|
||||
|
||||
/// Applies a moving average filter.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the data array is not contiguous in memory.
|
||||
#[must_use]
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
pub fn moving_average(data: &Array1<f64>, window_size: usize) -> Array1<f64> {
|
||||
if window_size == 0 || window_size > data.len() {
|
||||
return data.clone();
|
||||
}
|
||||
|
||||
let mut result = Array1::zeros(data.len());
|
||||
let half_window = window_size / 2;
|
||||
|
||||
// Safe unwrap: ndarray Array1 is always contiguous
|
||||
let slice = data.as_slice().expect("Array1 should be contiguous");
|
||||
|
||||
for i in 0..data.len() {
|
||||
let start = i.saturating_sub(half_window);
|
||||
let end = (i + half_window + 1).min(data.len());
|
||||
let window = &slice[start..end];
|
||||
result[i] = window.iter().sum::<f64>() / window.len() as f64;
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Clamps a value to a range.
|
||||
#[must_use]
|
||||
pub fn clamp<T: PartialOrd>(value: T, min: T, max: T) -> T {
|
||||
if value < min {
|
||||
min
|
||||
} else if value > max {
|
||||
max
|
||||
} else {
|
||||
value
|
||||
}
|
||||
}
|
||||
|
||||
/// Linearly interpolates between two values.
|
||||
#[must_use]
|
||||
pub fn lerp(a: f64, b: f64, t: f64) -> f64 {
|
||||
(b - a).mul_add(t, a)
|
||||
}
|
||||
|
||||
/// Converts degrees to radians.
|
||||
#[must_use]
|
||||
pub fn deg_to_rad(degrees: f64) -> f64 {
|
||||
degrees.to_radians()
|
||||
}
|
||||
|
||||
/// Converts radians to degrees.
|
||||
#[must_use]
|
||||
pub fn rad_to_deg(radians: f64) -> f64 {
|
||||
radians.to_degrees()
|
||||
}
|
||||
|
||||
/// Calculates the Euclidean distance between two points.
|
||||
#[must_use]
|
||||
pub fn euclidean_distance(p1: (f64, f64), p2: (f64, f64)) -> f64 {
|
||||
let dx = p2.0 - p1.0;
|
||||
let dy = p2.1 - p1.1;
|
||||
dx.hypot(dy)
|
||||
}
|
||||
|
||||
/// Calculates the Euclidean distance in 3D.
|
||||
#[must_use]
|
||||
pub fn euclidean_distance_3d(p1: (f64, f64, f64), p2: (f64, f64, f64)) -> f64 {
|
||||
let dx = p2.0 - p1.0;
|
||||
let dy = p2.1 - p1.1;
|
||||
let dz = p2.2 - p1.2;
|
||||
(dx.mul_add(dx, dy.mul_add(dy, dz * dz))).sqrt()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use ndarray::array;
|
||||
|
||||
#[test]
|
||||
fn test_normalize_min_max() {
|
||||
let data = array![0.0, 5.0, 10.0];
|
||||
let normalized = normalize_min_max(&data);
|
||||
|
||||
assert!((normalized[0] - 0.0).abs() < 1e-10);
|
||||
assert!((normalized[1] - 0.5).abs() < 1e-10);
|
||||
assert!((normalized[2] - 1.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_zscore() {
|
||||
let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let normalized = normalize_zscore(&data);
|
||||
|
||||
// Mean should be approximately 0
|
||||
assert!(normalized.mean().unwrap().abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_moving_average() {
|
||||
let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let smoothed = moving_average(&data, 3);
|
||||
|
||||
// Middle value should be average of 2, 3, 4
|
||||
assert!((smoothed[2] - 3.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clamp() {
|
||||
assert_eq!(clamp(5, 0, 10), 5);
|
||||
assert_eq!(clamp(-5, 0, 10), 0);
|
||||
assert_eq!(clamp(15, 0, 10), 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lerp() {
|
||||
assert!((lerp(0.0, 10.0, 0.5) - 5.0).abs() < 1e-10);
|
||||
assert!((lerp(0.0, 10.0, 0.0) - 0.0).abs() < 1e-10);
|
||||
assert!((lerp(0.0, 10.0, 1.0) - 10.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deg_rad_conversion() {
|
||||
let degrees = 180.0;
|
||||
let radians = deg_to_rad(degrees);
|
||||
assert!((radians - std::f64::consts::PI).abs() < 1e-10);
|
||||
|
||||
let back = rad_to_deg(radians);
|
||||
assert!((back - degrees).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_euclidean_distance() {
|
||||
let dist = euclidean_distance((0.0, 0.0), (3.0, 4.0));
|
||||
assert!((dist - 5.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unwrap_phase() {
|
||||
let pi = std::f64::consts::PI;
|
||||
// Simulate a phase wrap
|
||||
let phase = array![0.0, pi / 2.0, pi, -pi + 0.1, -pi / 2.0];
|
||||
let unwrapped = unwrap_phase(&phase);
|
||||
|
||||
// After unwrapping, the phase should be monotonically increasing
|
||||
for i in 1..unwrapped.len() {
|
||||
// Allow some tolerance for the discontinuity correction
|
||||
assert!(
|
||||
unwrapped[i] >= unwrapped[i - 1] - 0.5,
|
||||
"Phase should be mostly increasing after unwrapping"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_snr_calculation() {
|
||||
let signal = array![1.0, 1.0, 1.0, 1.0];
|
||||
let noise = array![0.1, 0.1, 0.1, 0.1];
|
||||
|
||||
let snr = calculate_snr_db(&signal, &noise);
|
||||
// SNR should be 20 dB (10 * log10(1/0.01) = 10 * log10(100) = 20)
|
||||
assert!((snr - 20.0).abs() < 1e-10);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user