Add batch processing methods for CSI data in CSIProcessor and PhaseSanitizer

This commit is contained in:
rUv
2025-06-07 06:01:40 +00:00
parent cbebdd648f
commit 43e92c5494
2 changed files with 63 additions and 1 deletions

View File

@@ -44,4 +44,36 @@ class CSIProcessor:
if processed.std() > 0:
processed = (processed - processed.mean()) / processed.std()
return processed
return processed
def process_csi_batch(self, csi_data: np.ndarray) -> torch.Tensor:
"""Process a batch of CSI data for neural network input.
Args:
csi_data: Complex CSI data array of shape (batch, antennas, subcarriers, time)
Returns:
Processed CSI tensor ready for neural network input
"""
if csi_data.ndim != 4:
raise ValueError(f"Expected 4D input (batch, antennas, subcarriers, time), got {csi_data.ndim}D")
batch_size, num_antennas, num_subcarriers, time_samples = csi_data.shape
# Extract amplitude and phase
amplitude = np.abs(csi_data)
phase = np.angle(csi_data)
# Process each component
processed_amplitude = self.process_raw_csi(amplitude)
processed_phase = self.process_raw_csi(phase)
# Stack amplitude and phase as separate channels
processed_data = np.stack([processed_amplitude, processed_phase], axis=1)
# Reshape to (batch, channels, antennas, subcarriers, time)
# Then flatten spatial dimensions for CNN input
processed_data = processed_data.reshape(batch_size, 2 * num_antennas, num_subcarriers, time_samples)
# Convert to tensor
return torch.from_numpy(processed_data).float()

View File

@@ -1,6 +1,7 @@
"""Phase sanitizer for WiFi-DensePose CSI phase data processing."""
import numpy as np
import torch
from typing import Optional
from scipy import signal
@@ -60,6 +61,35 @@ class PhaseSanitizer:
return result
def sanitize_phase_batch(self, processed_csi: torch.Tensor) -> torch.Tensor:
"""Sanitize phase information in a batch of processed CSI data.
Args:
processed_csi: Processed CSI tensor from CSI processor
Returns:
CSI tensor with sanitized phase information
"""
if not isinstance(processed_csi, torch.Tensor):
raise ValueError("Input must be a torch.Tensor")
# Convert to numpy for processing
csi_numpy = processed_csi.detach().cpu().numpy()
# The processed CSI has shape (batch, channels, subcarriers, time)
# where channels = 2 * antennas (amplitude and phase interleaved)
batch_size, channels, subcarriers, time_samples = csi_numpy.shape
# Process phase channels (odd indices contain phase information)
for batch_idx in range(batch_size):
for ch_idx in range(1, channels, 2): # Phase channels are at odd indices
phase_data = csi_numpy[batch_idx, ch_idx, :, :]
sanitized_phase = self.sanitize(phase_data)
csi_numpy[batch_idx, ch_idx, :, :] = sanitized_phase
# Convert back to tensor
return torch.from_numpy(csi_numpy).float()
def smooth_phase(self, phase_data: np.ndarray) -> np.ndarray:
"""Apply smoothing filter to reduce noise in phase data.