first commit
This commit is contained in:
22
taming/modules/losses/segmentation.py
Normal file
22
taming/modules/losses/segmentation.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class BCELoss(nn.Module):
|
||||
def forward(self, prediction, target):
|
||||
loss = F.binary_cross_entropy_with_logits(prediction,target)
|
||||
return loss, {}
|
||||
|
||||
|
||||
class BCELossWithQuant(nn.Module):
|
||||
def __init__(self, codebook_weight=1.):
|
||||
super().__init__()
|
||||
self.codebook_weight = codebook_weight
|
||||
|
||||
def forward(self, qloss, target, prediction, split):
|
||||
bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
|
||||
loss = bce_loss + self.codebook_weight*qloss
|
||||
return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
|
||||
"{}/bce_loss".format(split): bce_loss.detach().mean(),
|
||||
"{}/quant_loss".format(split): qloss.detach().mean()
|
||||
}
|
||||
Reference in New Issue
Block a user