Files
wifi-densepose/src/models/densepose_head.py
rUv cbebdd648f Implement WiFi-DensePose system with CSI data extraction and router interface
- Added CSIExtractor class for extracting CSI data from WiFi routers.
- Implemented RouterInterface class for SSH communication with routers.
- Developed DensePoseHead class for body part segmentation and UV coordinate regression.
- Created unit tests for CSIExtractor and RouterInterface to ensure functionality and error handling.
- Integrated paramiko for SSH connections and command execution.
- Established configuration validation for both extractor and router interface.
- Added context manager support for resource management in both classes.
2025-06-07 05:55:27 +00:00

279 lines
10 KiB
Python

"""DensePose head for WiFi-DensePose system."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Any, Tuple, List
class DensePoseError(Exception):
"""Exception raised for DensePose head errors."""
pass
class DensePoseHead(nn.Module):
"""DensePose head for body part segmentation and UV coordinate regression."""
def __init__(self, config: Dict[str, Any]):
"""Initialize DensePose head.
Args:
config: Configuration dictionary with head parameters
"""
super().__init__()
self._validate_config(config)
self.config = config
self.input_channels = config['input_channels']
self.num_body_parts = config['num_body_parts']
self.num_uv_coordinates = config['num_uv_coordinates']
self.hidden_channels = config.get('hidden_channels', [128, 64])
self.kernel_size = config.get('kernel_size', 3)
self.padding = config.get('padding', 1)
self.dropout_rate = config.get('dropout_rate', 0.1)
self.use_deformable_conv = config.get('use_deformable_conv', False)
self.use_fpn = config.get('use_fpn', False)
self.fpn_levels = config.get('fpn_levels', [2, 3, 4, 5])
self.output_stride = config.get('output_stride', 4)
# Feature Pyramid Network (optional)
if self.use_fpn:
self.fpn = self._build_fpn()
# Shared feature processing
self.shared_conv = self._build_shared_layers()
# Segmentation head for body part classification
self.segmentation_head = self._build_segmentation_head()
# UV regression head for coordinate prediction
self.uv_regression_head = self._build_uv_regression_head()
# Initialize weights
self._initialize_weights()
def _validate_config(self, config: Dict[str, Any]):
"""Validate configuration parameters."""
required_fields = ['input_channels', 'num_body_parts', 'num_uv_coordinates']
for field in required_fields:
if field not in config:
raise ValueError(f"Missing required field: {field}")
if config['input_channels'] <= 0:
raise ValueError("input_channels must be positive")
if config['num_body_parts'] <= 0:
raise ValueError("num_body_parts must be positive")
if config['num_uv_coordinates'] <= 0:
raise ValueError("num_uv_coordinates must be positive")
def _build_fpn(self) -> nn.Module:
"""Build Feature Pyramid Network."""
return nn.ModuleDict({
f'level_{level}': nn.Conv2d(self.input_channels, self.input_channels, 1)
for level in self.fpn_levels
})
def _build_shared_layers(self) -> nn.Module:
"""Build shared feature processing layers."""
layers = []
in_channels = self.input_channels
for hidden_dim in self.hidden_channels:
layers.extend([
nn.Conv2d(in_channels, hidden_dim,
kernel_size=self.kernel_size,
padding=self.padding),
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
nn.Dropout2d(self.dropout_rate)
])
in_channels = hidden_dim
return nn.Sequential(*layers)
def _build_segmentation_head(self) -> nn.Module:
"""Build segmentation head for body part classification."""
final_hidden = self.hidden_channels[-1] if self.hidden_channels else self.input_channels
return nn.Sequential(
nn.Conv2d(final_hidden, final_hidden // 2,
kernel_size=self.kernel_size,
padding=self.padding),
nn.BatchNorm2d(final_hidden // 2),
nn.ReLU(inplace=True),
nn.Dropout2d(self.dropout_rate),
# Upsampling to increase resolution
nn.ConvTranspose2d(final_hidden // 2, final_hidden // 4,
kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(final_hidden // 4),
nn.ReLU(inplace=True),
nn.Conv2d(final_hidden // 4, self.num_body_parts + 1, kernel_size=1),
# +1 for background class
)
def _build_uv_regression_head(self) -> nn.Module:
"""Build UV regression head for coordinate prediction."""
final_hidden = self.hidden_channels[-1] if self.hidden_channels else self.input_channels
return nn.Sequential(
nn.Conv2d(final_hidden, final_hidden // 2,
kernel_size=self.kernel_size,
padding=self.padding),
nn.BatchNorm2d(final_hidden // 2),
nn.ReLU(inplace=True),
nn.Dropout2d(self.dropout_rate),
# Upsampling to increase resolution
nn.ConvTranspose2d(final_hidden // 2, final_hidden // 4,
kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(final_hidden // 4),
nn.ReLU(inplace=True),
nn.Conv2d(final_hidden // 4, self.num_uv_coordinates, kernel_size=1),
)
def _initialize_weights(self):
"""Initialize network weights."""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Forward pass through the DensePose head.
Args:
x: Input feature tensor of shape (batch_size, channels, height, width)
Returns:
Dictionary containing:
- segmentation: Body part logits (batch_size, num_parts+1, height, width)
- uv_coordinates: UV coordinates (batch_size, 2, height, width)
"""
# Validate input shape
if x.shape[1] != self.input_channels:
raise DensePoseError(f"Expected {self.input_channels} input channels, got {x.shape[1]}")
# Apply FPN if enabled
if self.use_fpn:
# Simple FPN processing - in practice this would be more sophisticated
x = self.fpn['level_2'](x)
# Shared feature processing
shared_features = self.shared_conv(x)
# Segmentation branch
segmentation_logits = self.segmentation_head(shared_features)
# UV regression branch
uv_coordinates = self.uv_regression_head(shared_features)
uv_coordinates = torch.sigmoid(uv_coordinates) # Normalize to [0, 1]
return {
'segmentation': segmentation_logits,
'uv_coordinates': uv_coordinates
}
def compute_segmentation_loss(self, pred_logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Compute segmentation loss.
Args:
pred_logits: Predicted segmentation logits
target: Target segmentation masks
Returns:
Computed cross-entropy loss
"""
return F.cross_entropy(pred_logits, target, ignore_index=-1)
def compute_uv_loss(self, pred_uv: torch.Tensor, target_uv: torch.Tensor) -> torch.Tensor:
"""Compute UV coordinate regression loss.
Args:
pred_uv: Predicted UV coordinates
target_uv: Target UV coordinates
Returns:
Computed L1 loss
"""
return F.l1_loss(pred_uv, target_uv)
def compute_total_loss(self, predictions: Dict[str, torch.Tensor],
seg_target: torch.Tensor,
uv_target: torch.Tensor,
seg_weight: float = 1.0,
uv_weight: float = 1.0) -> torch.Tensor:
"""Compute total loss combining segmentation and UV losses.
Args:
predictions: Dictionary of predictions
seg_target: Target segmentation masks
uv_target: Target UV coordinates
seg_weight: Weight for segmentation loss
uv_weight: Weight for UV loss
Returns:
Combined loss
"""
seg_loss = self.compute_segmentation_loss(predictions['segmentation'], seg_target)
uv_loss = self.compute_uv_loss(predictions['uv_coordinates'], uv_target)
return seg_weight * seg_loss + uv_weight * uv_loss
def get_prediction_confidence(self, predictions: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Get prediction confidence scores.
Args:
predictions: Dictionary of predictions
Returns:
Dictionary of confidence scores
"""
seg_logits = predictions['segmentation']
uv_coords = predictions['uv_coordinates']
# Segmentation confidence: max probability
seg_probs = F.softmax(seg_logits, dim=1)
seg_confidence = torch.max(seg_probs, dim=1)[0]
# UV confidence: inverse of prediction variance
uv_variance = torch.var(uv_coords, dim=1, keepdim=True)
uv_confidence = 1.0 / (1.0 + uv_variance)
return {
'segmentation_confidence': seg_confidence,
'uv_confidence': uv_confidence.squeeze(1)
}
def post_process_predictions(self, predictions: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Post-process predictions for final output.
Args:
predictions: Raw predictions from forward pass
Returns:
Post-processed predictions
"""
seg_logits = predictions['segmentation']
uv_coords = predictions['uv_coordinates']
# Convert logits to class predictions
body_parts = torch.argmax(seg_logits, dim=1)
# Get confidence scores
confidence = self.get_prediction_confidence(predictions)
return {
'body_parts': body_parts,
'uv_coordinates': uv_coords,
'confidence_scores': confidence
}