Add batch processing methods for CSI data in CSIProcessor and PhaseSanitizer
This commit is contained in:
@@ -44,4 +44,36 @@ class CSIProcessor:
|
||||
if processed.std() > 0:
|
||||
processed = (processed - processed.mean()) / processed.std()
|
||||
|
||||
return processed
|
||||
return processed
|
||||
|
||||
def process_csi_batch(self, csi_data: np.ndarray) -> torch.Tensor:
|
||||
"""Process a batch of CSI data for neural network input.
|
||||
|
||||
Args:
|
||||
csi_data: Complex CSI data array of shape (batch, antennas, subcarriers, time)
|
||||
|
||||
Returns:
|
||||
Processed CSI tensor ready for neural network input
|
||||
"""
|
||||
if csi_data.ndim != 4:
|
||||
raise ValueError(f"Expected 4D input (batch, antennas, subcarriers, time), got {csi_data.ndim}D")
|
||||
|
||||
batch_size, num_antennas, num_subcarriers, time_samples = csi_data.shape
|
||||
|
||||
# Extract amplitude and phase
|
||||
amplitude = np.abs(csi_data)
|
||||
phase = np.angle(csi_data)
|
||||
|
||||
# Process each component
|
||||
processed_amplitude = self.process_raw_csi(amplitude)
|
||||
processed_phase = self.process_raw_csi(phase)
|
||||
|
||||
# Stack amplitude and phase as separate channels
|
||||
processed_data = np.stack([processed_amplitude, processed_phase], axis=1)
|
||||
|
||||
# Reshape to (batch, channels, antennas, subcarriers, time)
|
||||
# Then flatten spatial dimensions for CNN input
|
||||
processed_data = processed_data.reshape(batch_size, 2 * num_antennas, num_subcarriers, time_samples)
|
||||
|
||||
# Convert to tensor
|
||||
return torch.from_numpy(processed_data).float()
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Phase sanitizer for WiFi-DensePose CSI phase data processing."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Optional
|
||||
from scipy import signal
|
||||
|
||||
@@ -60,6 +61,35 @@ class PhaseSanitizer:
|
||||
|
||||
return result
|
||||
|
||||
def sanitize_phase_batch(self, processed_csi: torch.Tensor) -> torch.Tensor:
|
||||
"""Sanitize phase information in a batch of processed CSI data.
|
||||
|
||||
Args:
|
||||
processed_csi: Processed CSI tensor from CSI processor
|
||||
|
||||
Returns:
|
||||
CSI tensor with sanitized phase information
|
||||
"""
|
||||
if not isinstance(processed_csi, torch.Tensor):
|
||||
raise ValueError("Input must be a torch.Tensor")
|
||||
|
||||
# Convert to numpy for processing
|
||||
csi_numpy = processed_csi.detach().cpu().numpy()
|
||||
|
||||
# The processed CSI has shape (batch, channels, subcarriers, time)
|
||||
# where channels = 2 * antennas (amplitude and phase interleaved)
|
||||
batch_size, channels, subcarriers, time_samples = csi_numpy.shape
|
||||
|
||||
# Process phase channels (odd indices contain phase information)
|
||||
for batch_idx in range(batch_size):
|
||||
for ch_idx in range(1, channels, 2): # Phase channels are at odd indices
|
||||
phase_data = csi_numpy[batch_idx, ch_idx, :, :]
|
||||
sanitized_phase = self.sanitize(phase_data)
|
||||
csi_numpy[batch_idx, ch_idx, :, :] = sanitized_phase
|
||||
|
||||
# Convert back to tensor
|
||||
return torch.from_numpy(csi_numpy).float()
|
||||
|
||||
def smooth_phase(self, phase_data: np.ndarray) -> np.ndarray:
|
||||
"""Apply smoothing filter to reduce noise in phase data.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user