23 lines
816 B
Python
23 lines
816 B
Python
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class BCELoss(nn.Module):
|
|
def forward(self, prediction, target):
|
|
loss = F.binary_cross_entropy_with_logits(prediction,target)
|
|
return loss, {}
|
|
|
|
|
|
class BCELossWithQuant(nn.Module):
|
|
def __init__(self, codebook_weight=1.):
|
|
super().__init__()
|
|
self.codebook_weight = codebook_weight
|
|
|
|
def forward(self, qloss, target, prediction, split):
|
|
bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
|
|
loss = bce_loss + self.codebook_weight*qloss
|
|
return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
|
|
"{}/bce_loss".format(split): bce_loss.detach().mean(),
|
|
"{}/quant_loss".format(split): qloss.detach().mean()
|
|
}
|