Add WiFi DensePose implementation and results

- Implemented the WiFi DensePose model in PyTorch, including CSI phase processing, modality translation, and DensePose prediction heads.
- Added a comprehensive training utility for the model, including loss functions and training steps.
- Created a CSV file to document hardware specifications, architecture details, training parameters, performance metrics, and advantages of the model.
This commit is contained in:
rUv
2025-06-07 05:23:07 +00:00
parent 8227a70c31
commit f3c77b1750
66 changed files with 19849 additions and 2 deletions

View File

@@ -0,0 +1,489 @@
# WiFi DensePose Implementation in PyTorch
# Based on "DensePose From WiFi" by Carnegie Mellon University
# Paper: https://arxiv.org/pdf/2301.00250
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from typing import Dict, List, Tuple, Optional
from collections import OrderedDict
class CSIPhaseProcessor:
"""
Processes raw CSI phase data through unwrapping, filtering, and linear fitting
Based on the phase sanitization methodology from the paper
"""
def __init__(self, num_subcarriers: int = 30):
self.num_subcarriers = num_subcarriers
def unwrap_phase(self, phase_data: torch.Tensor) -> torch.Tensor:
"""
Unwrap phase values to handle discontinuities
Args:
phase_data: Raw phase data of shape (batch, freq_samples, tx, rx)
Returns:
Unwrapped phase data
"""
unwrapped = phase_data.clone()
# Unwrap along frequency dimension (groups of 30 frequencies)
for sample_group in range(5): # 5 consecutive samples
start_idx = sample_group * 30
end_idx = start_idx + 30
for i in range(start_idx + 1, end_idx):
diff = unwrapped[:, i] - unwrapped[:, i-1]
# Apply unwrapping logic
unwrapped[:, i] = torch.where(diff > math.pi,
unwrapped[:, i-1] + diff - 2*math.pi,
unwrapped[:, i])
unwrapped[:, i] = torch.where(diff < -math.pi,
unwrapped[:, i-1] + diff + 2*math.pi,
unwrapped[:, i])
return unwrapped
def apply_filters(self, phase_data: torch.Tensor) -> torch.Tensor:
"""
Apply median and uniform filters to eliminate outliers
"""
# Simple smoothing in frequency dimension
filtered = phase_data.clone()
for i in range(1, phase_data.shape[1]-1):
filtered[:, i] = (phase_data[:, i-1] + phase_data[:, i] + phase_data[:, i+1]) / 3
return filtered
def linear_fitting(self, phase_data: torch.Tensor) -> torch.Tensor:
"""
Apply linear fitting to remove systematic phase drift
"""
fitted_data = phase_data.clone()
F = self.num_subcarriers
# Process each sample group (5 consecutive samples)
for sample_group in range(5):
start_idx = sample_group * 30
end_idx = start_idx + 30
for batch_idx in range(phase_data.shape[0]):
for tx in range(phase_data.shape[2]):
for rx in range(phase_data.shape[3]):
phase_seq = phase_data[batch_idx, start_idx:end_idx, tx, rx]
if len(phase_seq) > 1:
# Calculate linear coefficients
alpha1 = (phase_seq[-1] - phase_seq[0]) / (2 * math.pi * F)
alpha0 = torch.mean(phase_seq)
# Apply linear fitting
frequencies = torch.arange(1, len(phase_seq) + 1, dtype=phase_seq.dtype, device=phase_seq.device)
linear_trend = alpha1 * frequencies + alpha0
fitted_data[batch_idx, start_idx:end_idx, tx, rx] = phase_seq - linear_trend
return fitted_data
def sanitize_phase(self, raw_phase: torch.Tensor) -> torch.Tensor:
"""
Complete phase sanitization pipeline
"""
# Step 1: Unwrap phase
unwrapped = self.unwrap_phase(raw_phase)
# Step 2: Apply filters
filtered = self.apply_filters(unwrapped)
# Step 3: Linear fitting
sanitized = self.linear_fitting(filtered)
return sanitized
class ModalityTranslationNetwork(nn.Module):
"""
Translates CSI domain features to spatial domain features
Input: 150x3x3 amplitude and phase tensors
Output: 3x720x1280 feature map
"""
def __init__(self, input_dim: int = 1350, hidden_dim: int = 512, output_height: int = 720, output_width: int = 1280):
super(ModalityTranslationNetwork, self).__init__()
self.input_dim = input_dim
self.output_height = output_height
self.output_width = output_width
# Amplitude encoder
self.amplitude_encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim, hidden_dim//2),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim//2, hidden_dim//4),
nn.ReLU()
)
# Phase encoder
self.phase_encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim, hidden_dim//2),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim//2, hidden_dim//4),
nn.ReLU()
)
# Feature fusion
self.fusion_mlp = nn.Sequential(
nn.Linear(hidden_dim//2, hidden_dim//4),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim//4, 24*24), # Reshape to 24x24
nn.ReLU()
)
# Spatial processing
self.spatial_conv = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.AdaptiveAvgPool2d((6, 6)) # Compress to 6x6
)
# Upsampling to target resolution
self.upsample = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # 12x12
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # 24x24
nn.BatchNorm2d(32),
nn.ReLU(),
nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), # 48x48
nn.BatchNorm2d(16),
nn.ReLU(),
nn.ConvTranspose2d(16, 8, kernel_size=4, stride=2, padding=1), # 96x96
nn.BatchNorm2d(8),
nn.ReLU(),
)
# Final upsampling to target size
self.final_conv = nn.Conv2d(8, 3, kernel_size=1)
def forward(self, amplitude_tensor: torch.Tensor, phase_tensor: torch.Tensor) -> torch.Tensor:
batch_size = amplitude_tensor.shape[0]
# Flatten input tensors
amplitude_flat = amplitude_tensor.view(batch_size, -1) # [B, 1350]
phase_flat = phase_tensor.view(batch_size, -1) # [B, 1350]
# Encode features
amp_features = self.amplitude_encoder(amplitude_flat) # [B, 128]
phase_features = self.phase_encoder(phase_flat) # [B, 128]
# Fuse features
fused_features = torch.cat([amp_features, phase_features], dim=1) # [B, 256]
spatial_features = self.fusion_mlp(fused_features) # [B, 576]
# Reshape to 2D feature map
spatial_map = spatial_features.view(batch_size, 1, 24, 24) # [B, 1, 24, 24]
# Apply spatial convolutions
conv_features = self.spatial_conv(spatial_map) # [B, 128, 6, 6]
# Upsample
upsampled = self.upsample(conv_features) # [B, 8, 96, 96]
# Final convolution
final_features = self.final_conv(upsampled) # [B, 3, 96, 96]
# Interpolate to target resolution
output = F.interpolate(final_features, size=(self.output_height, self.output_width),
mode='bilinear', align_corners=False)
return output
class DensePoseHead(nn.Module):
"""
DensePose prediction head for estimating UV coordinates
"""
def __init__(self, input_channels=256, num_parts=24, output_size=(112, 112)):
super(DensePoseHead, self).__init__()
self.num_parts = num_parts
self.output_size = output_size
# Shared convolutional layers
self.shared_conv = nn.Sequential(
nn.Conv2d(input_channels, 512, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.ReLU(),
)
# Part classification branch
self.part_classifier = nn.Conv2d(512, num_parts + 1, kernel_size=1) # +1 for background
# UV coordinate regression branches
self.u_regressor = nn.Conv2d(512, num_parts, kernel_size=1)
self.v_regressor = nn.Conv2d(512, num_parts, kernel_size=1)
def forward(self, x):
# Shared feature extraction
features = self.shared_conv(x)
# Upsample features to target size
features = F.interpolate(features, size=self.output_size, mode='bilinear', align_corners=False)
# Predict part labels
part_logits = self.part_classifier(features)
# Predict UV coordinates
u_coords = torch.sigmoid(self.u_regressor(features)) # Sigmoid to ensure [0,1] range
v_coords = torch.sigmoid(self.v_regressor(features))
return {
'part_logits': part_logits,
'u_coords': u_coords,
'v_coords': v_coords
}
class KeypointHead(nn.Module):
"""
Keypoint prediction head for estimating body keypoints
"""
def __init__(self, input_channels=256, num_keypoints=17, output_size=(56, 56)):
super(KeypointHead, self).__init__()
self.num_keypoints = num_keypoints
self.output_size = output_size
# Convolutional layers for keypoint detection
self.conv_layers = nn.Sequential(
nn.Conv2d(input_channels, 512, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(512, num_keypoints, kernel_size=1)
)
def forward(self, x):
# Extract keypoint heatmaps
heatmaps = self.conv_layers(x)
# Upsample to target size
heatmaps = F.interpolate(heatmaps, size=self.output_size, mode='bilinear', align_corners=False)
return heatmaps
class WiFiDensePoseRCNN(nn.Module):
"""
Complete WiFi-DensePose RCNN architecture
"""
def __init__(self):
super(WiFiDensePoseRCNN, self).__init__()
# CSI processing
self.phase_processor = CSIPhaseProcessor()
# Modality translation
self.modality_translation = ModalityTranslationNetwork()
# Simplified backbone (in practice, use ResNet-FPN)
self.backbone = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
# Simplified ResNet blocks
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
)
# Prediction heads
self.densepose_head = DensePoseHead(input_channels=256)
self.keypoint_head = KeypointHead(input_channels=256)
# Global average pooling for simplified processing
self.global_pool = nn.AdaptiveAvgPool2d((7, 7))
def forward(self, amplitude_data, phase_data):
batch_size = amplitude_data.shape[0]
# Process CSI phase data
sanitized_phase = self.phase_processor.sanitize_phase(phase_data)
# Translate to spatial domain
spatial_features = self.modality_translation(amplitude_data, sanitized_phase)
# Extract backbone features
backbone_features = self.backbone(spatial_features)
# Global pooling to get fixed-size features
pooled_features = self.global_pool(backbone_features)
# Predict DensePose
densepose_output = self.densepose_head(pooled_features)
# Predict keypoints
keypoint_heatmaps = self.keypoint_head(pooled_features)
return {
'spatial_features': spatial_features,
'densepose': densepose_output,
'keypoints': keypoint_heatmaps
}
class WiFiDensePoseLoss(nn.Module):
"""
Combined loss function for WiFi DensePose training
"""
def __init__(self, lambda_dp=0.6, lambda_kp=0.3, lambda_tr=0.1):
super(WiFiDensePoseLoss, self).__init__()
self.lambda_dp = lambda_dp
self.lambda_kp = lambda_kp
self.lambda_tr = lambda_tr
# Loss functions
self.cross_entropy = nn.CrossEntropyLoss()
self.mse_loss = nn.MSELoss()
self.smooth_l1 = nn.SmoothL1Loss()
def forward(self, predictions, targets, teacher_features=None):
total_loss = 0.0
loss_dict = {}
# DensePose losses
if 'densepose' in predictions and 'densepose' in targets:
# Part classification loss
part_loss = self.cross_entropy(
predictions['densepose']['part_logits'],
targets['densepose']['part_labels']
)
# UV coordinate regression loss
uv_loss = (self.smooth_l1(predictions['densepose']['u_coords'], targets['densepose']['u_coords']) +
self.smooth_l1(predictions['densepose']['v_coords'], targets['densepose']['v_coords'])) / 2
dp_loss = part_loss + uv_loss
total_loss += self.lambda_dp * dp_loss
loss_dict['densepose'] = dp_loss
# Keypoint loss
if 'keypoints' in predictions and 'keypoints' in targets:
kp_loss = self.mse_loss(predictions['keypoints'], targets['keypoints'])
total_loss += self.lambda_kp * kp_loss
loss_dict['keypoint'] = kp_loss
# Transfer learning loss
if teacher_features is not None and 'backbone_features' in predictions:
tr_loss = self.mse_loss(predictions['backbone_features'], teacher_features)
total_loss += self.lambda_tr * tr_loss
loss_dict['transfer'] = tr_loss
loss_dict['total'] = total_loss
return total_loss, loss_dict
# Training utilities
class WiFiDensePoseTrainer:
"""
Training utilities for WiFi DensePose
"""
def __init__(self, model, device='cuda' if torch.cuda.is_available() else 'cpu'):
self.model = model.to(device)
self.device = device
self.criterion = WiFiDensePoseLoss()
self.optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
self.optimizer, milestones=[48000, 96000], gamma=0.1
)
def train_step(self, amplitude_data, phase_data, targets):
self.model.train()
self.optimizer.zero_grad()
# Forward pass
outputs = self.model(amplitude_data, phase_data)
# Compute loss
loss, loss_dict = self.criterion(outputs, targets)
# Backward pass
loss.backward()
self.optimizer.step()
self.scheduler.step()
return loss.item(), loss_dict
def save_model(self, path):
torch.save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
}, path)
def load_model(self, path):
checkpoint = torch.load(path)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Example usage
def create_sample_data(batch_size=1, device='cpu'):
"""
Create sample CSI data for testing
"""
amplitude = torch.randn(batch_size, 150, 3, 3).to(device)
phase = torch.randn(batch_size, 150, 3, 3).to(device)
# Sample targets
targets = {
'densepose': {
'part_labels': torch.randint(0, 25, (batch_size, 112, 112)).to(device),
'u_coords': torch.rand(batch_size, 24, 112, 112).to(device),
'v_coords': torch.rand(batch_size, 24, 112, 112).to(device)
},
'keypoints': torch.rand(batch_size, 17, 56, 56).to(device)
}
return amplitude, phase, targets
if __name__ == "__main__":
# Initialize model
model = WiFiDensePoseRCNN()
trainer = WiFiDensePoseTrainer(model)
print("WiFi DensePose model initialized!")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
# Create sample data
amplitude, phase, targets = create_sample_data()
# Run inference
with torch.no_grad():
outputs = model(amplitude, phase)
print(f"Spatial features shape: {outputs['spatial_features'].shape}")
print(f"DensePose part logits shape: {outputs['densepose']['part_logits'].shape}")
print(f"Keypoint heatmaps shape: {outputs['keypoints'].shape}")
# Training step
loss, loss_dict = trainer.train_step(amplitude, phase, targets)
print(f"Training loss: {loss:.4f}")
print(f"Loss breakdown: {loss_dict}")