feat: Complete Rust port of WiFi-DensePose with modular crates
Major changes: - Organized Python v1 implementation into v1/ subdirectory - Created Rust workspace with 9 modular crates: - wifi-densepose-core: Core types, traits, errors - wifi-densepose-signal: CSI processing, phase sanitization, FFT - wifi-densepose-nn: Neural network inference (ONNX/Candle/tch) - wifi-densepose-api: Axum-based REST/WebSocket API - wifi-densepose-db: SQLx database layer - wifi-densepose-config: Configuration management - wifi-densepose-hardware: Hardware abstraction - wifi-densepose-wasm: WebAssembly bindings - wifi-densepose-cli: Command-line interface Documentation: - ADR-001: Workspace structure - ADR-002: Signal processing library selection - ADR-003: Neural network inference strategy - DDD domain model with bounded contexts Testing: - 69 tests passing across all crates - Signal processing: 45 tests - Neural networks: 21 tests - Core: 3 doc tests Performance targets: - 10x faster CSI processing (~0.5ms vs ~5ms) - 5x lower memory usage (~100MB vs ~500MB) - WASM support for browser deployment
This commit is contained in:
0
v1/src/models/__init__.py
Normal file
0
v1/src/models/__init__.py
Normal file
279
v1/src/models/densepose_head.py
Normal file
279
v1/src/models/densepose_head.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""DensePose head for WiFi-DensePose system."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Any, Tuple, List
|
||||
|
||||
|
||||
class DensePoseError(Exception):
|
||||
"""Exception raised for DensePose head errors."""
|
||||
pass
|
||||
|
||||
|
||||
class DensePoseHead(nn.Module):
|
||||
"""DensePose head for body part segmentation and UV coordinate regression."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""Initialize DensePose head.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary with head parameters
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self._validate_config(config)
|
||||
self.config = config
|
||||
|
||||
self.input_channels = config['input_channels']
|
||||
self.num_body_parts = config['num_body_parts']
|
||||
self.num_uv_coordinates = config['num_uv_coordinates']
|
||||
self.hidden_channels = config.get('hidden_channels', [128, 64])
|
||||
self.kernel_size = config.get('kernel_size', 3)
|
||||
self.padding = config.get('padding', 1)
|
||||
self.dropout_rate = config.get('dropout_rate', 0.1)
|
||||
self.use_deformable_conv = config.get('use_deformable_conv', False)
|
||||
self.use_fpn = config.get('use_fpn', False)
|
||||
self.fpn_levels = config.get('fpn_levels', [2, 3, 4, 5])
|
||||
self.output_stride = config.get('output_stride', 4)
|
||||
|
||||
# Feature Pyramid Network (optional)
|
||||
if self.use_fpn:
|
||||
self.fpn = self._build_fpn()
|
||||
|
||||
# Shared feature processing
|
||||
self.shared_conv = self._build_shared_layers()
|
||||
|
||||
# Segmentation head for body part classification
|
||||
self.segmentation_head = self._build_segmentation_head()
|
||||
|
||||
# UV regression head for coordinate prediction
|
||||
self.uv_regression_head = self._build_uv_regression_head()
|
||||
|
||||
# Initialize weights
|
||||
self._initialize_weights()
|
||||
|
||||
def _validate_config(self, config: Dict[str, Any]):
|
||||
"""Validate configuration parameters."""
|
||||
required_fields = ['input_channels', 'num_body_parts', 'num_uv_coordinates']
|
||||
for field in required_fields:
|
||||
if field not in config:
|
||||
raise ValueError(f"Missing required field: {field}")
|
||||
|
||||
if config['input_channels'] <= 0:
|
||||
raise ValueError("input_channels must be positive")
|
||||
|
||||
if config['num_body_parts'] <= 0:
|
||||
raise ValueError("num_body_parts must be positive")
|
||||
|
||||
if config['num_uv_coordinates'] <= 0:
|
||||
raise ValueError("num_uv_coordinates must be positive")
|
||||
|
||||
def _build_fpn(self) -> nn.Module:
|
||||
"""Build Feature Pyramid Network."""
|
||||
return nn.ModuleDict({
|
||||
f'level_{level}': nn.Conv2d(self.input_channels, self.input_channels, 1)
|
||||
for level in self.fpn_levels
|
||||
})
|
||||
|
||||
def _build_shared_layers(self) -> nn.Module:
|
||||
"""Build shared feature processing layers."""
|
||||
layers = []
|
||||
in_channels = self.input_channels
|
||||
|
||||
for hidden_dim in self.hidden_channels:
|
||||
layers.extend([
|
||||
nn.Conv2d(in_channels, hidden_dim,
|
||||
kernel_size=self.kernel_size,
|
||||
padding=self.padding),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout2d(self.dropout_rate)
|
||||
])
|
||||
in_channels = hidden_dim
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _build_segmentation_head(self) -> nn.Module:
|
||||
"""Build segmentation head for body part classification."""
|
||||
final_hidden = self.hidden_channels[-1] if self.hidden_channels else self.input_channels
|
||||
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(final_hidden, final_hidden // 2,
|
||||
kernel_size=self.kernel_size,
|
||||
padding=self.padding),
|
||||
nn.BatchNorm2d(final_hidden // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout2d(self.dropout_rate),
|
||||
|
||||
# Upsampling to increase resolution
|
||||
nn.ConvTranspose2d(final_hidden // 2, final_hidden // 4,
|
||||
kernel_size=4, stride=2, padding=1),
|
||||
nn.BatchNorm2d(final_hidden // 4),
|
||||
nn.ReLU(inplace=True),
|
||||
|
||||
nn.Conv2d(final_hidden // 4, self.num_body_parts + 1, kernel_size=1),
|
||||
# +1 for background class
|
||||
)
|
||||
|
||||
def _build_uv_regression_head(self) -> nn.Module:
|
||||
"""Build UV regression head for coordinate prediction."""
|
||||
final_hidden = self.hidden_channels[-1] if self.hidden_channels else self.input_channels
|
||||
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(final_hidden, final_hidden // 2,
|
||||
kernel_size=self.kernel_size,
|
||||
padding=self.padding),
|
||||
nn.BatchNorm2d(final_hidden // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout2d(self.dropout_rate),
|
||||
|
||||
# Upsampling to increase resolution
|
||||
nn.ConvTranspose2d(final_hidden // 2, final_hidden // 4,
|
||||
kernel_size=4, stride=2, padding=1),
|
||||
nn.BatchNorm2d(final_hidden // 4),
|
||||
nn.ReLU(inplace=True),
|
||||
|
||||
nn.Conv2d(final_hidden // 4, self.num_uv_coordinates, kernel_size=1),
|
||||
)
|
||||
|
||||
def _initialize_weights(self):
|
||||
"""Initialize network weights."""
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
"""Forward pass through the DensePose head.
|
||||
|
||||
Args:
|
||||
x: Input feature tensor of shape (batch_size, channels, height, width)
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- segmentation: Body part logits (batch_size, num_parts+1, height, width)
|
||||
- uv_coordinates: UV coordinates (batch_size, 2, height, width)
|
||||
"""
|
||||
# Validate input shape
|
||||
if x.shape[1] != self.input_channels:
|
||||
raise DensePoseError(f"Expected {self.input_channels} input channels, got {x.shape[1]}")
|
||||
|
||||
# Apply FPN if enabled
|
||||
if self.use_fpn:
|
||||
# Simple FPN processing - in practice this would be more sophisticated
|
||||
x = self.fpn['level_2'](x)
|
||||
|
||||
# Shared feature processing
|
||||
shared_features = self.shared_conv(x)
|
||||
|
||||
# Segmentation branch
|
||||
segmentation_logits = self.segmentation_head(shared_features)
|
||||
|
||||
# UV regression branch
|
||||
uv_coordinates = self.uv_regression_head(shared_features)
|
||||
uv_coordinates = torch.sigmoid(uv_coordinates) # Normalize to [0, 1]
|
||||
|
||||
return {
|
||||
'segmentation': segmentation_logits,
|
||||
'uv_coordinates': uv_coordinates
|
||||
}
|
||||
|
||||
def compute_segmentation_loss(self, pred_logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute segmentation loss.
|
||||
|
||||
Args:
|
||||
pred_logits: Predicted segmentation logits
|
||||
target: Target segmentation masks
|
||||
|
||||
Returns:
|
||||
Computed cross-entropy loss
|
||||
"""
|
||||
return F.cross_entropy(pred_logits, target, ignore_index=-1)
|
||||
|
||||
def compute_uv_loss(self, pred_uv: torch.Tensor, target_uv: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute UV coordinate regression loss.
|
||||
|
||||
Args:
|
||||
pred_uv: Predicted UV coordinates
|
||||
target_uv: Target UV coordinates
|
||||
|
||||
Returns:
|
||||
Computed L1 loss
|
||||
"""
|
||||
return F.l1_loss(pred_uv, target_uv)
|
||||
|
||||
def compute_total_loss(self, predictions: Dict[str, torch.Tensor],
|
||||
seg_target: torch.Tensor,
|
||||
uv_target: torch.Tensor,
|
||||
seg_weight: float = 1.0,
|
||||
uv_weight: float = 1.0) -> torch.Tensor:
|
||||
"""Compute total loss combining segmentation and UV losses.
|
||||
|
||||
Args:
|
||||
predictions: Dictionary of predictions
|
||||
seg_target: Target segmentation masks
|
||||
uv_target: Target UV coordinates
|
||||
seg_weight: Weight for segmentation loss
|
||||
uv_weight: Weight for UV loss
|
||||
|
||||
Returns:
|
||||
Combined loss
|
||||
"""
|
||||
seg_loss = self.compute_segmentation_loss(predictions['segmentation'], seg_target)
|
||||
uv_loss = self.compute_uv_loss(predictions['uv_coordinates'], uv_target)
|
||||
|
||||
return seg_weight * seg_loss + uv_weight * uv_loss
|
||||
|
||||
def get_prediction_confidence(self, predictions: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""Get prediction confidence scores.
|
||||
|
||||
Args:
|
||||
predictions: Dictionary of predictions
|
||||
|
||||
Returns:
|
||||
Dictionary of confidence scores
|
||||
"""
|
||||
seg_logits = predictions['segmentation']
|
||||
uv_coords = predictions['uv_coordinates']
|
||||
|
||||
# Segmentation confidence: max probability
|
||||
seg_probs = F.softmax(seg_logits, dim=1)
|
||||
seg_confidence = torch.max(seg_probs, dim=1)[0]
|
||||
|
||||
# UV confidence: inverse of prediction variance
|
||||
uv_variance = torch.var(uv_coords, dim=1, keepdim=True)
|
||||
uv_confidence = 1.0 / (1.0 + uv_variance)
|
||||
|
||||
return {
|
||||
'segmentation_confidence': seg_confidence,
|
||||
'uv_confidence': uv_confidence.squeeze(1)
|
||||
}
|
||||
|
||||
def post_process_predictions(self, predictions: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""Post-process predictions for final output.
|
||||
|
||||
Args:
|
||||
predictions: Raw predictions from forward pass
|
||||
|
||||
Returns:
|
||||
Post-processed predictions
|
||||
"""
|
||||
seg_logits = predictions['segmentation']
|
||||
uv_coords = predictions['uv_coordinates']
|
||||
|
||||
# Convert logits to class predictions
|
||||
body_parts = torch.argmax(seg_logits, dim=1)
|
||||
|
||||
# Get confidence scores
|
||||
confidence = self.get_prediction_confidence(predictions)
|
||||
|
||||
return {
|
||||
'body_parts': body_parts,
|
||||
'uv_coordinates': uv_coords,
|
||||
'confidence_scores': confidence
|
||||
}
|
||||
301
v1/src/models/modality_translation.py
Normal file
301
v1/src/models/modality_translation.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""Modality translation network for WiFi-DensePose system."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Any, List
|
||||
|
||||
|
||||
class ModalityTranslationError(Exception):
|
||||
"""Exception raised for modality translation errors."""
|
||||
pass
|
||||
|
||||
|
||||
class ModalityTranslationNetwork(nn.Module):
|
||||
"""Neural network for translating CSI data to visual feature space."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""Initialize modality translation network.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary with network parameters
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self._validate_config(config)
|
||||
self.config = config
|
||||
|
||||
self.input_channels = config['input_channels']
|
||||
self.hidden_channels = config['hidden_channels']
|
||||
self.output_channels = config['output_channels']
|
||||
self.kernel_size = config.get('kernel_size', 3)
|
||||
self.stride = config.get('stride', 1)
|
||||
self.padding = config.get('padding', 1)
|
||||
self.dropout_rate = config.get('dropout_rate', 0.1)
|
||||
self.activation = config.get('activation', 'relu')
|
||||
self.normalization = config.get('normalization', 'batch')
|
||||
self.use_attention = config.get('use_attention', False)
|
||||
self.attention_heads = config.get('attention_heads', 8)
|
||||
|
||||
# Encoder: CSI -> Feature space
|
||||
self.encoder = self._build_encoder()
|
||||
|
||||
# Decoder: Feature space -> Visual-like features
|
||||
self.decoder = self._build_decoder()
|
||||
|
||||
# Attention mechanism
|
||||
if self.use_attention:
|
||||
self.attention = self._build_attention()
|
||||
|
||||
# Initialize weights
|
||||
self._initialize_weights()
|
||||
|
||||
def _validate_config(self, config: Dict[str, Any]):
|
||||
"""Validate configuration parameters."""
|
||||
required_fields = ['input_channels', 'hidden_channels', 'output_channels']
|
||||
for field in required_fields:
|
||||
if field not in config:
|
||||
raise ValueError(f"Missing required field: {field}")
|
||||
|
||||
if config['input_channels'] <= 0:
|
||||
raise ValueError("input_channels must be positive")
|
||||
|
||||
if not config['hidden_channels'] or len(config['hidden_channels']) == 0:
|
||||
raise ValueError("hidden_channels must be a non-empty list")
|
||||
|
||||
if config['output_channels'] <= 0:
|
||||
raise ValueError("output_channels must be positive")
|
||||
|
||||
def _build_encoder(self) -> nn.ModuleList:
|
||||
"""Build encoder network."""
|
||||
layers = nn.ModuleList()
|
||||
|
||||
# Initial convolution
|
||||
in_channels = self.input_channels
|
||||
|
||||
for i, out_channels in enumerate(self.hidden_channels):
|
||||
layer_block = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=self.stride if i == 0 else 2,
|
||||
padding=self.padding),
|
||||
self._get_normalization(out_channels),
|
||||
self._get_activation(),
|
||||
nn.Dropout2d(self.dropout_rate)
|
||||
)
|
||||
layers.append(layer_block)
|
||||
in_channels = out_channels
|
||||
|
||||
return layers
|
||||
|
||||
def _build_decoder(self) -> nn.ModuleList:
|
||||
"""Build decoder network."""
|
||||
layers = nn.ModuleList()
|
||||
|
||||
# Start with the last hidden channel size
|
||||
in_channels = self.hidden_channels[-1]
|
||||
|
||||
# Progressive upsampling (reverse of encoder)
|
||||
for i, out_channels in enumerate(reversed(self.hidden_channels[:-1])):
|
||||
layer_block = nn.Sequential(
|
||||
nn.ConvTranspose2d(in_channels, out_channels,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=2,
|
||||
padding=self.padding,
|
||||
output_padding=1),
|
||||
self._get_normalization(out_channels),
|
||||
self._get_activation(),
|
||||
nn.Dropout2d(self.dropout_rate)
|
||||
)
|
||||
layers.append(layer_block)
|
||||
in_channels = out_channels
|
||||
|
||||
# Final output layer
|
||||
final_layer = nn.Sequential(
|
||||
nn.Conv2d(in_channels, self.output_channels,
|
||||
kernel_size=self.kernel_size,
|
||||
padding=self.padding),
|
||||
nn.Tanh() # Normalize output
|
||||
)
|
||||
layers.append(final_layer)
|
||||
|
||||
return layers
|
||||
|
||||
def _get_normalization(self, channels: int) -> nn.Module:
|
||||
"""Get normalization layer."""
|
||||
if self.normalization == 'batch':
|
||||
return nn.BatchNorm2d(channels)
|
||||
elif self.normalization == 'instance':
|
||||
return nn.InstanceNorm2d(channels)
|
||||
elif self.normalization == 'layer':
|
||||
return nn.GroupNorm(1, channels)
|
||||
else:
|
||||
return nn.Identity()
|
||||
|
||||
def _get_activation(self) -> nn.Module:
|
||||
"""Get activation function."""
|
||||
if self.activation == 'relu':
|
||||
return nn.ReLU(inplace=True)
|
||||
elif self.activation == 'leaky_relu':
|
||||
return nn.LeakyReLU(0.2, inplace=True)
|
||||
elif self.activation == 'gelu':
|
||||
return nn.GELU()
|
||||
else:
|
||||
return nn.ReLU(inplace=True)
|
||||
|
||||
def _build_attention(self) -> nn.Module:
|
||||
"""Build attention mechanism."""
|
||||
return nn.MultiheadAttention(
|
||||
embed_dim=self.hidden_channels[-1],
|
||||
num_heads=self.attention_heads,
|
||||
dropout=self.dropout_rate,
|
||||
batch_first=True
|
||||
)
|
||||
|
||||
def _initialize_weights(self):
|
||||
"""Initialize network weights."""
|
||||
for m in self.modules():
|
||||
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass through the network.
|
||||
|
||||
Args:
|
||||
x: Input CSI tensor of shape (batch_size, channels, height, width)
|
||||
|
||||
Returns:
|
||||
Translated features tensor
|
||||
"""
|
||||
# Validate input shape
|
||||
if x.shape[1] != self.input_channels:
|
||||
raise ModalityTranslationError(f"Expected {self.input_channels} input channels, got {x.shape[1]}")
|
||||
|
||||
# Encode CSI data
|
||||
encoded_features = self.encode(x)
|
||||
|
||||
# Decode to visual-like features
|
||||
decoded = self.decode(encoded_features)
|
||||
|
||||
return decoded
|
||||
|
||||
def encode(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
"""Encode input through encoder layers.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
|
||||
Returns:
|
||||
List of feature maps from each encoder layer
|
||||
"""
|
||||
features = []
|
||||
current = x
|
||||
|
||||
for layer in self.encoder:
|
||||
current = layer(current)
|
||||
features.append(current)
|
||||
|
||||
return features
|
||||
|
||||
def decode(self, encoded_features: List[torch.Tensor]) -> torch.Tensor:
|
||||
"""Decode features through decoder layers.
|
||||
|
||||
Args:
|
||||
encoded_features: List of encoded feature maps
|
||||
|
||||
Returns:
|
||||
Decoded output tensor
|
||||
"""
|
||||
# Start with the last encoded feature
|
||||
current = encoded_features[-1]
|
||||
|
||||
# Apply attention if enabled
|
||||
if self.use_attention:
|
||||
batch_size, channels, height, width = current.shape
|
||||
# Reshape for attention: (batch, seq_len, embed_dim)
|
||||
current_flat = current.view(batch_size, channels, -1).transpose(1, 2)
|
||||
attended, _ = self.attention(current_flat, current_flat, current_flat)
|
||||
current = attended.transpose(1, 2).view(batch_size, channels, height, width)
|
||||
|
||||
# Apply decoder layers
|
||||
for layer in self.decoder:
|
||||
current = layer(current)
|
||||
|
||||
return current
|
||||
|
||||
def compute_translation_loss(self, predicted: torch.Tensor, target: torch.Tensor, loss_type: str = 'mse') -> torch.Tensor:
|
||||
"""Compute translation loss between predicted and target features.
|
||||
|
||||
Args:
|
||||
predicted: Predicted feature tensor
|
||||
target: Target feature tensor
|
||||
loss_type: Type of loss ('mse', 'l1', 'smooth_l1')
|
||||
|
||||
Returns:
|
||||
Computed loss tensor
|
||||
"""
|
||||
if loss_type == 'mse':
|
||||
return F.mse_loss(predicted, target)
|
||||
elif loss_type == 'l1':
|
||||
return F.l1_loss(predicted, target)
|
||||
elif loss_type == 'smooth_l1':
|
||||
return F.smooth_l1_loss(predicted, target)
|
||||
else:
|
||||
return F.mse_loss(predicted, target)
|
||||
|
||||
def get_feature_statistics(self, features: torch.Tensor) -> Dict[str, float]:
|
||||
"""Get statistics of feature tensor.
|
||||
|
||||
Args:
|
||||
features: Feature tensor to analyze
|
||||
|
||||
Returns:
|
||||
Dictionary of feature statistics
|
||||
"""
|
||||
with torch.no_grad():
|
||||
return {
|
||||
'mean': features.mean().item(),
|
||||
'std': features.std().item(),
|
||||
'min': features.min().item(),
|
||||
'max': features.max().item(),
|
||||
'sparsity': (features == 0).float().mean().item()
|
||||
}
|
||||
|
||||
def get_intermediate_features(self, x: torch.Tensor) -> Dict[str, Any]:
|
||||
"""Get intermediate features for visualization.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
|
||||
Returns:
|
||||
Dictionary containing intermediate features
|
||||
"""
|
||||
result = {}
|
||||
|
||||
# Get encoder features
|
||||
encoder_features = self.encode(x)
|
||||
result['encoder_features'] = encoder_features
|
||||
|
||||
# Get decoder features
|
||||
decoder_features = []
|
||||
current = encoder_features[-1]
|
||||
|
||||
if self.use_attention:
|
||||
batch_size, channels, height, width = current.shape
|
||||
current_flat = current.view(batch_size, channels, -1).transpose(1, 2)
|
||||
attended, attention_weights = self.attention(current_flat, current_flat, current_flat)
|
||||
current = attended.transpose(1, 2).view(batch_size, channels, height, width)
|
||||
result['attention_weights'] = attention_weights
|
||||
|
||||
for layer in self.decoder:
|
||||
current = layer(current)
|
||||
decoder_features.append(current)
|
||||
|
||||
result['decoder_features'] = decoder_features
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user