Add batch processing methods for CSI data in CSIProcessor and PhaseSanitizer
This commit is contained in:
@@ -44,4 +44,36 @@ class CSIProcessor:
|
|||||||
if processed.std() > 0:
|
if processed.std() > 0:
|
||||||
processed = (processed - processed.mean()) / processed.std()
|
processed = (processed - processed.mean()) / processed.std()
|
||||||
|
|
||||||
return processed
|
return processed
|
||||||
|
|
||||||
|
def process_csi_batch(self, csi_data: np.ndarray) -> torch.Tensor:
|
||||||
|
"""Process a batch of CSI data for neural network input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
csi_data: Complex CSI data array of shape (batch, antennas, subcarriers, time)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed CSI tensor ready for neural network input
|
||||||
|
"""
|
||||||
|
if csi_data.ndim != 4:
|
||||||
|
raise ValueError(f"Expected 4D input (batch, antennas, subcarriers, time), got {csi_data.ndim}D")
|
||||||
|
|
||||||
|
batch_size, num_antennas, num_subcarriers, time_samples = csi_data.shape
|
||||||
|
|
||||||
|
# Extract amplitude and phase
|
||||||
|
amplitude = np.abs(csi_data)
|
||||||
|
phase = np.angle(csi_data)
|
||||||
|
|
||||||
|
# Process each component
|
||||||
|
processed_amplitude = self.process_raw_csi(amplitude)
|
||||||
|
processed_phase = self.process_raw_csi(phase)
|
||||||
|
|
||||||
|
# Stack amplitude and phase as separate channels
|
||||||
|
processed_data = np.stack([processed_amplitude, processed_phase], axis=1)
|
||||||
|
|
||||||
|
# Reshape to (batch, channels, antennas, subcarriers, time)
|
||||||
|
# Then flatten spatial dimensions for CNN input
|
||||||
|
processed_data = processed_data.reshape(batch_size, 2 * num_antennas, num_subcarriers, time_samples)
|
||||||
|
|
||||||
|
# Convert to tensor
|
||||||
|
return torch.from_numpy(processed_data).float()
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Phase sanitizer for WiFi-DensePose CSI phase data processing."""
|
"""Phase sanitizer for WiFi-DensePose CSI phase data processing."""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
|
|
||||||
@@ -60,6 +61,35 @@ class PhaseSanitizer:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def sanitize_phase_batch(self, processed_csi: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Sanitize phase information in a batch of processed CSI data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
processed_csi: Processed CSI tensor from CSI processor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CSI tensor with sanitized phase information
|
||||||
|
"""
|
||||||
|
if not isinstance(processed_csi, torch.Tensor):
|
||||||
|
raise ValueError("Input must be a torch.Tensor")
|
||||||
|
|
||||||
|
# Convert to numpy for processing
|
||||||
|
csi_numpy = processed_csi.detach().cpu().numpy()
|
||||||
|
|
||||||
|
# The processed CSI has shape (batch, channels, subcarriers, time)
|
||||||
|
# where channels = 2 * antennas (amplitude and phase interleaved)
|
||||||
|
batch_size, channels, subcarriers, time_samples = csi_numpy.shape
|
||||||
|
|
||||||
|
# Process phase channels (odd indices contain phase information)
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
for ch_idx in range(1, channels, 2): # Phase channels are at odd indices
|
||||||
|
phase_data = csi_numpy[batch_idx, ch_idx, :, :]
|
||||||
|
sanitized_phase = self.sanitize(phase_data)
|
||||||
|
csi_numpy[batch_idx, ch_idx, :, :] = sanitized_phase
|
||||||
|
|
||||||
|
# Convert back to tensor
|
||||||
|
return torch.from_numpy(csi_numpy).float()
|
||||||
|
|
||||||
def smooth_phase(self, phase_data: np.ndarray) -> np.ndarray:
|
def smooth_phase(self, phase_data: np.ndarray) -> np.ndarray:
|
||||||
"""Apply smoothing filter to reduce noise in phase data.
|
"""Apply smoothing filter to reduce noise in phase data.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user