first commit

This commit is contained in:
Guocheng Qian
2023-08-02 19:51:43 -07:00
parent c2891c38cc
commit 13e18567fa
202 changed files with 43362 additions and 17 deletions

View File

@@ -0,0 +1,22 @@
import torch.nn as nn
import torch.nn.functional as F
class BCELoss(nn.Module):
def forward(self, prediction, target):
loss = F.binary_cross_entropy_with_logits(prediction,target)
return loss, {}
class BCELossWithQuant(nn.Module):
def __init__(self, codebook_weight=1.):
super().__init__()
self.codebook_weight = codebook_weight
def forward(self, qloss, target, prediction, split):
bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
loss = bce_loss + self.codebook_weight*qloss
return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
"{}/bce_loss".format(split): bce_loss.detach().mean(),
"{}/quant_loss".format(split): qloss.detach().mean()
}