import torch import torch.nn as nn import torch.nn.functional as F from activation import trunc_exp, biased_softplus from .renderer import NeRFRenderer, MLP import numpy as np from encoding import get_encoder from .utils import safe_normalize from tqdm import tqdm import logging logger = logging.getLogger(__name__) class NeRFNetwork(NeRFRenderer): def __init__(self, opt, num_layers=3, hidden_dim=64, num_layers_bg=2, hidden_dim_bg=32, level_dim=2 ): super().__init__(opt) self.num_layers = opt.num_layers if hasattr(opt, 'num_layers') else num_layers self.hidden_dim = opt.hidden_dim if hasattr(opt, 'hidden_dim') else hidden_dim self.level_dim = opt.level_dim if hasattr(opt, 'level_dim') else level_dim num_layers_bg = opt.num_layers_bg if hasattr(opt, 'num_layers_bg') else num_layers_bg hidden_dim_bg = opt.hidden_dim_bg if hasattr(opt, 'hidden_dim_bg') else hidden_dim_bg if self.opt.grid_type == 'hashgrid': self.encoder, self.in_dim = get_encoder('hashgrid', input_dim=3, log2_hashmap_size=19, desired_resolution=2048 * self.bound, interpolation='smoothstep') elif self.opt.grid_type == 'tilegrid': self.encoder, self.in_dim = get_encoder( 'tiledgrid', input_dim=3, level_dim=self.level_dim, log2_hashmap_size=16, num_levels=16, desired_resolution= 2048 * self.bound, ) self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True) # self.normal_net = MLP(self.in_dim, 3, hidden_dim, num_layers, bias=True) # masking self.grid_levels_mask = 0 # background network if self.opt.bg_radius > 0: self.num_layers_bg = num_layers_bg self.hidden_dim_bg = hidden_dim_bg # use a very simple network to avoid it learning the prompt... self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3, multires=6) self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True) else: self.bg_net = None def common_forward(self, x): # sigma h = self.encoder(x, bound=self.bound, max_level=self.max_level) # Feature masking for coarse-to-fine training if self.grid_levels_mask > 0: h_mask: torch.Tensor = torch.arange(self.in_dim, device=h.device) < self.in_dim - self.grid_levels_mask * self.level_dim # (self.in_dim) h_mask = h_mask.reshape(1, self.in_dim).float() # (1, self.in_dim) h = h * h_mask # (N, self.in_dim) h = self.sigma_net(h) sigma = self.density_activation(h[..., 0] + self.density_blob(x)) albedo = torch.sigmoid(h[..., 1:]) return sigma, albedo def forward(self, x, d, l=None, ratio=1, shading='albedo'): # x: [N, 3], in [-bound, bound] # d: [N, 3], view direction, nomalized in [-1, 1] # l: [3], plane light direction, nomalized in [-1, 1] # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless) sigma, albedo = self.common_forward(x) if shading == 'albedo': normal = None color = albedo else: # lambertian shading normal = self.normal(x) if shading == 'normal': color = (normal + 1) / 2 else: lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,] if shading == 'textureless': color = lambertian.unsqueeze(-1).repeat(1, 3) else: # 'lambertian' color = albedo * lambertian.unsqueeze(-1) return sigma, color, normal def density(self, x): # x: [N, 3], in [-bound, bound] sigma, albedo = self.common_forward(x) return { 'sigma': sigma, 'albedo': albedo, } def background(self, d): h = self.encoder_bg(d) # [N, C] h = self.bg_net(h) # sigmoid activation for rgb rgbs = torch.sigmoid(h) return rgbs # optimizer utils def get_params(self, lr): params = [ {'params': self.encoder.parameters(), 'lr': lr * 10}, {'params': self.sigma_net.parameters(), 'lr': lr * self.opt.lr_scale_nerf}, # {'params': self.normal_net.parameters(), 'lr': lr}, ] if self.opt.bg_radius > 0: # params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10}) params.append({'params': self.bg_net.parameters(), 'lr': lr}) if self.opt.dmtet: params.append({'params': self.dmtet.parameters(), 'lr': lr}) return params def reset_sigmanet(self): self.sigma_net.reset_parameters() def init_nerf_from_sdf_color(self, rpst, albedo, points=None, pretrain_iters=10000, lr=0.001, rpst_type='sdf', ): self.reset_sigmanet() # matching optimization self.train() self.grid_levels_mask = 0 loss_fn = torch.nn.MSELoss() optimizer = torch.optim.Adam(list(self.parameters()), lr=lr) milestones = [int(0.4 * pretrain_iters), int(0.8 * pretrain_iters)] scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1) rpst = rpst.squeeze().clamp(0, 1) # rpst = torch.ones_like(rpst) * 0.4 pbar = tqdm(range(pretrain_iters), desc="NeRF sigma optimization") rgb_loss = torch.tensor(0, device=rpst.device) for i in pbar: output = self.density(points) if rpst_type == 'sdf': pred_rpst = output['sigma'] - self.density_thresh else: pred_rpst = output['sigma'] sdf_loss = loss_fn(pred_rpst, rpst) if albedo is not None: pred_albedo = output['albedo'] rgb_loss = loss_fn(pred_albedo, albedo) loss = 10 * sdf_loss + rgb_loss else: loss = sdf_loss optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() pbar.set_postfix(loss=loss.item(), rgb_loss=rgb_loss.item(), sdf_loss=sdf_loss.item()) logger.info(f'lr: {lr} Accuracy: (pred_rpst[rpst>0]>0).sum() / (rpst>0).sum()') pbar = tqdm(range(pretrain_iters), desc="NeRF color optimization") def init_tet_from_sdf_color(self, sdf, colors=None, pretrain_iters=5000, lr=0.01): self.train() self.grid_levels_mask = 0 self.dmtet.reset_tet(reset_scale=False) self.dmtet.init_tet_from_sdf(sdf, pretrain_iters=pretrain_iters, lr=lr) if colors is not None: self.reset_sigmanet() loss_fn = torch.nn.MSELoss() pretrain_iters = 5000 optimizer = torch.optim.Adam(list(self.parameters()), lr=0.01) pbar = tqdm(range(pretrain_iters), desc="NeRF color optimization") for i in pbar: pred_albedo = self.density(self.dmtet.verts)['albedo'] loss = loss_fn(pred_albedo, colors) optimizer.zero_grad() loss.backward() optimizer.step() pbar.set_postfix(loss=loss.item())