Add batch processing methods for CSI data in CSIProcessor and PhaseSanitizer

This commit is contained in:
rUv
2025-06-07 06:01:40 +00:00
parent cbebdd648f
commit 43e92c5494
2 changed files with 63 additions and 1 deletions

View File

@@ -44,4 +44,36 @@ class CSIProcessor:
if processed.std() > 0:
processed = (processed - processed.mean()) / processed.std()
return processed
return processed
def process_csi_batch(self, csi_data: np.ndarray) -> torch.Tensor:
"""Process a batch of CSI data for neural network input.
Args:
csi_data: Complex CSI data array of shape (batch, antennas, subcarriers, time)
Returns:
Processed CSI tensor ready for neural network input
"""
if csi_data.ndim != 4:
raise ValueError(f"Expected 4D input (batch, antennas, subcarriers, time), got {csi_data.ndim}D")
batch_size, num_antennas, num_subcarriers, time_samples = csi_data.shape
# Extract amplitude and phase
amplitude = np.abs(csi_data)
phase = np.angle(csi_data)
# Process each component
processed_amplitude = self.process_raw_csi(amplitude)
processed_phase = self.process_raw_csi(phase)
# Stack amplitude and phase as separate channels
processed_data = np.stack([processed_amplitude, processed_phase], axis=1)
# Reshape to (batch, channels, antennas, subcarriers, time)
# Then flatten spatial dimensions for CNN input
processed_data = processed_data.reshape(batch_size, 2 * num_antennas, num_subcarriers, time_samples)
# Convert to tensor
return torch.from_numpy(processed_data).float()