Files
Magic123/nerf/network_grid.py
Guocheng Qian 13e18567fa first commit
2023-08-02 19:51:43 -07:00

217 lines
7.6 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from activation import trunc_exp, biased_softplus
from .renderer import NeRFRenderer, MLP
import numpy as np
from encoding import get_encoder
from .utils import safe_normalize
from tqdm import tqdm
import logging
logger = logging.getLogger(__name__)
class NeRFNetwork(NeRFRenderer):
def __init__(self,
opt,
num_layers=3,
hidden_dim=64,
num_layers_bg=2,
hidden_dim_bg=32,
level_dim=2
):
super().__init__(opt)
self.num_layers = opt.num_layers if hasattr(opt, 'num_layers') else num_layers
self.hidden_dim = opt.hidden_dim if hasattr(opt, 'hidden_dim') else hidden_dim
self.level_dim = opt.level_dim if hasattr(opt, 'level_dim') else level_dim
num_layers_bg = opt.num_layers_bg if hasattr(opt, 'num_layers_bg') else num_layers_bg
hidden_dim_bg = opt.hidden_dim_bg if hasattr(opt, 'hidden_dim_bg') else hidden_dim_bg
if self.opt.grid_type == 'hashgrid':
self.encoder, self.in_dim = get_encoder('hashgrid', input_dim=3, log2_hashmap_size=19, desired_resolution=2048 * self.bound, interpolation='smoothstep')
elif self.opt.grid_type == 'tilegrid':
self.encoder, self.in_dim = get_encoder(
'tiledgrid',
input_dim=3,
level_dim=self.level_dim,
log2_hashmap_size=16,
num_levels=16,
desired_resolution= 2048 * self.bound,
)
self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
# self.normal_net = MLP(self.in_dim, 3, hidden_dim, num_layers, bias=True)
# masking
self.grid_levels_mask = 0
# background network
if self.opt.bg_radius > 0:
self.num_layers_bg = num_layers_bg
self.hidden_dim_bg = hidden_dim_bg
# use a very simple network to avoid it learning the prompt...
self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3, multires=6)
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
else:
self.bg_net = None
def common_forward(self, x):
# sigma
h = self.encoder(x, bound=self.bound, max_level=self.max_level)
# Feature masking for coarse-to-fine training
if self.grid_levels_mask > 0:
h_mask: torch.Tensor = torch.arange(self.in_dim, device=h.device) < self.in_dim - self.grid_levels_mask * self.level_dim # (self.in_dim)
h_mask = h_mask.reshape(1, self.in_dim).float() # (1, self.in_dim)
h = h * h_mask # (N, self.in_dim)
h = self.sigma_net(h)
sigma = self.density_activation(h[..., 0] + self.density_blob(x))
albedo = torch.sigmoid(h[..., 1:])
return sigma, albedo
def forward(self, x, d, l=None, ratio=1, shading='albedo'):
# x: [N, 3], in [-bound, bound]
# d: [N, 3], view direction, nomalized in [-1, 1]
# l: [3], plane light direction, nomalized in [-1, 1]
# ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
sigma, albedo = self.common_forward(x)
if shading == 'albedo':
normal = None
color = albedo
else: # lambertian shading
normal = self.normal(x)
if shading == 'normal':
color = (normal + 1) / 2
else:
lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]
if shading == 'textureless':
color = lambertian.unsqueeze(-1).repeat(1, 3)
else: # 'lambertian'
color = albedo * lambertian.unsqueeze(-1)
return sigma, color, normal
def density(self, x):
# x: [N, 3], in [-bound, bound]
sigma, albedo = self.common_forward(x)
return {
'sigma': sigma,
'albedo': albedo,
}
def background(self, d):
h = self.encoder_bg(d) # [N, C]
h = self.bg_net(h)
# sigmoid activation for rgb
rgbs = torch.sigmoid(h)
return rgbs
# optimizer utils
def get_params(self, lr):
params = [
{'params': self.encoder.parameters(), 'lr': lr * 10},
{'params': self.sigma_net.parameters(), 'lr': lr * self.opt.lr_scale_nerf},
# {'params': self.normal_net.parameters(), 'lr': lr},
]
if self.opt.bg_radius > 0:
# params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
params.append({'params': self.bg_net.parameters(), 'lr': lr})
if self.opt.dmtet:
params.append({'params': self.dmtet.parameters(), 'lr': lr})
return params
def reset_sigmanet(self):
self.sigma_net.reset_parameters()
def init_nerf_from_sdf_color(self, rpst, albedo,
points=None, pretrain_iters=10000, lr=0.001, rpst_type='sdf',
):
self.reset_sigmanet()
# matching optimization
self.train()
self.grid_levels_mask = 0
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(list(self.parameters()), lr=lr)
milestones = [int(0.4 * pretrain_iters), int(0.8 * pretrain_iters)]
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)
rpst = rpst.squeeze().clamp(0, 1)
# rpst = torch.ones_like(rpst) * 0.4
pbar = tqdm(range(pretrain_iters), desc="NeRF sigma optimization")
rgb_loss = torch.tensor(0, device=rpst.device)
for i in pbar:
output = self.density(points)
if rpst_type == 'sdf':
pred_rpst = output['sigma'] - self.density_thresh
else:
pred_rpst = output['sigma']
sdf_loss = loss_fn(pred_rpst, rpst)
if albedo is not None:
pred_albedo = output['albedo']
rgb_loss = loss_fn(pred_albedo, albedo)
loss = 10 * sdf_loss + rgb_loss
else:
loss = sdf_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
pbar.set_postfix(loss=loss.item(), rgb_loss=rgb_loss.item(), sdf_loss=sdf_loss.item())
logger.info(f'lr: {lr} Accuracy: (pred_rpst[rpst>0]>0).sum() / (rpst>0).sum()')
pbar = tqdm(range(pretrain_iters), desc="NeRF color optimization")
def init_tet_from_sdf_color(self, sdf, colors=None, pretrain_iters=5000, lr=0.01):
self.train()
self.grid_levels_mask = 0
self.dmtet.reset_tet(reset_scale=False)
self.dmtet.init_tet_from_sdf(sdf, pretrain_iters=pretrain_iters, lr=lr)
if colors is not None:
self.reset_sigmanet()
loss_fn = torch.nn.MSELoss()
pretrain_iters = 5000
optimizer = torch.optim.Adam(list(self.parameters()), lr=0.01)
pbar = tqdm(range(pretrain_iters), desc="NeRF color optimization")
for i in pbar:
pred_albedo = self.density(self.dmtet.verts)['albedo']
loss = loss_fn(pred_albedo, colors)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_postfix(loss=loss.item())