Add batch processing methods for CSI data in CSIProcessor and PhaseSanitizer

This commit is contained in:
rUv
2025-06-07 06:01:40 +00:00
parent cbebdd648f
commit 43e92c5494
2 changed files with 63 additions and 1 deletions

View File

@@ -45,3 +45,35 @@ class CSIProcessor:
processed = (processed - processed.mean()) / processed.std() processed = (processed - processed.mean()) / processed.std()
return processed return processed
def process_csi_batch(self, csi_data: np.ndarray) -> torch.Tensor:
"""Process a batch of CSI data for neural network input.
Args:
csi_data: Complex CSI data array of shape (batch, antennas, subcarriers, time)
Returns:
Processed CSI tensor ready for neural network input
"""
if csi_data.ndim != 4:
raise ValueError(f"Expected 4D input (batch, antennas, subcarriers, time), got {csi_data.ndim}D")
batch_size, num_antennas, num_subcarriers, time_samples = csi_data.shape
# Extract amplitude and phase
amplitude = np.abs(csi_data)
phase = np.angle(csi_data)
# Process each component
processed_amplitude = self.process_raw_csi(amplitude)
processed_phase = self.process_raw_csi(phase)
# Stack amplitude and phase as separate channels
processed_data = np.stack([processed_amplitude, processed_phase], axis=1)
# Reshape to (batch, channels, antennas, subcarriers, time)
# Then flatten spatial dimensions for CNN input
processed_data = processed_data.reshape(batch_size, 2 * num_antennas, num_subcarriers, time_samples)
# Convert to tensor
return torch.from_numpy(processed_data).float()

View File

@@ -1,6 +1,7 @@
"""Phase sanitizer for WiFi-DensePose CSI phase data processing.""" """Phase sanitizer for WiFi-DensePose CSI phase data processing."""
import numpy as np import numpy as np
import torch
from typing import Optional from typing import Optional
from scipy import signal from scipy import signal
@@ -60,6 +61,35 @@ class PhaseSanitizer:
return result return result
def sanitize_phase_batch(self, processed_csi: torch.Tensor) -> torch.Tensor:
"""Sanitize phase information in a batch of processed CSI data.
Args:
processed_csi: Processed CSI tensor from CSI processor
Returns:
CSI tensor with sanitized phase information
"""
if not isinstance(processed_csi, torch.Tensor):
raise ValueError("Input must be a torch.Tensor")
# Convert to numpy for processing
csi_numpy = processed_csi.detach().cpu().numpy()
# The processed CSI has shape (batch, channels, subcarriers, time)
# where channels = 2 * antennas (amplitude and phase interleaved)
batch_size, channels, subcarriers, time_samples = csi_numpy.shape
# Process phase channels (odd indices contain phase information)
for batch_idx in range(batch_size):
for ch_idx in range(1, channels, 2): # Phase channels are at odd indices
phase_data = csi_numpy[batch_idx, ch_idx, :, :]
sanitized_phase = self.sanitize(phase_data)
csi_numpy[batch_idx, ch_idx, :, :] = sanitized_phase
# Convert back to tensor
return torch.from_numpy(csi_numpy).float()
def smooth_phase(self, phase_data: np.ndarray) -> np.ndarray: def smooth_phase(self, phase_data: np.ndarray) -> np.ndarray:
"""Apply smoothing filter to reduce noise in phase data. """Apply smoothing filter to reduce noise in phase data.