diff --git a/src/core/csi_processor.py b/src/core/csi_processor.py index b610481..6c21b65 100644 --- a/src/core/csi_processor.py +++ b/src/core/csi_processor.py @@ -44,4 +44,36 @@ class CSIProcessor: if processed.std() > 0: processed = (processed - processed.mean()) / processed.std() - return processed \ No newline at end of file + return processed + + def process_csi_batch(self, csi_data: np.ndarray) -> torch.Tensor: + """Process a batch of CSI data for neural network input. + + Args: + csi_data: Complex CSI data array of shape (batch, antennas, subcarriers, time) + + Returns: + Processed CSI tensor ready for neural network input + """ + if csi_data.ndim != 4: + raise ValueError(f"Expected 4D input (batch, antennas, subcarriers, time), got {csi_data.ndim}D") + + batch_size, num_antennas, num_subcarriers, time_samples = csi_data.shape + + # Extract amplitude and phase + amplitude = np.abs(csi_data) + phase = np.angle(csi_data) + + # Process each component + processed_amplitude = self.process_raw_csi(amplitude) + processed_phase = self.process_raw_csi(phase) + + # Stack amplitude and phase as separate channels + processed_data = np.stack([processed_amplitude, processed_phase], axis=1) + + # Reshape to (batch, channels, antennas, subcarriers, time) + # Then flatten spatial dimensions for CNN input + processed_data = processed_data.reshape(batch_size, 2 * num_antennas, num_subcarriers, time_samples) + + # Convert to tensor + return torch.from_numpy(processed_data).float() \ No newline at end of file diff --git a/src/core/phase_sanitizer.py b/src/core/phase_sanitizer.py index cfdbc28..fd5371f 100644 --- a/src/core/phase_sanitizer.py +++ b/src/core/phase_sanitizer.py @@ -1,6 +1,7 @@ """Phase sanitizer for WiFi-DensePose CSI phase data processing.""" import numpy as np +import torch from typing import Optional from scipy import signal @@ -60,6 +61,35 @@ class PhaseSanitizer: return result + def sanitize_phase_batch(self, processed_csi: torch.Tensor) -> torch.Tensor: + """Sanitize phase information in a batch of processed CSI data. + + Args: + processed_csi: Processed CSI tensor from CSI processor + + Returns: + CSI tensor with sanitized phase information + """ + if not isinstance(processed_csi, torch.Tensor): + raise ValueError("Input must be a torch.Tensor") + + # Convert to numpy for processing + csi_numpy = processed_csi.detach().cpu().numpy() + + # The processed CSI has shape (batch, channels, subcarriers, time) + # where channels = 2 * antennas (amplitude and phase interleaved) + batch_size, channels, subcarriers, time_samples = csi_numpy.shape + + # Process phase channels (odd indices contain phase information) + for batch_idx in range(batch_size): + for ch_idx in range(1, channels, 2): # Phase channels are at odd indices + phase_data = csi_numpy[batch_idx, ch_idx, :, :] + sanitized_phase = self.sanitize(phase_data) + csi_numpy[batch_idx, ch_idx, :, :] = sanitized_phase + + # Convert back to tensor + return torch.from_numpy(csi_numpy).float() + def smooth_phase(self, phase_data: np.ndarray) -> np.ndarray: """Apply smoothing filter to reduce noise in phase data.