- Implemented the WiFi DensePose model in PyTorch, including CSI phase processing, modality translation, and DensePose prediction heads. - Added a comprehensive training utility for the model, including loss functions and training steps. - Created a CSV file to document hardware specifications, architecture details, training parameters, performance metrics, and advantages of the model.
489 lines
18 KiB
Python
489 lines
18 KiB
Python
# WiFi DensePose Implementation in PyTorch
|
|
# Based on "DensePose From WiFi" by Carnegie Mellon University
|
|
# Paper: https://arxiv.org/pdf/2301.00250
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
import math
|
|
from typing import Dict, List, Tuple, Optional
|
|
from collections import OrderedDict
|
|
|
|
class CSIPhaseProcessor:
|
|
"""
|
|
Processes raw CSI phase data through unwrapping, filtering, and linear fitting
|
|
Based on the phase sanitization methodology from the paper
|
|
"""
|
|
|
|
def __init__(self, num_subcarriers: int = 30):
|
|
self.num_subcarriers = num_subcarriers
|
|
|
|
def unwrap_phase(self, phase_data: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Unwrap phase values to handle discontinuities
|
|
Args:
|
|
phase_data: Raw phase data of shape (batch, freq_samples, tx, rx)
|
|
Returns:
|
|
Unwrapped phase data
|
|
"""
|
|
unwrapped = phase_data.clone()
|
|
|
|
# Unwrap along frequency dimension (groups of 30 frequencies)
|
|
for sample_group in range(5): # 5 consecutive samples
|
|
start_idx = sample_group * 30
|
|
end_idx = start_idx + 30
|
|
|
|
for i in range(start_idx + 1, end_idx):
|
|
diff = unwrapped[:, i] - unwrapped[:, i-1]
|
|
|
|
# Apply unwrapping logic
|
|
unwrapped[:, i] = torch.where(diff > math.pi,
|
|
unwrapped[:, i-1] + diff - 2*math.pi,
|
|
unwrapped[:, i])
|
|
unwrapped[:, i] = torch.where(diff < -math.pi,
|
|
unwrapped[:, i-1] + diff + 2*math.pi,
|
|
unwrapped[:, i])
|
|
|
|
return unwrapped
|
|
|
|
def apply_filters(self, phase_data: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Apply median and uniform filters to eliminate outliers
|
|
"""
|
|
# Simple smoothing in frequency dimension
|
|
filtered = phase_data.clone()
|
|
for i in range(1, phase_data.shape[1]-1):
|
|
filtered[:, i] = (phase_data[:, i-1] + phase_data[:, i] + phase_data[:, i+1]) / 3
|
|
|
|
return filtered
|
|
|
|
def linear_fitting(self, phase_data: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Apply linear fitting to remove systematic phase drift
|
|
"""
|
|
fitted_data = phase_data.clone()
|
|
F = self.num_subcarriers
|
|
|
|
# Process each sample group (5 consecutive samples)
|
|
for sample_group in range(5):
|
|
start_idx = sample_group * 30
|
|
end_idx = start_idx + 30
|
|
|
|
for batch_idx in range(phase_data.shape[0]):
|
|
for tx in range(phase_data.shape[2]):
|
|
for rx in range(phase_data.shape[3]):
|
|
phase_seq = phase_data[batch_idx, start_idx:end_idx, tx, rx]
|
|
|
|
if len(phase_seq) > 1:
|
|
# Calculate linear coefficients
|
|
alpha1 = (phase_seq[-1] - phase_seq[0]) / (2 * math.pi * F)
|
|
alpha0 = torch.mean(phase_seq)
|
|
|
|
# Apply linear fitting
|
|
frequencies = torch.arange(1, len(phase_seq) + 1, dtype=phase_seq.dtype, device=phase_seq.device)
|
|
linear_trend = alpha1 * frequencies + alpha0
|
|
fitted_data[batch_idx, start_idx:end_idx, tx, rx] = phase_seq - linear_trend
|
|
|
|
return fitted_data
|
|
|
|
def sanitize_phase(self, raw_phase: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Complete phase sanitization pipeline
|
|
"""
|
|
# Step 1: Unwrap phase
|
|
unwrapped = self.unwrap_phase(raw_phase)
|
|
|
|
# Step 2: Apply filters
|
|
filtered = self.apply_filters(unwrapped)
|
|
|
|
# Step 3: Linear fitting
|
|
sanitized = self.linear_fitting(filtered)
|
|
|
|
return sanitized
|
|
|
|
class ModalityTranslationNetwork(nn.Module):
|
|
"""
|
|
Translates CSI domain features to spatial domain features
|
|
Input: 150x3x3 amplitude and phase tensors
|
|
Output: 3x720x1280 feature map
|
|
"""
|
|
|
|
def __init__(self, input_dim: int = 1350, hidden_dim: int = 512, output_height: int = 720, output_width: int = 1280):
|
|
super(ModalityTranslationNetwork, self).__init__()
|
|
|
|
self.input_dim = input_dim
|
|
self.output_height = output_height
|
|
self.output_width = output_width
|
|
|
|
# Amplitude encoder
|
|
self.amplitude_encoder = nn.Sequential(
|
|
nn.Linear(input_dim, hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.2),
|
|
nn.Linear(hidden_dim, hidden_dim//2),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.2),
|
|
nn.Linear(hidden_dim//2, hidden_dim//4),
|
|
nn.ReLU()
|
|
)
|
|
|
|
# Phase encoder
|
|
self.phase_encoder = nn.Sequential(
|
|
nn.Linear(input_dim, hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.2),
|
|
nn.Linear(hidden_dim, hidden_dim//2),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.2),
|
|
nn.Linear(hidden_dim//2, hidden_dim//4),
|
|
nn.ReLU()
|
|
)
|
|
|
|
# Feature fusion
|
|
self.fusion_mlp = nn.Sequential(
|
|
nn.Linear(hidden_dim//2, hidden_dim//4),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.2),
|
|
nn.Linear(hidden_dim//4, 24*24), # Reshape to 24x24
|
|
nn.ReLU()
|
|
)
|
|
|
|
# Spatial processing
|
|
self.spatial_conv = nn.Sequential(
|
|
nn.Conv2d(1, 64, kernel_size=3, padding=1),
|
|
nn.BatchNorm2d(64),
|
|
nn.ReLU(),
|
|
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
|
nn.BatchNorm2d(128),
|
|
nn.ReLU(),
|
|
nn.AdaptiveAvgPool2d((6, 6)) # Compress to 6x6
|
|
)
|
|
|
|
# Upsampling to target resolution
|
|
self.upsample = nn.Sequential(
|
|
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # 12x12
|
|
nn.BatchNorm2d(64),
|
|
nn.ReLU(),
|
|
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # 24x24
|
|
nn.BatchNorm2d(32),
|
|
nn.ReLU(),
|
|
nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), # 48x48
|
|
nn.BatchNorm2d(16),
|
|
nn.ReLU(),
|
|
nn.ConvTranspose2d(16, 8, kernel_size=4, stride=2, padding=1), # 96x96
|
|
nn.BatchNorm2d(8),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
# Final upsampling to target size
|
|
self.final_conv = nn.Conv2d(8, 3, kernel_size=1)
|
|
|
|
def forward(self, amplitude_tensor: torch.Tensor, phase_tensor: torch.Tensor) -> torch.Tensor:
|
|
batch_size = amplitude_tensor.shape[0]
|
|
|
|
# Flatten input tensors
|
|
amplitude_flat = amplitude_tensor.view(batch_size, -1) # [B, 1350]
|
|
phase_flat = phase_tensor.view(batch_size, -1) # [B, 1350]
|
|
|
|
# Encode features
|
|
amp_features = self.amplitude_encoder(amplitude_flat) # [B, 128]
|
|
phase_features = self.phase_encoder(phase_flat) # [B, 128]
|
|
|
|
# Fuse features
|
|
fused_features = torch.cat([amp_features, phase_features], dim=1) # [B, 256]
|
|
spatial_features = self.fusion_mlp(fused_features) # [B, 576]
|
|
|
|
# Reshape to 2D feature map
|
|
spatial_map = spatial_features.view(batch_size, 1, 24, 24) # [B, 1, 24, 24]
|
|
|
|
# Apply spatial convolutions
|
|
conv_features = self.spatial_conv(spatial_map) # [B, 128, 6, 6]
|
|
|
|
# Upsample
|
|
upsampled = self.upsample(conv_features) # [B, 8, 96, 96]
|
|
|
|
# Final convolution
|
|
final_features = self.final_conv(upsampled) # [B, 3, 96, 96]
|
|
|
|
# Interpolate to target resolution
|
|
output = F.interpolate(final_features, size=(self.output_height, self.output_width),
|
|
mode='bilinear', align_corners=False)
|
|
|
|
return output
|
|
|
|
class DensePoseHead(nn.Module):
|
|
"""
|
|
DensePose prediction head for estimating UV coordinates
|
|
"""
|
|
def __init__(self, input_channels=256, num_parts=24, output_size=(112, 112)):
|
|
super(DensePoseHead, self).__init__()
|
|
|
|
self.num_parts = num_parts
|
|
self.output_size = output_size
|
|
|
|
# Shared convolutional layers
|
|
self.shared_conv = nn.Sequential(
|
|
nn.Conv2d(input_channels, 512, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(512, 512, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(512, 512, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
# Part classification branch
|
|
self.part_classifier = nn.Conv2d(512, num_parts + 1, kernel_size=1) # +1 for background
|
|
|
|
# UV coordinate regression branches
|
|
self.u_regressor = nn.Conv2d(512, num_parts, kernel_size=1)
|
|
self.v_regressor = nn.Conv2d(512, num_parts, kernel_size=1)
|
|
|
|
def forward(self, x):
|
|
# Shared feature extraction
|
|
features = self.shared_conv(x)
|
|
|
|
# Upsample features to target size
|
|
features = F.interpolate(features, size=self.output_size, mode='bilinear', align_corners=False)
|
|
|
|
# Predict part labels
|
|
part_logits = self.part_classifier(features)
|
|
|
|
# Predict UV coordinates
|
|
u_coords = torch.sigmoid(self.u_regressor(features)) # Sigmoid to ensure [0,1] range
|
|
v_coords = torch.sigmoid(self.v_regressor(features))
|
|
|
|
return {
|
|
'part_logits': part_logits,
|
|
'u_coords': u_coords,
|
|
'v_coords': v_coords
|
|
}
|
|
|
|
class KeypointHead(nn.Module):
|
|
"""
|
|
Keypoint prediction head for estimating body keypoints
|
|
"""
|
|
def __init__(self, input_channels=256, num_keypoints=17, output_size=(56, 56)):
|
|
super(KeypointHead, self).__init__()
|
|
|
|
self.num_keypoints = num_keypoints
|
|
self.output_size = output_size
|
|
|
|
# Convolutional layers for keypoint detection
|
|
self.conv_layers = nn.Sequential(
|
|
nn.Conv2d(input_channels, 512, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(512, 512, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(512, 512, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(512, num_keypoints, kernel_size=1)
|
|
)
|
|
|
|
def forward(self, x):
|
|
# Extract keypoint heatmaps
|
|
heatmaps = self.conv_layers(x)
|
|
|
|
# Upsample to target size
|
|
heatmaps = F.interpolate(heatmaps, size=self.output_size, mode='bilinear', align_corners=False)
|
|
|
|
return heatmaps
|
|
|
|
class WiFiDensePoseRCNN(nn.Module):
|
|
"""
|
|
Complete WiFi-DensePose RCNN architecture
|
|
"""
|
|
def __init__(self):
|
|
super(WiFiDensePoseRCNN, self).__init__()
|
|
|
|
# CSI processing
|
|
self.phase_processor = CSIPhaseProcessor()
|
|
|
|
# Modality translation
|
|
self.modality_translation = ModalityTranslationNetwork()
|
|
|
|
# Simplified backbone (in practice, use ResNet-FPN)
|
|
self.backbone = nn.Sequential(
|
|
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
|
|
nn.BatchNorm2d(64),
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
|
|
|
# Simplified ResNet blocks
|
|
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
|
nn.BatchNorm2d(128),
|
|
nn.ReLU(),
|
|
nn.Conv2d(128, 256, kernel_size=3, padding=1),
|
|
nn.BatchNorm2d(256),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
# Prediction heads
|
|
self.densepose_head = DensePoseHead(input_channels=256)
|
|
self.keypoint_head = KeypointHead(input_channels=256)
|
|
|
|
# Global average pooling for simplified processing
|
|
self.global_pool = nn.AdaptiveAvgPool2d((7, 7))
|
|
|
|
def forward(self, amplitude_data, phase_data):
|
|
batch_size = amplitude_data.shape[0]
|
|
|
|
# Process CSI phase data
|
|
sanitized_phase = self.phase_processor.sanitize_phase(phase_data)
|
|
|
|
# Translate to spatial domain
|
|
spatial_features = self.modality_translation(amplitude_data, sanitized_phase)
|
|
|
|
# Extract backbone features
|
|
backbone_features = self.backbone(spatial_features)
|
|
|
|
# Global pooling to get fixed-size features
|
|
pooled_features = self.global_pool(backbone_features)
|
|
|
|
# Predict DensePose
|
|
densepose_output = self.densepose_head(pooled_features)
|
|
|
|
# Predict keypoints
|
|
keypoint_heatmaps = self.keypoint_head(pooled_features)
|
|
|
|
return {
|
|
'spatial_features': spatial_features,
|
|
'densepose': densepose_output,
|
|
'keypoints': keypoint_heatmaps
|
|
}
|
|
|
|
class WiFiDensePoseLoss(nn.Module):
|
|
"""
|
|
Combined loss function for WiFi DensePose training
|
|
"""
|
|
def __init__(self, lambda_dp=0.6, lambda_kp=0.3, lambda_tr=0.1):
|
|
super(WiFiDensePoseLoss, self).__init__()
|
|
|
|
self.lambda_dp = lambda_dp
|
|
self.lambda_kp = lambda_kp
|
|
self.lambda_tr = lambda_tr
|
|
|
|
# Loss functions
|
|
self.cross_entropy = nn.CrossEntropyLoss()
|
|
self.mse_loss = nn.MSELoss()
|
|
self.smooth_l1 = nn.SmoothL1Loss()
|
|
|
|
def forward(self, predictions, targets, teacher_features=None):
|
|
total_loss = 0.0
|
|
loss_dict = {}
|
|
|
|
# DensePose losses
|
|
if 'densepose' in predictions and 'densepose' in targets:
|
|
# Part classification loss
|
|
part_loss = self.cross_entropy(
|
|
predictions['densepose']['part_logits'],
|
|
targets['densepose']['part_labels']
|
|
)
|
|
|
|
# UV coordinate regression loss
|
|
uv_loss = (self.smooth_l1(predictions['densepose']['u_coords'], targets['densepose']['u_coords']) +
|
|
self.smooth_l1(predictions['densepose']['v_coords'], targets['densepose']['v_coords'])) / 2
|
|
|
|
dp_loss = part_loss + uv_loss
|
|
total_loss += self.lambda_dp * dp_loss
|
|
loss_dict['densepose'] = dp_loss
|
|
|
|
# Keypoint loss
|
|
if 'keypoints' in predictions and 'keypoints' in targets:
|
|
kp_loss = self.mse_loss(predictions['keypoints'], targets['keypoints'])
|
|
total_loss += self.lambda_kp * kp_loss
|
|
loss_dict['keypoint'] = kp_loss
|
|
|
|
# Transfer learning loss
|
|
if teacher_features is not None and 'backbone_features' in predictions:
|
|
tr_loss = self.mse_loss(predictions['backbone_features'], teacher_features)
|
|
total_loss += self.lambda_tr * tr_loss
|
|
loss_dict['transfer'] = tr_loss
|
|
|
|
loss_dict['total'] = total_loss
|
|
return total_loss, loss_dict
|
|
|
|
# Training utilities
|
|
class WiFiDensePoseTrainer:
|
|
"""
|
|
Training utilities for WiFi DensePose
|
|
"""
|
|
def __init__(self, model, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
|
self.model = model.to(device)
|
|
self.device = device
|
|
self.criterion = WiFiDensePoseLoss()
|
|
self.optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
|
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
|
self.optimizer, milestones=[48000, 96000], gamma=0.1
|
|
)
|
|
|
|
def train_step(self, amplitude_data, phase_data, targets):
|
|
self.model.train()
|
|
self.optimizer.zero_grad()
|
|
|
|
# Forward pass
|
|
outputs = self.model(amplitude_data, phase_data)
|
|
|
|
# Compute loss
|
|
loss, loss_dict = self.criterion(outputs, targets)
|
|
|
|
# Backward pass
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
self.scheduler.step()
|
|
|
|
return loss.item(), loss_dict
|
|
|
|
def save_model(self, path):
|
|
torch.save({
|
|
'model_state_dict': self.model.state_dict(),
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
}, path)
|
|
|
|
def load_model(self, path):
|
|
checkpoint = torch.load(path)
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
|
|
# Example usage
|
|
def create_sample_data(batch_size=1, device='cpu'):
|
|
"""
|
|
Create sample CSI data for testing
|
|
"""
|
|
amplitude = torch.randn(batch_size, 150, 3, 3).to(device)
|
|
phase = torch.randn(batch_size, 150, 3, 3).to(device)
|
|
|
|
# Sample targets
|
|
targets = {
|
|
'densepose': {
|
|
'part_labels': torch.randint(0, 25, (batch_size, 112, 112)).to(device),
|
|
'u_coords': torch.rand(batch_size, 24, 112, 112).to(device),
|
|
'v_coords': torch.rand(batch_size, 24, 112, 112).to(device)
|
|
},
|
|
'keypoints': torch.rand(batch_size, 17, 56, 56).to(device)
|
|
}
|
|
|
|
return amplitude, phase, targets
|
|
|
|
if __name__ == "__main__":
|
|
# Initialize model
|
|
model = WiFiDensePoseRCNN()
|
|
trainer = WiFiDensePoseTrainer(model)
|
|
|
|
print("WiFi DensePose model initialized!")
|
|
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
|
|
|
# Create sample data
|
|
amplitude, phase, targets = create_sample_data()
|
|
|
|
# Run inference
|
|
with torch.no_grad():
|
|
outputs = model(amplitude, phase)
|
|
print(f"Spatial features shape: {outputs['spatial_features'].shape}")
|
|
print(f"DensePose part logits shape: {outputs['densepose']['part_logits'].shape}")
|
|
print(f"Keypoint heatmaps shape: {outputs['keypoints'].shape}")
|
|
|
|
# Training step
|
|
loss, loss_dict = trainer.train_step(amplitude, phase, targets)
|
|
print(f"Training loss: {loss:.4f}")
|
|
print(f"Loss breakdown: {loss_dict}") |