1600 lines
69 KiB
Python
1600 lines
69 KiB
Python
import os
|
|
import glob
|
|
import tqdm
|
|
import random
|
|
import logging
|
|
import gc
|
|
|
|
import numpy as np
|
|
import imageio, imageio_ffmpeg
|
|
import time
|
|
|
|
import cv2
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import torch.nn.functional as F
|
|
import torch.distributed as dist
|
|
from torch import Tensor
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
import torchvision.transforms.functional as TF
|
|
from torchvision.utils import make_grid
|
|
from torchmetrics.functional import pearson_corrcoef
|
|
|
|
from rich.console import Console
|
|
from torch_ema import ExponentialMovingAverage
|
|
|
|
from packaging import version as pver
|
|
|
|
from nerf.clip import CLIP
|
|
from easydict import EasyDict as edict
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AverageMeters(object):
|
|
"""Computes and stores the average and current value"""
|
|
def __init__(self, keys=['loss']):
|
|
self.meters = edict()
|
|
self.keys= keys
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
for key in self.keys:
|
|
self.meters[key] = {}
|
|
self.meters[key].val = 0
|
|
self.meters[key].avg = 0
|
|
self.meters[key].sum = 0
|
|
self.meters[key].count = 0
|
|
|
|
def reset_by_key(self, key):
|
|
self.meters[key] = {}
|
|
self.meters[key].val = 0
|
|
self.meters[key].avg = 0
|
|
self.meters[key].sum = 0
|
|
self.meters[key].count = 0
|
|
|
|
def update(self, in_dict, n=1):
|
|
for key, val in in_dict.items():
|
|
if key not in self.keys:
|
|
self.keys.append(key)
|
|
self.reset_by_key(key)
|
|
self.meters[key].val = val
|
|
self.meters[key].sum += val * n
|
|
self.meters[key].count += n
|
|
self.meters[key].avg = self.meters[key].sum / self.meters[key].count
|
|
|
|
|
|
def setup_workspace(opt):
|
|
if opt.workspace is None or opt.workspace == '':
|
|
opt.workspace = 'out/'
|
|
if opt.text:
|
|
opt.workspace += '_'.join(opt.text.split(' '))
|
|
if opt.image:
|
|
opt.workspace += '_'.join('_'.join(opt.image.split('/')
|
|
[-2:]).split('.')[:-1])
|
|
opt.workspace += '+' + time.strftime('%Y%m%d-%H%M%S')
|
|
opt.runname = os.path.basename(opt.workspace)
|
|
os.makedirs(opt.workspace, exist_ok=True)
|
|
opt.log_path = os.path.join(opt.workspace, f"log_{opt.runname}.txt")
|
|
opt.ckpt_path = os.path.join(opt.workspace, 'checkpoints')
|
|
opt.best_path = f"{opt.ckpt_path}/{opt.runname}.pth"
|
|
os.makedirs(opt.ckpt_path, exist_ok=True)
|
|
|
|
|
|
def custom_meshgrid(*args):
|
|
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
|
|
if pver.parse(torch.__version__) < pver.parse('1.10'):
|
|
return torch.meshgrid(*args)
|
|
else:
|
|
return torch.meshgrid(*args, indexing='ij')
|
|
|
|
def safe_normalize(x, eps=1e-20):
|
|
return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps))
|
|
|
|
@torch.cuda.amp.autocast(enabled=False)
|
|
def get_rays(poses, intrinsics, H, W, N=-1, error_map=None):
|
|
''' get rays
|
|
Args:
|
|
poses: [B, 4, 4], cam2world
|
|
intrinsics: [4]
|
|
H, W, N: int
|
|
error_map: [B, 128 * 128], sample probability based on training error
|
|
Returns:
|
|
rays_o, rays_d: [B, N, 3]
|
|
inds: [B, N]
|
|
'''
|
|
|
|
device = poses.device
|
|
B = poses.shape[0]
|
|
fx, fy, cx, cy = intrinsics
|
|
|
|
i, j = custom_meshgrid(torch.linspace(0, W-1, W, device=device), torch.linspace(0, H-1, H, device=device))
|
|
i = i.t().reshape([1, H*W]).expand([B, H*W]) + 0.5
|
|
j = j.t().reshape([1, H*W]).expand([B, H*W]) + 0.5
|
|
|
|
results = {}
|
|
|
|
if N > 0:
|
|
N = min(N, H*W)
|
|
|
|
if error_map is None:
|
|
inds = torch.randint(0, H*W, size=[N], device=device) # may duplicate
|
|
inds = inds.expand([B, N])
|
|
else:
|
|
|
|
# weighted sample on a low-reso grid
|
|
inds_coarse = torch.multinomial(error_map.to(device), N, replacement=False) # [B, N], but in [0, 128*128)
|
|
|
|
# map to the original resolution with random perturb.
|
|
inds_x, inds_y = inds_coarse // 128, inds_coarse % 128 # `//` will throw a warning in torch 1.10... anyway.
|
|
sx, sy = H / 128, W / 128
|
|
inds_x = (inds_x * sx + torch.rand(B, N, device=device) * sx).long().clamp(max=H - 1)
|
|
inds_y = (inds_y * sy + torch.rand(B, N, device=device) * sy).long().clamp(max=W - 1)
|
|
inds = inds_x * W + inds_y
|
|
|
|
results['inds_coarse'] = inds_coarse # need this when updating error_map
|
|
|
|
i = torch.gather(i, -1, inds)
|
|
j = torch.gather(j, -1, inds)
|
|
|
|
results['inds'] = inds
|
|
|
|
else:
|
|
inds = torch.arange(H*W, device=device).expand([B, H*W])
|
|
|
|
zs = - torch.ones_like(i)
|
|
xs = - (i - cx) / fx * zs
|
|
ys = (j - cy) / fy * zs
|
|
directions = torch.stack((xs, ys, zs), dim=-1)
|
|
# directions = safe_normalize(directions)
|
|
rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3)
|
|
|
|
rays_o = poses[..., :3, 3] # [B, 3]
|
|
rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3]
|
|
|
|
results['rays_o'] = rays_o
|
|
results['rays_d'] = rays_d
|
|
|
|
return results
|
|
|
|
|
|
def seed_everything(seed):
|
|
random.seed(seed)
|
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
#torch.backends.cudnn.deterministic = True
|
|
#torch.backends.cudnn.benchmark = True
|
|
|
|
|
|
def save_tensor2image(x: torch.Tensor, path, channel_last=False, quality=75, **kwargs):
|
|
# assume the input x is channel last
|
|
if x.ndim == 4 and channel_last:
|
|
x = x.permute(0, 3, 1, 2)
|
|
TF.to_pil_image(make_grid(x, value_range=(0, 1), **kwargs)).save(path, quality=quality)
|
|
|
|
@torch.jit.script
|
|
def linear_to_srgb(x):
|
|
return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055)
|
|
|
|
|
|
@torch.jit.script
|
|
def srgb_to_linear(x):
|
|
return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)
|
|
|
|
|
|
def nonzero_normalize_depth(depth, mask=None):
|
|
if mask is not None:
|
|
if (depth[mask]>0).sum() > 0:
|
|
nonzero_depth_min = depth[mask][depth[mask]>0].min()
|
|
else:
|
|
nonzero_depth_min = 0
|
|
else:
|
|
if (depth>0).sum() > 0:
|
|
nonzero_depth_min = depth[depth>0].min()
|
|
else:
|
|
nonzero_depth_min = 0
|
|
if nonzero_depth_min == 0:
|
|
return depth
|
|
else:
|
|
depth = (depth - nonzero_depth_min) / depth.max()
|
|
return depth.clamp(0, 1)
|
|
|
|
|
|
class Trainer(object):
|
|
def __init__(self,
|
|
argv, # command line args
|
|
name, # name of this experiment
|
|
opt, # extra conf
|
|
model, # network
|
|
guidance, # guidance network
|
|
criterion=None, # loss function, if None, assume inline implementation in train_step
|
|
optimizer=None, # optimizer
|
|
ema_decay=None, # if use EMA, set the decay
|
|
lr_scheduler=None, # scheduler
|
|
metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric.
|
|
local_rank=0, # which GPU am I
|
|
world_size=1, # total num of GPUs
|
|
device=None, # device to use, usually setting to None is OK. (auto choose device)
|
|
mute=False, # whether to mute all print
|
|
fp16=False, # amp optimize level
|
|
max_keep_ckpt=1, # max num of saved ckpts in disk
|
|
workspace='workspace', # workspace to save logs & ckpts
|
|
best_mode='min', # the smaller/larger result, the better
|
|
use_loss_as_metric=True, # use loss as the first metric
|
|
report_metric_at_train=False, # also report metrics at training
|
|
use_checkpoint="latest", # which ckpt to use at init time
|
|
use_tensorboard=True, # whether to use tensorboard for logging
|
|
scheduler_update_every_step=False, # whether to call scheduler.step() after every train step
|
|
):
|
|
|
|
self.argv = argv
|
|
self.name = name
|
|
self.opt = opt
|
|
self.mute = mute
|
|
self.metrics = metrics
|
|
self.local_rank = local_rank
|
|
self.world_size = world_size
|
|
self.workspace = workspace
|
|
self.ema_decay = ema_decay
|
|
self.fp16 = fp16
|
|
self.best_mode = best_mode
|
|
self.use_loss_as_metric = use_loss_as_metric
|
|
self.report_metric_at_train = report_metric_at_train
|
|
self.max_keep_ckpt = max_keep_ckpt
|
|
self.use_checkpoint = use_checkpoint
|
|
self.use_tensorboard = use_tensorboard
|
|
self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S")
|
|
self.scheduler_update_every_step = scheduler_update_every_step
|
|
self.device = device if device is not None else torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu')
|
|
self.console = Console()
|
|
|
|
model.to(self.device)
|
|
if self.world_size > 1:
|
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
|
|
self.model = model
|
|
|
|
# guide model
|
|
self.guidance = guidance
|
|
self.embeddings = {}
|
|
|
|
# text prompt / images
|
|
if self.guidance is not None:
|
|
for key in self.guidance:
|
|
for p in self.guidance[key].parameters():
|
|
p.requires_grad = False
|
|
self.embeddings[key] = {}
|
|
self.prepare_embeddings()
|
|
|
|
if isinstance(criterion, nn.Module):
|
|
criterion.to(self.device)
|
|
self.criterion = criterion
|
|
|
|
if optimizer is None:
|
|
self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=5e-4) # naive adam
|
|
else:
|
|
self.optimizer = optimizer(self.model)
|
|
|
|
if lr_scheduler is None:
|
|
self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1) # fake scheduler
|
|
else:
|
|
self.lr_scheduler = lr_scheduler(self.optimizer)
|
|
|
|
if ema_decay:
|
|
self.ema = ExponentialMovingAverage(
|
|
self.model.parameters(), decay=ema_decay)
|
|
else:
|
|
self.ema = None
|
|
|
|
self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)
|
|
|
|
# variable init
|
|
self.total_train_t = 0
|
|
self.epoch = 0
|
|
self.global_step = 0
|
|
self.local_step = 0
|
|
self.novel_view_step = 0
|
|
self.stats = {
|
|
"loss": [],
|
|
"valid_loss": [],
|
|
"results": [], # metrics[0], or valid_loss
|
|
"checkpoints": [], # record path of saved ckpt, to automatically remove old ckpt
|
|
"best_result": None,
|
|
}
|
|
self.loss_meter = AverageMeters()
|
|
# auto fix
|
|
if len(metrics) == 0 or self.use_loss_as_metric:
|
|
self.best_mode = 'min'
|
|
|
|
logger.info(f'[INFO] cmdline: {self.argv}')
|
|
logger.info(f'args:\n{self.opt}')
|
|
logger.info(
|
|
f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}')
|
|
logger.info(
|
|
f'[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}')
|
|
logger.info(f'[INFO] #Optimizer: \n{self.optimizer}')
|
|
logger.info(f'[INFO] #Scheduler: \n{self.lr_scheduler}')
|
|
|
|
if self.workspace is not None:
|
|
if self.use_checkpoint == "scratch":
|
|
logger.info("[INFO] Training from scratch ...")
|
|
elif self.use_checkpoint == "latest":
|
|
logger.info("[INFO] Loading latest checkpoint ...")
|
|
self.load_checkpoint()
|
|
elif self.use_checkpoint == "latest_model":
|
|
logger.info("[INFO] Loading latest checkpoint (model only)...")
|
|
self.load_checkpoint(model_only=True)
|
|
elif self.use_checkpoint == "best":
|
|
if os.path.exists(self.opt.best_path):
|
|
logger.info("[INFO] Loading best checkpoint ...")
|
|
self.load_checkpoint(self.opt.best_path)
|
|
else:
|
|
logger.info(
|
|
f"[INFO] {self.opt.best_path} not found, loading latest ...")
|
|
self.load_checkpoint()
|
|
else: # path to ckpt
|
|
logger.info(f"[INFO] Loading {self.use_checkpoint} ...")
|
|
self.load_checkpoint(self.use_checkpoint)
|
|
|
|
# calculate the text embs.
|
|
@torch.no_grad()
|
|
def prepare_embeddings(self):
|
|
|
|
# text embeddings (stable-diffusion)
|
|
if self.opt.text is not None:
|
|
|
|
dir_texts = ['front', 'side', 'back']
|
|
if 'SD' in self.guidance:
|
|
self.embeddings['SD']['default'] = self.guidance['SD'].get_all_text_embeds([self.opt.text])
|
|
neg_embedding = self.guidance['SD'].get_all_text_embeds([self.opt.negative])
|
|
|
|
for idx, d in enumerate(dir_texts):
|
|
text = f"{self.opt.text}, {d} view"
|
|
self.embeddings['SD'][d] = self.guidance['SD'].get_all_text_embeds([text])
|
|
if self.opt.dir_texts_neg:
|
|
text_neg = self.opt.negative + ', '.join([text+' view' for i, text in enumerate(dir_texts) if i != idx])
|
|
logger.info(f'dir_texts of {d}\n postive text: {text},\n negative text: {text_neg}')
|
|
neg_embedding= self.guidance['SD'].get_all_text_embeds([text_neg])
|
|
self.embeddings['SD'][d] = torch.cat((neg_embedding, self.embeddings['SD'][d]), dim=0)
|
|
|
|
if 'IF' in self.guidance:
|
|
self.embeddings['IF']['default'] = self.guidance['IF'].get_text_embeds([self.opt.text])
|
|
neg_embedding = self.guidance['IF'].get_text_embeds([self.opt.negative])
|
|
|
|
for idx, d in enumerate(dir_texts):
|
|
text = f"{self.opt.text}, {d} view"
|
|
self.embeddings['IF'][d] = self.guidance['IF'].get_text_embeds([text])
|
|
if self.opt.dir_texts_neg:
|
|
text_neg = self.opt.negative + ', '.join([text+' view' for i, text in enumerate(dir_texts) if i != idx])
|
|
logger.info(f'dir_texts of {d}\n postive text: {text},\n negative text: {text_neg}')
|
|
neg_embedding= self.guidance['IF'].get_all_text_embeds([text_neg])
|
|
self.embeddings['IF'][d] = torch.cat((neg_embedding, self.embeddings['IF'][d]), dim=0)
|
|
|
|
# if 'clip' in self.guidance:
|
|
# self.embeddings['clip']['text'] = self.guidance['clip'].get_text_embeds(self.opt.text)
|
|
|
|
if self.opt.images is not None:
|
|
|
|
h = int(self.opt.known_view_scale * self.opt.h)
|
|
w = int(self.opt.known_view_scale * self.opt.w)
|
|
|
|
# load processed image and remove edges
|
|
rgbas = []
|
|
rgbas_hw = []
|
|
mask_no_edges = []
|
|
for image in self.opt.images:
|
|
rgba = cv2.cvtColor(cv2.imread(image, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA)
|
|
rgbas.append(rgba)
|
|
rgba_hw = cv2.resize(rgba, (w, h), interpolation=cv2.INTER_AREA).astype(np.float32) / 255
|
|
rgbas_hw.append(rgba_hw)
|
|
if self.opt.rm_edge:
|
|
alpha = np.uint8(rgba_hw[..., 3] * 255.)
|
|
dilate = cv2.dilate(alpha, np.ones((self.opt.edge_width, self.opt.edge_width), np.uint8))
|
|
edge = cv2.absdiff(alpha, dilate).astype(np.float32) / 255
|
|
mask_no_edge = rgba_hw[..., 3] > 0.5
|
|
mask_no_edge[edge>self.opt.edge_threshold] = False
|
|
mask_no_edges.append(mask_no_edge)
|
|
rgba_hw = np.stack(rgbas_hw)
|
|
mask = rgba_hw[..., 3] > 0.5
|
|
if len(mask_no_edges) > 0:
|
|
mask_no_edge = np.stack(mask_no_edges)
|
|
else:
|
|
mask_no_edge = mask
|
|
|
|
# breakpoint()
|
|
# rgb
|
|
rgb_hw = rgba_hw[..., :3] * rgba_hw[..., 3:] + (1 - rgba_hw[..., 3:])
|
|
self.rgb = torch.from_numpy(rgb_hw).permute(0,3,1,2).contiguous().to(self.device)
|
|
self.mask = torch.from_numpy(mask).to(self.device)
|
|
self.opacity = torch.from_numpy(mask_no_edge).to(self.device).to(torch.float32).unsqueeze(0)
|
|
print(f'[INFO] dataset: load image prompt {self.opt.images} {self.rgb.shape}')
|
|
|
|
# load depth
|
|
depth_paths = [image.replace('rgba', 'depth') for image in self.opt.images]
|
|
if os.path.exists(depth_paths[0]):
|
|
depths = [cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) for depth_path in depth_paths]
|
|
depth = np.stack([cv2.resize(depth, (w, h), interpolation=cv2.INTER_AREA) for depth in depths])
|
|
self.depth = 1 - torch.from_numpy(depth.astype(np.float32) / 255).to(self.device)
|
|
if len(self.depth.shape) == 4 and self.depth.shape[-1] > 1:
|
|
self.depth = self.depth[..., 0]
|
|
logger.info(f'[WARN] dataset: {depth_paths[0]} has more than one channel, only use the first channel')
|
|
if self.opt.normalize_depth:
|
|
self.depth = nonzero_normalize_depth(self.depth, self.mask)
|
|
save_tensor2image(self.depth, os.path.join(self.workspace, 'depth_resized.jpg'))
|
|
self.depth = self.depth[self.mask]
|
|
print(f'[INFO] dataset: load depth prompt {depth_paths} {self.depth.shape}')
|
|
else:
|
|
self.depth = None
|
|
logger.info(f'[WARN] dataset: {depth_paths[0]} is not found')
|
|
|
|
# load normal
|
|
normal_paths = [image.replace('rgba', 'normal') for image in self.opt.images]
|
|
if os.path.exists(normal_paths[0]):
|
|
normals = []
|
|
for normal_path in normal_paths:
|
|
normal = cv2.imread(normal_path, cv2.IMREAD_UNCHANGED)
|
|
if normal.shape[-1] == 4:
|
|
normal = cv2.cvtColor(normal, cv2.COLOR_BGRA2RGB)
|
|
normals.append(normal)
|
|
normal = np.stack([cv2.resize(normal, (w, h), interpolation=cv2.INTER_AREA) for normal in normals])
|
|
self.normal = torch.from_numpy(normal.astype(np.float32) / 255).to(self.device)
|
|
save_tensor2image(self.normal, os.path.join(self.workspace, 'normal_resized.jpg'), channel_last=True)
|
|
print(f'[INFO] dataset: load normal prompt {normal_paths} {self.normal.shape}')
|
|
self.normal = self.normal[self.mask]
|
|
else:
|
|
self.normal = None
|
|
logger.info(f'[WARN] dataset: {normal_paths[0]} is not found')
|
|
|
|
# save for debug
|
|
save_tensor2image(self.rgb, os.path.join(self.workspace, 'rgb_resized.png'), channel_last=False)
|
|
save_tensor2image(self.opacity, os.path.join(self.workspace, 'opacity_resized.png'), channel_last=False)
|
|
|
|
# encode embeddings for zero123
|
|
if 'zero123' in self.guidance:
|
|
rgba_256 = np.stack([cv2.resize(rgba, (256, 256), interpolation=cv2.INTER_AREA).astype(np.float32) / 255 for rgba in rgbas])
|
|
rgbs_256 = rgba_256[..., :3] * rgba_256[..., 3:] + (1 - rgba_256[..., 3:])
|
|
rgb_256 = torch.from_numpy(rgbs_256).permute(0,3,1,2).contiguous().to(self.device)
|
|
guidance_embeds = self.guidance['zero123'].get_img_embeds(rgb_256)
|
|
self.embeddings['zero123']['default'] = {
|
|
'zero123_ws' : self.opt.zero123_ws,
|
|
'c_crossattn' : guidance_embeds[0],
|
|
'c_concat' : guidance_embeds[1],
|
|
'ref_polars' : self.opt.ref_polars,
|
|
'ref_azimuths' : self.opt.ref_azimuths,
|
|
'ref_radii' : self.opt.ref_radii,
|
|
}
|
|
|
|
# if 'clip' in self.guidance:
|
|
# self.embeddings['clip']['image'] = self.guidance['clip'].get_img_embeds(self.rgb)
|
|
# encoder image for clip
|
|
if self.opt.use_clip:
|
|
self.rgb_clip_embed = self.guidance.get_clip_img_embeds(self.rgb)
|
|
# debug.
|
|
scaler = torch.cuda.amp.GradScaler()
|
|
image = torch.randn((1,3,512,512), device=self.device, requires_grad=True)
|
|
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
|
loss = self.guidance.clip_loss(self.rgb_clip_embed, image)
|
|
scaler.scale(loss).backward()
|
|
else:
|
|
self.rgb_clip_embed = None
|
|
|
|
|
|
# ------------------------------
|
|
@torch.no_grad()
|
|
def match_known(self, **kwargs):
|
|
self.model.eval()
|
|
data = self.default_view_data
|
|
rays_o = data['rays_o'] # [B, N, 3]
|
|
rays_d = data['rays_d'] # [B, N, 3]
|
|
mvp = data['mvp'] # [B, 4, 4]
|
|
|
|
B, N = rays_o.shape[:2]
|
|
H, W = data['H'], data['W']
|
|
|
|
ambient_ratio = 1.0
|
|
shading = self.opt.known_shading
|
|
binarize = False
|
|
bg_color = self.get_bg_color(
|
|
self.opt.bg_color_known, B*N, rays_o.device)
|
|
|
|
# add camera noise to avoid grid-like artifect
|
|
# * (1 - self.global_step / self.opt.iters)
|
|
noise_scale = self.opt.known_view_noise_scale
|
|
rays_o = rays_o + torch.randn(3, device=self.device) * noise_scale
|
|
rays_d = rays_d + torch.randn(3, device=self.device) * noise_scale
|
|
|
|
outputs = self.model.render(rays_o, rays_d, mvp, H, W, staged=False, perturb=True,
|
|
bg_color=bg_color, ambient_ratio=ambient_ratio, shading=shading, binarize=binarize)
|
|
pred_rgb = outputs['image'].reshape(B, H, W, 3).permute(
|
|
0, 3, 1, 2).contiguous() # [1, 3, H, W]
|
|
pred_mask = outputs['weights_sum'].reshape(B, 1, H, W)
|
|
|
|
rgb_loss = self.opt.lambda_rgb * \
|
|
F.mse_loss(pred_rgb*self.opacity,
|
|
self.rgb*self.opacity)
|
|
mask_loss = self.opt.lambda_mask * \
|
|
F.mse_loss(pred_mask, self.mask.to(torch.float32).unsqueeze(0))
|
|
return pred_rgb, pred_mask, rgb_loss, mask_loss
|
|
|
|
def get_bg_color(self, bg_type, N, device):
|
|
if bg_type is None:
|
|
return None
|
|
elif isinstance(bg_type, str):
|
|
if bg_type == 'pixelnoise':
|
|
bg_color = torch.rand((N, 3), device=device)
|
|
elif bg_type == 'noise':
|
|
bg_color = torch.rand((1, 3), device=device).repeat(N, 1)
|
|
elif bg_type == 'white':
|
|
bg_color = torch.ones((N, 3), device=device)
|
|
return bg_color
|
|
elif isinstance(bg_type, Tensor):
|
|
bg_color = bg_color.to(device)
|
|
return bg_color
|
|
else:
|
|
raise NotImplementedError(f"{bg_type} is not implemented")
|
|
|
|
# def margin_rank_loss(self, depth):
|
|
# # high res, only calc on fg
|
|
# output = depth.squeeze().view(-1)
|
|
# output = output[self.fg_idx]
|
|
# num = output.shape[0] # [n, 1]
|
|
# # print(num)
|
|
# output = output.reshape(1, -1)
|
|
# o1 = output.expand(num, -1).reshape(-1)
|
|
# o2 = output.T.expand(-1, num).reshape(-1)
|
|
# return F.margin_ranking_loss(o1, o2, self.rank_loss_target)
|
|
|
|
def train_step(self, data):
|
|
# perform RGBD loss instead of SDS if is image-conditioned
|
|
do_rgbd_loss = self.opt.images is not None and \
|
|
(self.global_step < self.opt.known_iters) or (self.global_step % self.opt.known_view_interval == 0)
|
|
|
|
# override random camera with fixed known camera
|
|
if do_rgbd_loss:
|
|
data = self.default_view_data
|
|
|
|
# progressively relaxing view range
|
|
if self.opt.progressive_view:
|
|
r = min(1.0, 0.2 + self.global_step / (0.5 * self.opt.iters))
|
|
self.opt.phi_range = [self.opt.default_azimuth * (1 - r) + self.opt.full_phi_range[0] * r,
|
|
self.opt.default_azimuth * (1 - r) + self.opt.full_phi_range[1] * r]
|
|
self.opt.theta_range = [self.opt.default_polar * (1 - r) + self.opt.full_theta_range[0] * r,
|
|
self.opt.default_polar * (1 - r) + self.opt.full_theta_range[1] * r]
|
|
self.opt.radius_range = [self.opt.default_radius * (1 - r) + self.opt.full_radius_range[0] * r,
|
|
self.opt.default_radius * (1 - r) + self.opt.full_radius_range[1] * r]
|
|
self.opt.fovy_range = [self.opt.default_fovy * (1 - r) + self.opt.full_fovy_range[0] * r,
|
|
self.opt.default_fovy * (1 - r) + self.opt.full_fovy_range[1] * r]
|
|
|
|
# progressively increase max_level
|
|
if self.opt.progressive_level:
|
|
self.model.max_level = min(1.0, 0.25 + self.global_step / (0.5 * self.opt.iters))
|
|
|
|
rays_o = data['rays_o'] # [B, N, 3]
|
|
rays_d = data['rays_d'] # [B, N, 3]
|
|
mvp = data['mvp'] # [B, 4, 4]
|
|
|
|
B, N = rays_o.shape[:2]
|
|
H, W = data['H'], data['W']
|
|
|
|
# When ref_data has B images > opt.batch_size
|
|
if B > self.opt.batch_size:
|
|
# choose batch_size images out of those B images
|
|
choice = torch.randperm(B)[:self.opt.batch_size]
|
|
B = self.opt.batch_size
|
|
rays_o = rays_o[choice]
|
|
rays_d = rays_d[choice]
|
|
mvp = mvp[choice]
|
|
|
|
if do_rgbd_loss:
|
|
ambient_ratio = 1.0
|
|
shading = 'lambertian' # use lambertian instead of albedo to get normal
|
|
as_latent = False
|
|
binarize = False
|
|
bg_color = self.get_bg_color(
|
|
self.opt.bg_color_known, B*N, rays_o.device)
|
|
|
|
# add camera noise to avoid grid-like artifact
|
|
if self.opt.known_view_noise_scale > 0:
|
|
noise_scale = self.opt.known_view_noise_scale #* (1 - self.global_step / self.opt.iters)
|
|
rays_o = rays_o + torch.randn(3, device=self.device) * noise_scale
|
|
rays_d = rays_d + torch.randn(3, device=self.device) * noise_scale
|
|
|
|
elif self.global_step < (self.opt.latent_iter_ratio * self.opt.iters):
|
|
ambient_ratio = 1.0
|
|
shading = 'normal'
|
|
as_latent = True
|
|
binarize = False
|
|
bg_color = None
|
|
|
|
else:
|
|
if self.global_step < (self.opt.normal_iter_ratio * self.opt.iters):
|
|
ambient_ratio = 1.0
|
|
shading = 'normal'
|
|
elif self.global_step < (self.opt.textureless_iter_ratio * self.opt.iters):
|
|
ambient_ratio = 0.1 + 0.9 * random.random()
|
|
shading = 'textureless'
|
|
elif self.global_step < (self.opt.albedo_iter_ratio * self.opt.iters):
|
|
ambient_ratio = 1.0
|
|
shading = 'albedo'
|
|
else:
|
|
# random shading
|
|
ambient_ratio = 0.1 + 0.9 * random.random()
|
|
rand = random.random()
|
|
if rand > 0.8:
|
|
shading = 'textureless'
|
|
else:
|
|
shading = 'lambertian'
|
|
|
|
as_latent = False
|
|
|
|
# random weights binarization (like mobile-nerf) [NOT WORKING NOW]
|
|
# binarize_thresh = min(0.5, -0.5 + self.global_step / self.opt.iters)
|
|
# binarize = random.random() < binarize_thresh
|
|
binarize = False
|
|
|
|
# random background
|
|
rand = random.random()
|
|
if self.opt.bg_radius > 0 and rand > 0.5:
|
|
bg_color = None # use bg_net
|
|
else:
|
|
bg_color = torch.rand(3).to(self.device) # single color random bg
|
|
|
|
outputs = self.model.render(rays_o, rays_d, mvp, H, W, staged=False, perturb=True, bg_color=bg_color, ambient_ratio=ambient_ratio, shading=shading, binarize=binarize)
|
|
pred_depth = outputs['depth'].reshape(B, 1, H, W)
|
|
if self.opt.normalize_depth:
|
|
pred_depth = nonzero_normalize_depth(pred_depth)
|
|
pred_mask = outputs['weights_sum'].reshape(B, 1, H, W)
|
|
if 'normal_image' in outputs:
|
|
pred_normal = outputs['normal_image'].reshape(B, H, W, 3)
|
|
else:
|
|
pred_normal = None
|
|
|
|
if as_latent:
|
|
# abuse normal & mask as latent code for faster geometry initialization (ref: fantasia3D)
|
|
pred_rgb = torch.cat([outputs['image'], outputs['weights_sum'].unsqueeze(-1)], dim=-1).reshape(B, H, W, 4).permute(0, 3, 1, 2).contiguous() # [B, 4, H, W]
|
|
else:
|
|
pred_rgb = outputs['image'].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous() # [B, 3, H, W]
|
|
out_dict = {
|
|
'rgb': pred_rgb,
|
|
'depth': pred_depth,
|
|
'mask': pred_mask,
|
|
'normal': pred_normal,
|
|
}
|
|
|
|
# Loss
|
|
# known view loss
|
|
loss_rgb, loss_mask, loss_normal, loss_depth, loss_sds, loss_if, loss_zero123, loss_clip, loss_entropy, loss_opacity, loss_orient, loss_smooth, loss_smooth2d, loss_smooth3d, loss_mesh_normal, loss_mesh_lap = torch.zeros(16, device=self.device)
|
|
# known view loss
|
|
if do_rgbd_loss:
|
|
gt_mask = self.mask # [B, H, W]
|
|
gt_rgb = self.rgb # [B, 3, H, W]
|
|
gt_opacity = self.opacity # [B, 1, H, W]
|
|
gt_normal = self.normal # [B, H, W, 3]
|
|
gt_depth = self.depth # [B, H, W]
|
|
|
|
if len(gt_rgb) > self.opt.batch_size:
|
|
gt_mask = gt_mask[choice]
|
|
gt_rgb = gt_rgb[choice]
|
|
gt_opacity = gt_opacity[choice]
|
|
gt_normal = gt_normal[choice]
|
|
gt_depth = gt_depth[choice]
|
|
|
|
# color loss
|
|
loss_rgb = self.opt.lambda_rgb * \
|
|
F.mse_loss(pred_rgb*gt_opacity, gt_rgb*gt_opacity)
|
|
|
|
# mask loss
|
|
loss_mask = self.opt.lambda_mask * F.mse_loss(pred_mask, gt_mask.to(torch.float32).unsqueeze(0))
|
|
|
|
# normal loss
|
|
if self.opt.lambda_normal > 0 and 'normal_image' in outputs and self.normal is not None:
|
|
pred_normal = pred_normal[self.mask]
|
|
lambda_normal = self.opt.lambda_normal * \
|
|
min(1, self.global_step / self.opt.iters)
|
|
loss_normal = lambda_normal * \
|
|
(1 - F.cosine_similarity(pred_normal, self.normal).mean())/2
|
|
|
|
# relative depth loss
|
|
if self.opt.lambda_depth > 0 and self.depth is not None:
|
|
valid_pred_depth = pred_depth[:, 0][self.mask]
|
|
loss_depth = self.opt.lambda_depth * (1 - pearson_corrcoef(valid_pred_depth, self.depth))/2
|
|
|
|
loss = loss_rgb + loss_mask + loss_normal + loss_depth
|
|
# novel view loss
|
|
else:
|
|
save_guidance_path = os.path.join(self.opt.workspace, 'guidance', f'train_step{self.global_step}_guidance.jpg') if self.opt.save_guidance_every > 0 and self.novel_view_step % self.opt.save_guidance_every ==0 else None
|
|
if 'SD' in self.guidance:
|
|
# interpolate text_z
|
|
azimuth = data['azimuth'] # [-180, 180]
|
|
|
|
# ENHANCE: remove loop to handle batch size > 1
|
|
text_z = []
|
|
for b in range(azimuth.shape[0]):
|
|
if azimuth[b] >= -90 and azimuth[b] < 90:
|
|
if azimuth[b] >= 0:
|
|
r = 1 - azimuth[b] / 90
|
|
else:
|
|
r = 1 + azimuth[b] / 90
|
|
start_z = self.embeddings['SD']['front']
|
|
end_z = self.embeddings['SD']['side']
|
|
else:
|
|
if azimuth[b] >= 0:
|
|
r = 1 - (azimuth[b] - 90) / 90
|
|
else:
|
|
r = 1 + (azimuth[b] + 90) / 90
|
|
start_z = self.embeddings['SD']['side']
|
|
end_z = self.embeddings['SD']['back']
|
|
text_z.append(r * start_z + (1 - r) * end_z)
|
|
text_z = torch.stack(text_z, dim=0).transpose(0, 1).flatten(0, 1)
|
|
text_z_sds = text_z[:, :-1]
|
|
loss_sds, _ = self.guidance['SD'].train_step(text_z_sds, pred_rgb, as_latent=as_latent, guidance_scale=self.opt.guidance_scale['SD'], grad_scale=self.opt.lambda_guidance['SD'],
|
|
density=pred_mask if self.opt.gudiance_spatial_weighting else None,
|
|
save_guidance_path=save_guidance_path
|
|
)
|
|
# if self.opt.lambda_clip > 0:
|
|
# lambda_clip = 10 * (1 - abs(azimuth) / 180) * self.opt.lambda_clip
|
|
# if self.opt.clip_image_loss:
|
|
# loss_clip = lambda_clip * self.guidance.clip_loss(self.rgb_clip_embed, pred_rgb)
|
|
# else:
|
|
# loss_clip = lambda_clip * self.guidance.clip_loss(text_z_clip, pred_rgb)
|
|
|
|
if 'IF' in self.guidance:
|
|
# interpolate text_z
|
|
azimuth = data['azimuth'] # [-180, 180]
|
|
|
|
# ENHANCE: remove loop to handle batch size > 1
|
|
# ENHANCE: remove loop to handle batch size > 1
|
|
text_z = []
|
|
for b in range(azimuth.shape[0]):
|
|
if azimuth[b] >= -90 and azimuth[b] < 90:
|
|
if azimuth[b] >= 0:
|
|
r = 1 - azimuth[b] / 90
|
|
else:
|
|
r = 1 + azimuth[b] / 90
|
|
start_z = self.embeddings['IF']['front']
|
|
end_z = self.embeddings['IF']['side']
|
|
else:
|
|
if azimuth[b] >= 0:
|
|
r = 1 - (azimuth[b] - 90) / 90
|
|
else:
|
|
r = 1 + (azimuth[b] + 90) / 90
|
|
start_z = self.embeddings['IF']['side']
|
|
end_z = self.embeddings['IF']['back']
|
|
text_z.append(r * start_z + (1 - r) * end_z)
|
|
text_z = torch.stack(text_z, dim=0).transpose(0, 1).flatten(0, 1)
|
|
text_z = torch.cat(text_z, dim=1).reshape(B, 2, start_z.shape[-2]-1, start_z.shape[-1]).transpose(0, 1).flatten(0, 1)
|
|
loss_if = self.guidance['IF'].train_step(text_z, pred_rgb, guidance_scale=self.opt.guidance_scale['IF'], grad_scale=self.opt.lambda_guidance['IF'])
|
|
|
|
if 'zero123' in self.guidance:
|
|
|
|
polar = data['polar']
|
|
azimuth = data['azimuth']
|
|
radius = data['radius']
|
|
|
|
loss_zero123 = self.guidance['zero123'].train_step(self.embeddings['zero123']['default'], pred_rgb, polar, azimuth, radius, guidance_scale=self.opt.guidance_scale['zero123'],
|
|
as_latent=as_latent, grad_scale=self.opt.lambda_guidance['zero123'], save_guidance_path=save_guidance_path)
|
|
|
|
if 'clip' in self.guidance:
|
|
|
|
# empirical, far view should apply smaller CLIP loss
|
|
lambda_guidance = 10 * (1 - abs(azimuth) / 180) * self.opt.lambda_guidance['clip']
|
|
loss_clip = self.guidance['clip'].train_step(self.embeddings['clip'], pred_rgb, grad_scale=lambda_guidance)
|
|
loss = loss_sds + loss_if + loss_zero123 + loss_clip
|
|
|
|
# regularizations
|
|
if not self.opt.dmtet:
|
|
|
|
if self.opt.lambda_opacity > 0:
|
|
loss_opacity = self.opt.lambda_opacity * (outputs['weights_sum'] ** 2).mean()
|
|
|
|
if self.opt.lambda_entropy > 0:
|
|
lambda_entropy = self.opt.lambda_entropy * \
|
|
min(1, 2 * self.global_step / self.opt.iters)
|
|
alphas = outputs['weights'].clamp(1e-5, 1 - 1e-5)
|
|
# alphas = alphas ** 2 # skewed entropy, favors 0 over 1
|
|
loss_entropy = lambda_entropy * (- alphas * torch.log2(alphas) -
|
|
(1 - alphas) * torch.log2(1 - alphas)).mean()
|
|
|
|
if self.opt.lambda_normal_smooth > 0 and 'normal_image' in outputs:
|
|
pred_vals = outputs['normal_image'].reshape(B, H, W, 3)
|
|
# total-variation
|
|
loss_smooth = (pred_vals[:, 1:, :, :] - pred_vals[:, :-1, :, :]).square().mean() + \
|
|
(pred_vals[:, :, 1:, :] -
|
|
pred_vals[:, :, :-1, :]).square().mean()
|
|
loss_smooth = self.opt.lambda_normal_smooth * loss_smooth
|
|
|
|
if self.opt.lambda_normal_smooth2d > 0 and 'normal_image' in outputs:
|
|
pred_vals = outputs['normal_image'].reshape(
|
|
B, H, W, 3).permute(0, 3, 1, 2).contiguous()
|
|
smoothed_vals = TF.gaussian_blur(pred_vals, kernel_size=9)
|
|
loss_smooth2d = self.opt.lambda_normal_smooth2d * F.mse_loss(pred_vals, smoothed_vals)
|
|
|
|
if self.opt.lambda_orient > 0 and 'loss_orient' in outputs:
|
|
loss_orient = self.opt.lambda_orient * outputs['loss_orient']
|
|
|
|
if self.opt.lambda_3d_normal_smooth > 0 and 'loss_normal_perturb' in outputs:
|
|
loss_smooth3d = self.opt.lambda_3d_normal_smooth * outputs['loss_normal_perturb']
|
|
|
|
loss += loss_opacity + loss_entropy + loss_smooth + loss_smooth2d + loss_orient + loss_smooth3d
|
|
|
|
else:
|
|
if self.opt.lambda_mesh_normal > 0:
|
|
loss_mesh_normal = self.opt.lambda_mesh_normal * \
|
|
outputs['loss_normal']
|
|
|
|
if self.opt.lambda_mesh_lap > 0:
|
|
loss_mesh_lap = self.opt.lambda_mesh_lap * outputs['loss_lap']
|
|
loss += loss_mesh_normal + loss_mesh_lap
|
|
|
|
losses_dict = {
|
|
'loss': loss.item(),
|
|
'loss_sds': loss_sds.item(),
|
|
'loss_if': loss_if.item(),
|
|
'loss_zero123': loss_zero123.item(),
|
|
'loss_clip': loss_clip.item(),
|
|
'loss_rgb': loss_rgb.item(),
|
|
'loss_mask': loss_mask.item(),
|
|
'loss_normal': loss_normal.item(),
|
|
'loss_depth': loss_depth.item(),
|
|
'loss_opacity': loss_opacity.item(),
|
|
'loss_entropy': loss_entropy.item(),
|
|
'loss_smooth': loss_smooth.item(),
|
|
'loss_smooth2d': loss_smooth2d.item(),
|
|
'loss_smooth3d': loss_smooth3d.item(),
|
|
'loss_orient': loss_orient.item(),
|
|
'loss_mesh_normal': loss_mesh_normal.item(),
|
|
'loss_mesh_lap': loss_mesh_lap.item(),
|
|
}
|
|
# if loss_guidance_dict:
|
|
# for key, val in loss_guidance_dict.items():
|
|
# losses_dict[key] = val.item() if isinstance(val, torch.Tensor) else val
|
|
|
|
if 'normal' in out_dict:
|
|
out_dict['normal'] = out_dict['normal'].permute(0, 3, 1, 2).contiguous()
|
|
|
|
# save for debug purpose
|
|
if self.opt.save_train_every > 0 and self.global_step % self.opt.save_train_every == 0:
|
|
image_save_path = os.path.join(self.workspace, 'train_debug',)
|
|
os.makedirs(image_save_path, exist_ok=True)
|
|
for key, value in out_dict.items():
|
|
if value is not None:
|
|
value = ((value - value.min()) / (value.max() - value.min() + 1e-6)).detach().mul(255).to(torch.uint8)
|
|
try:
|
|
save_tensor2image(value, os.path.join(image_save_path, f'train_{self.global_step:06d}_{key}.jpg'), channel_last=False)
|
|
except:
|
|
pass
|
|
return loss, losses_dict, out_dict
|
|
|
|
def post_train_step(self):
|
|
|
|
# unscale grad before modifying it!
|
|
# ref: https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-clipping
|
|
self.scaler.unscale_(self.optimizer)
|
|
|
|
# clip grad
|
|
if self.opt.grad_clip >= 0:
|
|
torch.nn.utils.clip_grad_value_(self.model.parameters(), self.opt.grad_clip)
|
|
|
|
if not self.opt.dmtet and self.opt.backbone == 'grid':
|
|
|
|
if self.opt.lambda_tv > 0:
|
|
lambda_tv = min(1.0, self.global_step / (0.5 * self.opt.iters)) * self.opt.lambda_tv
|
|
self.model.encoder.grad_total_variation(lambda_tv, None, self.model.bound)
|
|
if self.opt.lambda_wd > 0:
|
|
self.model.encoder.grad_weight_decay(self.opt.lambda_wd)
|
|
|
|
def eval_step(self, data):
|
|
|
|
rays_o = data['rays_o'] # [B, N, 3]
|
|
rays_d = data['rays_d'] # [B, N, 3]
|
|
mvp = data['mvp']
|
|
|
|
B, N = rays_o.shape[:2]
|
|
H, W = data['H'], data['W']
|
|
|
|
shading = data['shading'] if 'shading' in data else 'lambertian'
|
|
ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0
|
|
light_d = data['light_d'] if 'light_d' in data else None
|
|
|
|
outputs = self.model.render(rays_o, rays_d, mvp, H, W, staged=True, perturb=False, bg_color=None, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading)
|
|
pred_rgb = outputs['image'].reshape(B, H, W, 3)
|
|
pred_depth = outputs['depth'].reshape(B, H, W, 1)
|
|
if self.opt.normalize_depth:
|
|
pred_depth = nonzero_normalize_depth(pred_depth)
|
|
if 'normal_image' in outputs:
|
|
pred_normal = outputs['normal_image'].reshape(B, H, W, 3)
|
|
else:
|
|
pred_normal = None
|
|
out_dict = {
|
|
shading: pred_rgb,
|
|
'depth': pred_depth,
|
|
'normal_image': pred_normal,
|
|
}
|
|
# dummy
|
|
loss = torch.zeros([1], device=pred_rgb.device, dtype=pred_rgb.dtype)
|
|
return out_dict, loss
|
|
|
|
def test_step(self, data, bg_color=None, perturb=False, shading='lambertian'):
|
|
rays_o = data['rays_o'] # [B, N, 3]
|
|
rays_d = data['rays_d'] # [B, N, 3]
|
|
mvp = data['mvp']
|
|
|
|
B, N = rays_o.shape[:2]
|
|
H, W = data['H'], data['W']
|
|
|
|
bg_color = self.get_bg_color(bg_color, B*N, rays_o.device)
|
|
|
|
shading = data['shading'] if 'shading' in data else shading
|
|
ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0
|
|
light_d = data['light_d'] if 'light_d' in data else None
|
|
|
|
outputs = self.model.render(rays_o, rays_d, mvp, H, W, staged=True, perturb=perturb, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading, bg_color=bg_color)
|
|
|
|
pred_rgb = outputs['image'].reshape(B, H, W, 3)
|
|
pred_depth = outputs['depth'].reshape(B, H, W, 1)
|
|
pred_mask = outputs['weights_sum'].reshape(B, H, W, 1)
|
|
# if self.opt.normalize_depth:
|
|
pred_depth = nonzero_normalize_depth(pred_depth)
|
|
if 'normal_image' in outputs:
|
|
pred_normal = outputs['normal_image'].reshape(B, H, W, 3)
|
|
pred_normal = pred_normal * pred_mask + (1.0 - pred_mask)
|
|
else:
|
|
pred_normal = None
|
|
out_dict = {
|
|
shading: pred_rgb,
|
|
'depth': pred_depth,
|
|
'normal_image': pred_normal,
|
|
'mask': pred_mask,
|
|
}
|
|
return out_dict
|
|
|
|
def save_mesh(self, loader=None, save_path=None):
|
|
|
|
if save_path is None:
|
|
save_path = os.path.join(self.workspace, 'mesh')
|
|
|
|
logger.info(f"==> Saving mesh to {save_path}")
|
|
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
|
self.model.export_mesh(save_path, resolution=self.opt.mcubes_resolution, decimate_target=self.opt.decimate_target)
|
|
|
|
logger.info(f"==> Finished saving mesh.")
|
|
|
|
### ------------------------------
|
|
|
|
def train(self, train_loader, valid_loader, test_loader, max_epochs):
|
|
|
|
if self.use_tensorboard and self.local_rank == 0:
|
|
self.writer = SummaryWriter(
|
|
os.path.join(self.workspace, "run", self.name))
|
|
|
|
# init from nerf should be performed after Shap-E, since Shap-E will rescale dmtet
|
|
if self.opt.dmtet and (self.opt.init_ckpt and os.path.exists(self.opt.init_ckpt)):
|
|
reset_scale = False if self.opt.use_shape else True
|
|
old_sdf = self.model.get_sdf_from_nerf(reset_scale)
|
|
if not self.opt.tet_mlp:
|
|
self.model.dmtet.init_tet_from_sdf(old_sdf)
|
|
self.test(valid_loader, name=f'init_ckpt', write_video=False, save_each_frame=False, subfolder='check_init')
|
|
else:
|
|
old_sdf = None
|
|
|
|
if self.opt.use_shape and self.opt.dmtet:
|
|
os.makedirs(os.path.join(self.opt.workspace, 'shape'), exist_ok=True)
|
|
best_loss = torch.inf
|
|
best_idx = 0
|
|
for idx, (sdf, color) in enumerate(zip(self.opt.rpsts, self.opt.colors)):
|
|
self.model.init_tet_from_sdf_color(sdf)
|
|
pred_rgb, pred_mask, rgb_loss, mask_loss = self.match_known()
|
|
best_loss = min(best_loss, mask_loss)
|
|
if best_loss == mask_loss:
|
|
best_idx = idx
|
|
logger.info(f"==> Current best match shape known sdf idx: {best_idx}")
|
|
save_tensor2image(pred_mask, os.path.join(self.opt.workspace, 'shape', f"match_shape_known_{idx}_rgb.jpg"), channel_last=False)
|
|
self.test(valid_loader, name=f'idx_{idx}', write_video=False, save_each_frame=False, subfolder='check_init')
|
|
|
|
sdf = self.opt.rpsts[best_idx]
|
|
self.model.init_tet_from_sdf_color(sdf, self.opt.colors[best_idx])
|
|
self.test(valid_loader, name=f'shape_only', write_video=False, save_each_frame=False, subfolder='check_init')
|
|
|
|
# Enable mixture model
|
|
if self.opt.base_mesh:
|
|
logger.info(f"==> Enable mixture model with base mesh {self.opt.base_mesh}")
|
|
mesh_sdf = self.model.dmtet.get_sdf_from_mesh(self.opt.base_mesh)
|
|
sdf = (mesh_sdf.clamp(0, 1) + sdf.clamp(0,1) ).clamp(0, 1)
|
|
|
|
if old_sdf is not None:
|
|
sdf = (sdf.clamp(0, 1) + old_sdf.clamp(0, 1)).clamp(0, 1)
|
|
|
|
self.model.init_tet_from_sdf_color(sdf, self.opt.colors[best_idx])
|
|
self.test(valid_loader, name=f'shape_merge', write_video=False, save_each_frame=False, subfolder='check_init')
|
|
|
|
del best_loss, best_idx, pred_rgb, pred_mask, rgb_loss, mask_loss
|
|
self.opt.rpsts = None
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
# init shape for NeRF. NOTE: Does not work yet.. in progress.
|
|
# if self.opt.use_shape and not self.opt.dmtet:
|
|
# os.makedirs(os.path.join(self.opt.workspace, 'shape'), exist_ok=True)
|
|
# best_loss = torch.inf
|
|
# best_idx = 0
|
|
# for idx, (density, color) in enumerate(zip(self.opt.rpsts, self.opt.colors)):
|
|
# self.model.init_nerf_from_sdf_color(density, color, self.opt.points, lr=0.001)
|
|
# pred_rgb, pred_mask, rgb_loss, mask_loss = self.match_known()
|
|
# best_loss = min(best_loss, mask_loss)
|
|
# if best_loss == mask_loss:
|
|
# best_idx = idx
|
|
# logger.info(f"==> Current best match shape known sdf idx: {best_idx}")
|
|
# save_tensor2image(pred_mask, os.path.join(self.opt.workspace, 'shape', f"match_shape_known_{idx}_rgb.jpg"), channel_last=False)
|
|
# self.evaluate_one_epoch(valid_loader, f'idx_{idx}')
|
|
# self.model.init_nerf_from_sdf_color(self.opt.rpsts[best_idx], self.opt.colors[best_idx])
|
|
# self.evaluate_one_epoch(valid_loader, f'init_from_shape_{idx}')
|
|
|
|
# del best_loss, best_idx, pred_rgb, pred_mask, rgb_loss, mask_loss
|
|
# self.opt.rpsts = None
|
|
# self.opt.colors = None
|
|
# self.opt.points = None
|
|
# gc.collect()
|
|
# torch.cuda.empty_cache()
|
|
|
|
start_t = time.time()
|
|
|
|
for epoch in range(self.epoch + 1, max_epochs + 1):
|
|
self.epoch = epoch
|
|
|
|
self.train_one_epoch(train_loader, max_epochs)
|
|
|
|
if self.workspace is not None and self.local_rank == 0:
|
|
self.save_checkpoint(full=True, best=False)
|
|
|
|
if self.epoch % self.opt.eval_interval == 0:
|
|
self.evaluate_one_epoch(valid_loader)
|
|
self.save_checkpoint(full=False, best=True)
|
|
|
|
if self.epoch % self.opt.test_interval == 0 or self.epoch == max_epochs:
|
|
self.test(test_loader, img_folder='images' if self.epoch == max_epochs else f'images_ep{self.epoch:04d}')
|
|
|
|
end_t = time.time()
|
|
|
|
self.total_train_t = end_t - start_t + self.total_train_t
|
|
|
|
logger.info(f"[INFO] training takes {(self.total_train_t)/ 60:.4f} minutes.")
|
|
|
|
if self.use_tensorboard and self.local_rank == 0:
|
|
self.writer.close()
|
|
|
|
def evaluate(self, loader, name=None):
|
|
self.use_tensorboard, use_tensorboard = False, self.use_tensorboard
|
|
self.evaluate_one_epoch(loader, name)
|
|
self.use_tensorboard = use_tensorboard
|
|
|
|
def test(self, loader, save_path=None, name=None,
|
|
write_video=True, save_each_frame=True, shading='lambertian',
|
|
subfolder='results', img_folder='images'
|
|
):
|
|
|
|
if save_path is None:
|
|
save_path = os.path.join(self.workspace, subfolder)
|
|
image_save_path = os.path.join(self.workspace, subfolder, img_folder)
|
|
|
|
if name is None:
|
|
name = f'{self.name}_ep{self.epoch:04d}'
|
|
|
|
os.makedirs(save_path, exist_ok=True)
|
|
os.makedirs(image_save_path, exist_ok=True)
|
|
|
|
logger.info(f"==> Start Test, saving {shading} results to {save_path}")
|
|
|
|
pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
|
|
self.model.eval()
|
|
|
|
all_outputs = {}
|
|
with torch.no_grad():
|
|
for i, data in enumerate(loader):
|
|
with torch.cuda.amp.autocast(enabled=self.fp16):
|
|
outputs = self.test_step(data, bg_color=self.opt.bg_color_test, shading=shading)
|
|
for key, value in outputs.items():
|
|
if value is not None:
|
|
value = ((value - value.min()) / (value.max() - value.min() + 1e-6)).detach().mul(255).to(torch.uint8)
|
|
if save_each_frame:
|
|
save_tensor2image(value, os.path.join(image_save_path, f'{name}_{i:04d}_{key}.jpg'), channel_last=True)
|
|
if key not in all_outputs.keys():
|
|
all_outputs[key] = []
|
|
all_outputs[key].append(value)
|
|
pbar.update(loader.batch_size)
|
|
|
|
for key, value in all_outputs.items():
|
|
all_outputs[key] = torch.cat(value, dim=0)
|
|
|
|
if write_video:
|
|
for key, value in all_outputs.items():
|
|
# current version torchvision does not support writing a single-channel video
|
|
# torchvision.io.write_video(os.path.join(save_path, f'{name}_{key}.mp4'), all_outputs[key].detach().cpu(), fps=25)
|
|
imageio.mimwrite(os.path.join(save_path, f'{name}_{key}.mp4'), all_outputs[key].detach().cpu().numpy(), fps=25, quality=8, macro_block_size=1)
|
|
for key, value in all_outputs.items():
|
|
save_tensor2image(value, os.path.join(save_path, f'{name}_{key}.jpg'), channel_last=True)
|
|
logger.info(f"==> Finished Test.")
|
|
|
|
# [GUI] train text step.
|
|
def train_gui(self, train_loader, step=16):
|
|
|
|
self.model.train()
|
|
|
|
total_loss = torch.tensor([0], dtype=torch.float32, device=self.device)
|
|
|
|
loader = iter(train_loader)
|
|
|
|
for _ in range(step):
|
|
|
|
# mimic an infinite loop dataloader (in case the total dataset is smaller than step)
|
|
try:
|
|
data = next(loader)
|
|
except StopIteration:
|
|
loader = iter(train_loader)
|
|
data = next(loader)
|
|
|
|
# update grid every 16 steps
|
|
if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0:
|
|
with torch.cuda.amp.autocast(enabled=self.fp16):
|
|
self.model.update_extra_state()
|
|
|
|
self.global_step += 1
|
|
|
|
self.optimizer.zero_grad()
|
|
|
|
with torch.cuda.amp.autocast(enabled=self.fp16):
|
|
loss, loss_dicts, outputs = self.train_step(data)
|
|
|
|
self.scaler.scale(loss).backward()
|
|
self.post_train_step()
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
|
|
if self.scheduler_update_every_step:
|
|
self.lr_scheduler.step()
|
|
|
|
self.loss_meter.update(loss_dicts)
|
|
|
|
if self.ema is not None:
|
|
self.ema.update()
|
|
|
|
average_loss = self.loss_meter.meters['loss'].avg
|
|
|
|
if not self.scheduler_update_every_step:
|
|
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|
self.lr_scheduler.step(average_loss)
|
|
else:
|
|
self.lr_scheduler.step()
|
|
|
|
outputs = {
|
|
'loss': average_loss,
|
|
'lr': self.optimizer.param_groups[0]['lr'],
|
|
}
|
|
|
|
return outputs
|
|
|
|
|
|
# [GUI] test on a single image
|
|
def test_gui(self, pose, intrinsics, mvp, W, H, bg_color=None, spp=1, downscale=1, light_d=None, ambient_ratio=1.0, shading='albedo'):
|
|
|
|
# render resolution (may need downscale to for better frame rate)
|
|
rH = int(H * downscale)
|
|
rW = int(W * downscale)
|
|
intrinsics = intrinsics * downscale
|
|
|
|
pose = torch.from_numpy(pose).unsqueeze(0).to(self.device)
|
|
mvp = torch.from_numpy(mvp).unsqueeze(0).to(self.device)
|
|
|
|
rays = get_rays(pose, intrinsics, rH, rW, -1)
|
|
|
|
# from degree theta/phi to 3D normalized vec
|
|
light_d = np.deg2rad(light_d)
|
|
light_d = np.array([
|
|
np.sin(light_d[0]) * np.sin(light_d[1]),
|
|
np.cos(light_d[0]),
|
|
np.sin(light_d[0]) * np.cos(light_d[1]),
|
|
], dtype=np.float32)
|
|
light_d = torch.from_numpy(light_d).to(self.device)
|
|
|
|
data = {
|
|
'rays_o': rays['rays_o'],
|
|
'rays_d': rays['rays_d'],
|
|
'mvp': mvp,
|
|
'H': rH,
|
|
'W': rW,
|
|
'light_d': light_d,
|
|
'ambient_ratio': ambient_ratio,
|
|
'shading': shading,
|
|
}
|
|
|
|
self.model.eval()
|
|
|
|
if self.ema is not None:
|
|
self.ema.store()
|
|
self.ema.copy_to()
|
|
|
|
with torch.no_grad():
|
|
with torch.cuda.amp.autocast(enabled=self.fp16):
|
|
# here spp is used as perturb random seed!
|
|
outputs = self.test_step(
|
|
data, bg_color=bg_color, perturb=False if spp == 1 else spp)
|
|
|
|
if self.ema is not None:
|
|
self.ema.restore()
|
|
|
|
# interpolation to the original resolution
|
|
if downscale != 1:
|
|
# have to permute twice with torch...
|
|
outputs[shading] = F.interpolate(outputs[shading].permute(0, 3, 1, 2), size=(
|
|
H, W), mode='nearest').permute(0, 2, 3, 1).contiguous()
|
|
outputs['depth'] = F.interpolate(outputs['depth'].unsqueeze(
|
|
1), size=(H, W), mode='nearest').squeeze(1)
|
|
|
|
if outputs['normal_imagea'] is not None:
|
|
outputs['normal_image'] = F.interpolate(outputs['normal_image'].unsqueeze(1), size=(H, W), mode='nearest').squeeze(1)
|
|
|
|
return outputs
|
|
|
|
def train_one_epoch(self, loader, max_epochs):
|
|
logger.info(f"==> [{time.strftime('%Y-%m-%d_%H-%M-%S')}] Start Training {self.workspace} Epoch {self.epoch}/{max_epochs}, lr={self.optimizer.param_groups[0]['lr']:.6f} ...")
|
|
|
|
if self.local_rank == 0 and self.report_metric_at_train:
|
|
for metric in self.metrics:
|
|
metric.clear()
|
|
|
|
self.model.train()
|
|
|
|
# distributedSampler: must call set_epoch() to shuffle indices across multiple epochs
|
|
# ref: https://pytorch.org/docs/stable/data.html
|
|
if self.world_size > 1:
|
|
loader.sampler.set_epoch(self.epoch)
|
|
|
|
self.local_step = 0
|
|
|
|
for data in loader:
|
|
|
|
# update grid every 16 steps
|
|
if (self.model.cuda_ray or self.model.taichi_ray) and self.global_step % self.opt.update_extra_interval == 0:
|
|
with torch.cuda.amp.autocast(enabled=self.fp16):
|
|
self.model.update_extra_state()
|
|
|
|
# Update grid
|
|
if self.opt.grid_levels_mask > 0:
|
|
if self.global_step > self.opt.grid_levels_mask_iters:
|
|
self.model.grid_levels_mask = 0
|
|
else:
|
|
self.model.grid_levels_mask = self.opt.grid_levels_mask
|
|
|
|
self.local_step += 1
|
|
self.global_step += 1
|
|
|
|
self.optimizer.zero_grad()
|
|
|
|
with torch.cuda.amp.autocast(enabled=self.fp16):
|
|
loss, losses_dict, outputs = self.train_step(data)
|
|
|
|
# hooked grad clipping for RGB space
|
|
if self.opt.grad_clip_rgb >= 0:
|
|
def _hook(grad):
|
|
if self.opt.fp16:
|
|
# correctly handle the scale
|
|
grad_scale = self.scaler._get_scale_async()
|
|
return grad.clamp(grad_scale * -self.opt.grad_clip_rgb, grad_scale * self.opt.grad_clip_rgb)
|
|
else:
|
|
return grad.clamp(-self.opt.grad_clip_rgb, self.opt.grad_clip_rgb)
|
|
outputs['rgb'].register_hook(_hook)
|
|
# if (self.global_step <= self.opt.known_iters or self.global_step % self.opt.known_view_interval == 0) and self.opt.image is not None and self.opt.joint_known_unknown and known_rgbs is not None:
|
|
# known_rgbs.register_hook(_hook)
|
|
# pred_rgbs.retain_grad()
|
|
|
|
self.scaler.scale(loss).backward()
|
|
|
|
self.post_train_step()
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
|
|
if self.scheduler_update_every_step:
|
|
self.lr_scheduler.step()
|
|
|
|
self.loss_meter.update(losses_dict)
|
|
if self.local_rank == 0:
|
|
# if self.report_metric_at_train:
|
|
# for metric in self.metrics:
|
|
# metric.update(preds, truths)
|
|
|
|
if self.use_tensorboard:
|
|
|
|
for key, val in losses_dict.items():
|
|
self.writer.add_scalar(
|
|
f"train/{key}", val, self.global_step)
|
|
|
|
self.writer.add_scalar(
|
|
"train/lr", self.optimizer.param_groups[0]['lr'], self.global_step)
|
|
|
|
if self.global_step % self.opt.log_every == 0:
|
|
strings = f"==> Train [Step] {self.global_step}/{self.opt.iters}"
|
|
for key, value in losses_dict.items():
|
|
strings += f", {key}={value:.4f}"
|
|
logger.info(strings)
|
|
strings = f"==> Train [Avg] {self.global_step}/{self.opt.iters}"
|
|
for key in self.loss_meter.meters.keys():
|
|
strings += f", {key}={self.loss_meter.meters[key].avg:.4f}"
|
|
logger.info(strings)
|
|
|
|
if self.ema is not None:
|
|
self.ema.update()
|
|
|
|
average_loss = self.loss_meter.meters['loss'].avg
|
|
self.stats["loss"].append(average_loss)
|
|
|
|
if self.local_rank == 0:
|
|
# pbar.close()
|
|
if self.report_metric_at_train:
|
|
for metric in self.metrics:
|
|
logger.info(metric.report(), style="red")
|
|
if self.use_tensorboard:
|
|
metric.write(self.writer, self.epoch, prefix="train")
|
|
metric.clear()
|
|
|
|
if not self.scheduler_update_every_step:
|
|
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|
self.lr_scheduler.step(average_loss)
|
|
else:
|
|
self.lr_scheduler.step()
|
|
|
|
# Visualize Training
|
|
if self.local_rank == 0:
|
|
# save image
|
|
save_path = os.path.join(
|
|
self.workspace, 'training')
|
|
os.makedirs(save_path, exist_ok=True)
|
|
name = f'train_{self.name}_ep{self.epoch:04d}'
|
|
for key, value in outputs.items():
|
|
save_tensor2image(value, os.path.join(save_path, f'{name}_{key}.jpg'), channel_last=False)
|
|
gpu_mem = get_GPU_mem()[0]
|
|
logger.info(f"==> [Finished Epoch {self.epoch}/{max_epochs}. GPU={gpu_mem:.1f}GB.")
|
|
|
|
def evaluate_one_epoch(self, loader, name=None):
|
|
logger.info(f"++> Evaluate {self.workspace} at epoch {self.epoch} ...")
|
|
|
|
if name is None:
|
|
name = f'{self.name}_ep{self.epoch:04d}'
|
|
|
|
total_loss = 0
|
|
if self.local_rank == 0:
|
|
for metric in self.metrics:
|
|
metric.clear()
|
|
|
|
self.model.eval()
|
|
|
|
if self.ema is not None:
|
|
self.ema.store()
|
|
self.ema.copy_to()
|
|
|
|
if self.local_rank == 0:
|
|
pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
|
|
|
|
with torch.no_grad():
|
|
self.local_step = 0
|
|
|
|
all_outputs = {}
|
|
for data in loader:
|
|
self.local_step += 1
|
|
|
|
with torch.cuda.amp.autocast(enabled=self.fp16):
|
|
outputs, loss = self.eval_step(data)
|
|
|
|
# all_gather/reduce the statistics (NCCL only support all_*)
|
|
if self.world_size > 1:
|
|
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
|
|
loss = loss / self.world_size
|
|
|
|
for key, value in outputs.items():
|
|
if value is not None:
|
|
dist.all_gather(outputs[key])
|
|
outputs[key] = torch.cat(outputs[key], dim=0)
|
|
|
|
loss_val = loss.item()
|
|
total_loss += loss_val
|
|
|
|
# only rank = 0 will perform evaluation.
|
|
if self.local_rank == 0:
|
|
|
|
# save image
|
|
save_path = os.path.join(
|
|
self.workspace, 'validation')
|
|
|
|
# logger.info(f"==> Saving validation image to {save_path}")
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
|
for key, value in outputs.items():
|
|
if value is not None:
|
|
value = ((value - value.min()) / (value.max() - value.min() + 1e-6)).detach().mul(255).to(torch.uint8)
|
|
# save_tensor2image(value, os.path.join(save_path, f'{name}_{self.local_step:04d}_{key}.jpg'))
|
|
if key not in all_outputs.keys():
|
|
all_outputs[key] = []
|
|
all_outputs[key].append(value)
|
|
|
|
pbar.set_description(
|
|
f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})")
|
|
pbar.update(loader.batch_size)
|
|
|
|
|
|
average_loss = total_loss / self.local_step
|
|
self.stats["valid_loss"].append(average_loss)
|
|
|
|
if self.local_rank == 0:
|
|
pbar.close()
|
|
if not self.use_loss_as_metric and len(self.metrics) > 0:
|
|
result = self.metrics[0].measure()
|
|
self.stats["results"].append(result if self.best_mode == 'min' else - result) # if max mode, use -result
|
|
else:
|
|
self.stats["results"].append(average_loss) # if no metric, choose best by min loss
|
|
|
|
for metric in self.metrics:
|
|
logger.info(metric.report(), style="blue")
|
|
if self.use_tensorboard:
|
|
metric.write(self.writer, self.epoch, prefix="evaluate")
|
|
metric.clear()
|
|
|
|
for key, value in all_outputs.items():
|
|
all_outputs[key] = torch.cat(value, dim=0)
|
|
save_tensor2image(all_outputs[key], os.path.join(save_path, f'{name}_{key}.jpg'), channel_last=True)
|
|
if self.ema is not None:
|
|
self.ema.restore()
|
|
|
|
logger.info(f"++> Evaluate epoch {self.epoch} Finished.")
|
|
|
|
def save_checkpoint(self, name=None, full=False, best=False):
|
|
|
|
if name is None:
|
|
name = f'{self.name}_ep{self.epoch:04d}'
|
|
|
|
state = {
|
|
'epoch': self.epoch,
|
|
'global_step': self.global_step,
|
|
'stats': self.stats,
|
|
}
|
|
|
|
if self.model.cuda_ray:
|
|
state['mean_density'] = self.model.mean_density
|
|
|
|
if self.opt.dmtet:
|
|
state['tet_scale'] = self.model.dmtet.tet_scale.cpu().numpy()
|
|
|
|
if full:
|
|
state['optimizer'] = self.optimizer.state_dict()
|
|
state['lr_scheduler'] = self.lr_scheduler.state_dict()
|
|
state['scaler'] = self.scaler.state_dict()
|
|
if self.ema is not None:
|
|
state['ema'] = self.ema.state_dict()
|
|
|
|
if not best:
|
|
|
|
state['model'] = self.model.state_dict()
|
|
|
|
file_path = f"{name}.pth"
|
|
|
|
self.stats["checkpoints"].append(file_path)
|
|
|
|
if len(self.stats["checkpoints"]) > self.max_keep_ckpt:
|
|
old_ckpt = os.path.join(
|
|
self.opt.ckpt_path, self.stats["checkpoints"].pop(0))
|
|
if os.path.exists(old_ckpt):
|
|
os.remove(old_ckpt)
|
|
|
|
torch.save(state, os.path.join(self.opt.ckpt_path, file_path))
|
|
|
|
else:
|
|
if len(self.stats["results"]) > 0:
|
|
# always save best since loss cannot reflect performance.
|
|
if True:
|
|
# logger.info(f"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}")
|
|
# self.stats["best_result"] = self.stats["results"][-1]
|
|
|
|
# save ema results
|
|
if self.ema is not None:
|
|
self.ema.store()
|
|
self.ema.copy_to()
|
|
|
|
state['model'] = self.model.state_dict()
|
|
|
|
if self.ema is not None:
|
|
self.ema.restore()
|
|
|
|
torch.save(state, self.opt.best_path)
|
|
else:
|
|
logger.info(
|
|
f"[WARN] no evaluated results found, skip saving best checkpoint.")
|
|
|
|
def load_checkpoint(self, checkpoint=None, model_only=False):
|
|
if checkpoint is None:
|
|
checkpoint_list = sorted(glob.glob(f'{self.opt.ckpt_path}/*.pth'))
|
|
if checkpoint_list:
|
|
checkpoint = checkpoint_list[-1]
|
|
logger.info(f"[INFO] Latest checkpoint is {checkpoint}")
|
|
else:
|
|
logger.info(
|
|
"[WARN] No checkpoint found, model randomly initialized.")
|
|
return
|
|
|
|
checkpoint_dict = torch.load(checkpoint, map_location=self.device)
|
|
|
|
if 'model' not in checkpoint_dict:
|
|
self.model.load_state_dict(checkpoint_dict)
|
|
logger.info("[INFO] loaded model.")
|
|
return
|
|
|
|
missing_keys, unexpected_keys = self.model.load_state_dict(checkpoint_dict['model'], strict=False)
|
|
logger.info("[INFO] loaded model.")
|
|
if len(missing_keys) > 0:
|
|
logger.info(f"[WARN] missing keys: {missing_keys}")
|
|
if len(unexpected_keys) > 0:
|
|
logger.info(f"[WARN] unexpected keys: {unexpected_keys}")
|
|
|
|
if self.ema is not None and 'ema' in checkpoint_dict:
|
|
try:
|
|
self.ema.load_state_dict(checkpoint_dict['ema'])
|
|
logger.info("[INFO] loaded EMA.")
|
|
except:
|
|
logger.info("[WARN] failed to loaded EMA.")
|
|
|
|
if self.model.cuda_ray:
|
|
if 'mean_density' in checkpoint_dict:
|
|
self.model.mean_density = checkpoint_dict['mean_density']
|
|
|
|
if self.opt.dmtet:
|
|
if 'tet_scale' in checkpoint_dict:
|
|
new_scale = torch.from_numpy(
|
|
checkpoint_dict['tet_scale']).to(self.device)
|
|
self.model.dmtet.verts *= new_scale / self.model.dmtet.tet_scale
|
|
self.model.dmtet.tet_scale = new_scale
|
|
# self.model.init_tet()
|
|
if model_only:
|
|
return
|
|
|
|
self.stats = checkpoint_dict['stats']
|
|
self.epoch = checkpoint_dict['epoch']
|
|
self.global_step = checkpoint_dict['global_step']
|
|
logger.info(
|
|
f"[INFO] load at epoch {self.epoch}, global step {self.global_step}")
|
|
|
|
if self.optimizer and 'optimizer' in checkpoint_dict:
|
|
try:
|
|
self.optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
|
logger.info("[INFO] loaded optimizer.")
|
|
except:
|
|
logger.info("[WARN] Failed to load optimizer.")
|
|
|
|
if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict:
|
|
try:
|
|
self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler'])
|
|
logger.info("[INFO] loaded scheduler.")
|
|
except:
|
|
logger.info("[WARN] Failed to load scheduler.")
|
|
|
|
if self.scaler and 'scaler' in checkpoint_dict:
|
|
try:
|
|
self.scaler.load_state_dict(checkpoint_dict['scaler'])
|
|
logger.info("[INFO] loaded scaler.")
|
|
except:
|
|
logger.info("[WARN] Failed to load scaler.")
|
|
|
|
|
|
def get_CPU_mem():
|
|
return psutil.Process(os.getpid()).memory_info().rss /1024**3
|
|
|
|
|
|
def get_GPU_mem():
|
|
num = torch.cuda.device_count()
|
|
mem, mems = 0, []
|
|
for i in range(num):
|
|
mem_free, mem_total = torch.cuda.mem_get_info(i)
|
|
mems.append(int(((mem_total - mem_free)/1024**3)*1000)/1000)
|
|
mem += mems[-1]
|
|
return mem, mems
|