1576 lines
64 KiB
Python
1576 lines
64 KiB
Python
import os
|
|
import math
|
|
import cv2
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from encoding import get_encoder
|
|
import nvdiffrast.torch as dr
|
|
|
|
import mcubes
|
|
import raymarching
|
|
from .utils import custom_meshgrid, safe_normalize
|
|
import logging
|
|
from activation import trunc_exp, biased_softplus
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
|
|
super().__init__()
|
|
self.dim_in = dim_in
|
|
self.dim_out = dim_out
|
|
self.dim_hidden = dim_hidden
|
|
self.num_layers = num_layers
|
|
|
|
net = []
|
|
for l in range(num_layers):
|
|
net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden,
|
|
self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
|
|
|
|
self.net = nn.ModuleList(net)
|
|
|
|
def forward(self, x):
|
|
for l in range(self.num_layers):
|
|
x = self.net[l](x)
|
|
if l != self.num_layers - 1:
|
|
x = F.relu(x, inplace=True)
|
|
return x
|
|
|
|
def reset_parameters(self):
|
|
@torch.no_grad()
|
|
def weight_init(m):
|
|
if isinstance(m, nn.Linear):
|
|
nn.init.xavier_uniform_(
|
|
m.weight, gain=nn.init.calculate_gain('relu'))
|
|
if m.bias is not None:
|
|
nn.init.zeros_(m.bias)
|
|
self.apply(weight_init)
|
|
|
|
|
|
def sample_pdf(bins, weights, n_samples, det=False):
|
|
# This implementation is from NeRF
|
|
# bins: [B, T], old_z_vals
|
|
# weights: [B, T - 1], bin weights.
|
|
# return: [B, n_samples], new_z_vals
|
|
|
|
# Get pdf
|
|
weights = weights + 1e-5 # prevent nans
|
|
pdf = weights / torch.sum(weights, -1, keepdim=True)
|
|
cdf = torch.cumsum(pdf, -1)
|
|
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
|
|
# Take uniform samples
|
|
if det:
|
|
u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device)
|
|
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
|
|
else:
|
|
u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)
|
|
|
|
# Invert CDF
|
|
u = u.contiguous()
|
|
inds = torch.searchsorted(cdf, u, right=True)
|
|
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
|
|
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
|
|
inds_g = torch.stack([below, above], -1) # (B, n_samples, 2)
|
|
|
|
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
|
|
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
|
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
|
|
|
denom = (cdf_g[..., 1] - cdf_g[..., 0])
|
|
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
|
|
t = (u - cdf_g[..., 0]) / denom
|
|
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
|
|
|
|
return samples
|
|
|
|
@torch.cuda.amp.autocast(enabled=False)
|
|
def near_far_from_bound(rays_o, rays_d, bound, type='cube', min_near=0.05):
|
|
# rays: [B, N, 3], [B, N, 3]
|
|
# bound: int, radius for ball or half-edge-length for cube
|
|
# return near [B, N, 1], far [B, N, 1]
|
|
|
|
radius = rays_o.norm(dim=-1, keepdim=True)
|
|
|
|
if type == 'sphere':
|
|
near = radius - bound # [B, N, 1]
|
|
far = radius + bound
|
|
|
|
elif type == 'cube':
|
|
tmin = (-bound - rays_o) / (rays_d + 1e-15) # [B, N, 3]
|
|
tmax = (bound - rays_o) / (rays_d + 1e-15)
|
|
near = torch.where(tmin < tmax, tmin, tmax).max(dim=-1, keepdim=True)[0]
|
|
far = torch.where(tmin > tmax, tmin, tmax).min(dim=-1, keepdim=True)[0]
|
|
# if far < near, means no intersection, set both near and far to inf (1e9 here)
|
|
mask = far < near
|
|
near[mask] = 1e9
|
|
far[mask] = 1e9
|
|
# restrict near to a minimal value
|
|
near = torch.clamp(near, min=min_near)
|
|
|
|
return near, far
|
|
|
|
|
|
def plot_pointcloud(pc, color=None):
|
|
import trimesh
|
|
# pc: [N, 3]
|
|
# color: [N, 3/4]
|
|
logger.info('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0))
|
|
pc = trimesh.PointCloud(pc, color)
|
|
# axis
|
|
axes = trimesh.creation.axis(axis_length=4)
|
|
# sphere
|
|
sphere = trimesh.creation.icosphere(radius=1)
|
|
trimesh.Scene([pc, axes, sphere]).show()
|
|
|
|
|
|
class DMTet:
|
|
def __init__(self, device='cuda'):
|
|
self.device = device
|
|
self.triangle_table = torch.tensor([
|
|
[-1, -1, -1, -1, -1, -1],
|
|
[1, 0, 2, -1, -1, -1],
|
|
[4, 0, 3, -1, -1, -1],
|
|
[1, 4, 2, 1, 3, 4],
|
|
[3, 1, 5, -1, -1, -1],
|
|
[2, 3, 0, 2, 5, 3],
|
|
[1, 4, 0, 1, 5, 4],
|
|
[4, 2, 5, -1, -1, -1],
|
|
[4, 5, 2, -1, -1, -1],
|
|
[4, 1, 0, 4, 5, 1],
|
|
[3, 2, 0, 3, 5, 2],
|
|
[1, 3, 5, -1, -1, -1],
|
|
[4, 1, 2, 4, 3, 1],
|
|
[3, 0, 4, -1, -1, -1],
|
|
[2, 0, 1, -1, -1, -1],
|
|
[-1, -1, -1, -1, -1, -1]
|
|
], dtype=torch.long, device=self.device)
|
|
|
|
self.num_triangles_table = torch.tensor(
|
|
[0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=self.device)
|
|
self.base_tet_edges = torch.tensor(
|
|
[0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=self.device)
|
|
|
|
###############################################################################
|
|
# Utility functions
|
|
###############################################################################
|
|
|
|
def sort_edges(self, edges_ex2):
|
|
with torch.no_grad():
|
|
order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
|
|
order = order.unsqueeze(dim=1)
|
|
|
|
a = torch.gather(input=edges_ex2, index=order, dim=1)
|
|
b = torch.gather(input=edges_ex2, index=1-order, dim=1)
|
|
|
|
return torch.stack([a, b], -1)
|
|
|
|
def map_uv(self, faces, face_gidx, max_idx):
|
|
N = int(np.ceil(np.sqrt((max_idx+1)//2)))
|
|
tex_y, tex_x = torch.meshgrid(
|
|
torch.linspace(0, 1 - (1 / N), N,
|
|
dtype=torch.float32, device=self.device),
|
|
torch.linspace(0, 1 - (1 / N), N,
|
|
dtype=torch.float32, device=self.device),
|
|
) # indexing='ij')
|
|
|
|
pad = 0.9 / N
|
|
|
|
uvs = torch.stack([
|
|
tex_x, tex_y,
|
|
tex_x + pad, tex_y,
|
|
tex_x + pad, tex_y + pad,
|
|
tex_x, tex_y + pad
|
|
], dim=-1).view(-1, 2)
|
|
|
|
def _idx(tet_idx, N):
|
|
x = tet_idx % N
|
|
y = torch.div(tet_idx, N, rounding_mode='trunc')
|
|
return y * N + x
|
|
|
|
tet_idx = _idx(torch.div(face_gidx, 2, rounding_mode='trunc'), N)
|
|
tri_idx = face_gidx % 2
|
|
|
|
uv_idx = torch.stack((
|
|
tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2
|
|
), dim=-1). view(-1, 3)
|
|
|
|
return uvs, uv_idx
|
|
|
|
###############################################################################
|
|
# Marching tets implementation
|
|
###############################################################################
|
|
|
|
def __call__(self, pos_nx3, sdf_n, tet_fx4, return_uv=True):
|
|
# pos_nx3: [N, 3]
|
|
# sdf_n: [N]
|
|
# tet_fx4: [F, 4]
|
|
|
|
with torch.no_grad():
|
|
occ_n = sdf_n > 0
|
|
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
|
|
occ_sum = torch.sum(occ_fx4, -1) # [F,]
|
|
|
|
# a valid tets not all positive (out space) and not all negative (inner)
|
|
valid_tets = (occ_sum > 0) & (occ_sum < 4)
|
|
occ_sum = occ_sum[valid_tets]
|
|
|
|
# find all vertices
|
|
all_edges = tet_fx4[valid_tets][:,
|
|
self.base_tet_edges].reshape(-1, 2)
|
|
all_edges = self.sort_edges(all_edges)
|
|
unique_edges, idx_map = torch.unique(
|
|
all_edges, dim=0, return_inverse=True)
|
|
|
|
# find out the edges across the surface to interpolate and refine
|
|
unique_edges = unique_edges.long()
|
|
mask_edges = occ_n[unique_edges.reshape(-1)
|
|
].reshape(-1, 2).sum(-1) == 1
|
|
mapping = torch.ones(
|
|
(unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1
|
|
mapping[mask_edges] = torch.arange(
|
|
mask_edges.sum(), dtype=torch.long, device=self.device)
|
|
idx_map = mapping[idx_map] # map edges to verts
|
|
|
|
interp_v = unique_edges[mask_edges]
|
|
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
|
|
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
|
|
edges_to_interp_sdf[:, -1] *= -1
|
|
denominator = edges_to_interp_sdf.sum(1, keepdim=True)
|
|
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1])/denominator
|
|
|
|
# interpolate edges by sdf
|
|
verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
|
|
|
|
idx_map = idx_map.reshape(-1, 6)
|
|
|
|
v_id = torch.pow(2, torch.arange(
|
|
4, dtype=torch.long, device=self.device))
|
|
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
|
|
num_triangles = self.num_triangles_table[tetindex]
|
|
|
|
# Generate triangle indices
|
|
faces = torch.cat((
|
|
torch.gather(input=idx_map[num_triangles == 1], dim=1,
|
|
index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3),
|
|
torch.gather(input=idx_map[num_triangles == 2], dim=1,
|
|
index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3),
|
|
), dim=0)
|
|
|
|
if return_uv:
|
|
# Get global face index (static, does not depend on topology)
|
|
num_tets = tet_fx4.shape[0]
|
|
tet_gidx = torch.arange(num_tets, dtype=torch.long, device=self.device)[
|
|
valid_tets]
|
|
face_gidx = torch.cat((
|
|
tet_gidx[num_triangles == 1]*2,
|
|
torch.stack(
|
|
(tet_gidx[num_triangles == 2]*2, tet_gidx[num_triangles == 2]*2 + 1), dim=-1).view(-1)
|
|
), dim=0)
|
|
|
|
uvs, uv_idx = self.map_uv(faces, face_gidx, num_tets*2)
|
|
else:
|
|
uvs, uv_idx = None, None
|
|
return verts, faces, uvs, uv_idx
|
|
|
|
###############################################################################
|
|
# Regularizer
|
|
###############################################################################
|
|
|
|
|
|
def sdf_reg_loss(sdf, all_edges):
|
|
sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1, 2)
|
|
mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
|
|
sdf_f1x6x2 = sdf_f1x6x2[mask]
|
|
sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \
|
|
torch.nn.functional.binary_cross_entropy_with_logits(
|
|
sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float())
|
|
return sdf_diff
|
|
|
|
###############################################################################
|
|
# Geometry interface
|
|
###############################################################################
|
|
|
|
|
|
class DMTetGeometry(torch.nn.Module):
|
|
def __init__(self, grid_res, tet_mlp, opt, device='cuda'):
|
|
super(DMTetGeometry, self).__init__()
|
|
|
|
self.opt = opt
|
|
self.device = device
|
|
self.tet_scale = torch.ones(3, device=device)
|
|
self.grid_res = grid_res
|
|
self.marching_tets = DMTet()
|
|
|
|
tets = np.load('data/tets/{}_tets.npz'.format(self.grid_res))
|
|
# for 64/128, [N=36562/277410, 3], in [-0.5, 0.5]^3
|
|
self.verts = torch.tensor(
|
|
tets['vertices'], dtype=torch.float32, device=self.device) * 2
|
|
# for 64/128, [M=192492/1524684, 4], vert indices for each tetrahetron
|
|
self.indices = torch.tensor(
|
|
tets['indices'], dtype=torch.long, device=self.device)
|
|
self.generate_edges()
|
|
|
|
self.tet_mlp = tet_mlp
|
|
if tet_mlp:
|
|
self.encoder, self.in_dim = get_encoder('hashgrid', input_dim=3)
|
|
self.encoder = self.encoder.to(device)
|
|
self.mlp = MLP(self.in_dim, 4, 32, 3, False).to(device)
|
|
self.sdf = None
|
|
else:
|
|
sdf = torch.nn.Parameter(torch.zeros_like(
|
|
self.verts[..., 0]), requires_grad=True)
|
|
self.register_parameter('sdf', sdf)
|
|
deform = torch.nn.Parameter(
|
|
torch.zeros_like(self.verts), requires_grad=True)
|
|
self.register_parameter('deform', deform)
|
|
|
|
if opt.base_mesh and os.path.exists(opt.base_mesh):
|
|
self.init_tet_from_mesh(opt.base_mesh)
|
|
|
|
def reset_tet(self, reset_scale=True):
|
|
if self.tet_mlp:
|
|
self.mlp.reset_parameters()
|
|
else:
|
|
self.sdf.data = torch.zeros_like(self.verts[..., 0])
|
|
self.deform.data = torch.zeros_like(self.verts)
|
|
if reset_scale:
|
|
self.reset_tet_scale()
|
|
|
|
def get_sdf_from_mesh(self, base_mesh):
|
|
logger.info(f'[INFO] init sdf from base mesh: {base_mesh}')
|
|
|
|
import cubvh
|
|
import trimesh
|
|
mesh = trimesh.load(base_mesh, force='mesh')
|
|
|
|
scale = 1.5 / np.array(mesh.bounds[1] - mesh.bounds[0]).max()
|
|
center = np.array(mesh.bounds[1] + mesh.bounds[0]) / 2
|
|
mesh.vertices = (mesh.vertices - center) * scale
|
|
|
|
# build with numpy.ndarray/torch.Tensor
|
|
BVH = cubvh.cuBVH(mesh.vertices, mesh.faces)
|
|
sdf, face_id, _ = BVH.signed_distance(
|
|
self.verts, return_uvw=False, mode='watertight')
|
|
sdf *= -1 # INNER is POSITIVE
|
|
return sdf
|
|
|
|
def init_tet_from_mesh(self, base_mesh):
|
|
sdf = self.get_sdf_from_mesh(base_mesh)
|
|
self.init_tet_from_sdf(sdf)
|
|
# visualize
|
|
# sdf_np_gt = sdf.cpu().numpy()
|
|
# sdf_np = self.mlp(self.encoder(self.verts)).detach().cpu().numpy()[..., 0]
|
|
# verts_np = self.verts.cpu().numpy()
|
|
# color = np.zeros_like(verts_np)
|
|
# color[sdf_np < 0] = [1, 0, 0]
|
|
# color[sdf_np > 0] = [0, 0, 1]
|
|
# color = (color * 255).astype(np.uint8)
|
|
# pc = trimesh.PointCloud(verts_np, color)
|
|
# axes = trimesh.creation.axis(axis_length=4)
|
|
# box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline()
|
|
# trimesh.Scene([mesh, pc, axes, box]).show()
|
|
|
|
def init_tet_from_sdf(self, sdf, pretrain_iters=5000, lr=1e-3):
|
|
if self.tet_mlp:
|
|
self.mlp.reset_parameters()
|
|
# pretraining
|
|
loss_fn = torch.nn.MSELoss()
|
|
optimizer = torch.optim.Adam(list(self.parameters()), lr=lr)
|
|
|
|
#batch_size = min(10240, self.verts.shape[0])
|
|
batch_size = self.verts.shape[0]
|
|
pbar = tqdm(range(pretrain_iters), desc="init dmtet mlp from sdf")
|
|
for i in pbar:
|
|
rand_idx = torch.randint(0, self.verts.shape[0], (batch_size,))
|
|
p = self.verts[rand_idx]
|
|
ref_value = sdf[rand_idx]
|
|
output = self.mlp(self.encoder(p))
|
|
loss = loss_fn(output[..., 0], ref_value)
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
pbar.set_postfix(loss=loss.item())
|
|
else:
|
|
self.sdf.data = sdf.squeeze()
|
|
|
|
@torch.no_grad()
|
|
def reset_tet_scale(self, tet_scale=1.):
|
|
if isinstance(tet_scale, float):
|
|
tet_scale = torch.ones(3, device=self.device) * tet_scale
|
|
self.tet_scale = tet_scale
|
|
self.verts = self.verts * tet_scale
|
|
|
|
@torch.no_grad()
|
|
def generate_edges(self):
|
|
# six edges for each tetrahedron.
|
|
edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3],
|
|
dtype=torch.long, device=self.device)
|
|
all_edges = self.indices[:, edges].reshape(-1, 2) # [M * 6, 2]
|
|
all_edges_sorted = torch.sort(all_edges, dim=1)[0]
|
|
self.all_edges = torch.unique(all_edges_sorted, dim=0)
|
|
|
|
def get_sdf_deform(self):
|
|
if self.tet_mlp:
|
|
# predict SDF and per-vertex deformation
|
|
pred = self.mlp(self.encoder(self.verts))
|
|
sdf, deform = pred[:, 0], pred[:, 1:]
|
|
return sdf, torch.tanh(deform) / (self.grid_res)
|
|
else:
|
|
return self.sdf, torch.tanh(self.deform) / (self.grid_res)
|
|
|
|
def get_verts_face(self):
|
|
sdf, deform = self.get_sdf_deform()
|
|
verts, faces, _, _ = self.marching_tets(
|
|
self.verts + deform, sdf, self.indices, return_uv=False)
|
|
return verts, faces
|
|
|
|
# def getAABB(self):
|
|
# return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
|
|
|
|
# def getMesh(self, material):
|
|
|
|
# pred = self.mlp(self.encoder(self.verts)) # predict SDF and per-vertex deformation
|
|
# sdf, deform = pred[:, 0], pred[:, 1:]
|
|
|
|
# v_deformed = self.verts + torch.tanh(deform) / (self.grid_res)
|
|
|
|
# verts, faces, uvs, uv_idx = self.marching_tets(v_deformed, sdf, self.indices)
|
|
|
|
# imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material)
|
|
|
|
# # Run mesh operations to generate tangent space
|
|
# imesh = mesh.auto_normals(imesh)
|
|
# imesh = mesh.compute_tangents(imesh)
|
|
|
|
# return imesh, sdf
|
|
|
|
# def render(self, glctx, target, lgt, opt_material, bsdf=None):
|
|
|
|
# # return rendered buffers, keys: ['shaded', 'kd_grad', 'occlusion'].
|
|
# opt_mesh, sdf = self.getMesh(opt_material)
|
|
# buffers = render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'],
|
|
# msaa=True, background=None, bsdf=bsdf)
|
|
# buffers['mesh'] = opt_mesh
|
|
# buffers['sdf'] = sdf
|
|
|
|
# return buffers
|
|
|
|
# def tick(self, glctx, target, lgt, opt_material, loss_fn, guidance_model, text_z, iteration):
|
|
|
|
# # ==============================================================================================
|
|
# # Render optimizable object with identical conditions
|
|
# # ==============================================================================================
|
|
# buffers = self.render(glctx, target, lgt, opt_material)
|
|
|
|
# mesh = buffers['mesh']
|
|
|
|
# # ==============================================================================================
|
|
# # Compute loss
|
|
# # ==============================================================================================
|
|
# t_iter = iteration / self.opt.iter
|
|
|
|
# if iteration < int(self.opt.iter * 0.2):
|
|
# # mode = 'normal_latent'
|
|
# pred_rgb = buffers['normal'][..., 0:4].permute(0, 3, 1, 2).contiguous()
|
|
# as_latent = True
|
|
# elif iteration < int(self.opt.iter * 0.6):
|
|
# # mode = 'normal'
|
|
# pred_rgb = buffers['normal'][..., 0:3].permute(0, 3, 1, 2).contiguous()
|
|
# as_latent = False
|
|
# else:
|
|
# # mode = 'rgb'
|
|
# pred_rgb = buffers['shaded'][..., 0:3].permute(0, 3, 1, 2).contiguous()
|
|
# pred_ws = buffers['shaded'][..., 3].unsqueeze(1) # [B, 1, H, W]
|
|
# pred_rgb = pred_rgb * pred_ws + (1 - pred_ws) * 1 # white bg
|
|
# as_latent = False
|
|
|
|
# # torch_vis_2d(pred_rgb[0])
|
|
# # torch_vis_2d(pred_normal[0])
|
|
# # torch_vis_2d(pred_ws[0])
|
|
|
|
# if self.opt.directional_text:
|
|
# all_pos = []
|
|
# all_neg = []
|
|
# for emb in text_z[target['direction']]: # list of [2, S, -1]
|
|
# pos, neg = emb.chunk(2) # [1, S, -1]
|
|
# all_pos.append(pos)
|
|
# all_neg.append(neg)
|
|
# text_embedding = torch.cat(all_pos + all_neg, dim=0) # [2b, S, -1]
|
|
# else:
|
|
# text_embedding = text_z
|
|
|
|
# img_loss = guidance_model.train_step(text_embedding, pred_rgb.half(), as_latent=as_latent)
|
|
|
|
# # img_loss = torch.tensor(0.0, device = self.device)
|
|
|
|
# # below are lots of regularizations...
|
|
# reg_loss = torch.tensor(0.0, device = self.device)
|
|
|
|
# if iteration < int(self.opt.iter * 0.6):
|
|
# # SDF regularizer
|
|
# sdf_weight = self.opt.sdf_regularizer - (self.opt.sdf_regularizer - 0.01) * min(1.0, 4.0 * t_iter)
|
|
# sdf_loss = sdf_reg_loss(buffers['sdf'], self.all_edges).mean() * sdf_weight # Dropoff to 0.01
|
|
# reg_loss = reg_loss + sdf_loss
|
|
|
|
# # directly regularize mesh smoothness in finetuning...
|
|
# if iteration > int(self.opt.iter * 0.2):
|
|
# lap_loss = regularizer.laplace_regularizer_const(mesh.v_pos, mesh.t_pos_idx) * self.opt.laplace_scale #* min(1.0, iteration / 500)
|
|
# reg_loss = reg_loss + lap_loss
|
|
|
|
# # normal_loss = regularizer.normal_consistency(mesh.v_pos, mesh.t_pos_idx) * self.opt.laplace_scale * min(1.0, iteration / 500)
|
|
# # reg_loss = reg_loss + normal_loss
|
|
|
|
# else:
|
|
# # Albedo (k_d) smoothnesss regularizer
|
|
# # reg_loss += torch.mean(buffers['kd_grad'][..., :-1] * buffers['kd_grad'][..., -1:]) * 0.03 * min(1.0, (iteration - int(self.opt.iter * 0.6)) / 500)
|
|
|
|
# # # Visibility regularizer
|
|
# # reg_loss += torch.mean(buffers['occlusion'][..., :-1] * buffers['occlusion'][..., -1:]) * 0.001 * min(1.0, (iteration - int(self.opt.iter * 0.6)) / 500)
|
|
|
|
# # # Light white balance regularizer
|
|
# reg_loss += lgt.regularizer() * 0.005
|
|
|
|
# return img_loss, reg_loss
|
|
|
|
|
|
def compute_edge_to_face_mapping(attr_idx):
|
|
with torch.no_grad():
|
|
# Get unique edges
|
|
# Create all edges, packed by triangle
|
|
all_edges = torch.cat((
|
|
torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1),
|
|
torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1),
|
|
torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1),
|
|
), dim=-1).view(-1, 2)
|
|
|
|
# Swap edge order so min index is always first
|
|
order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1)
|
|
sorted_edges = torch.cat((
|
|
torch.gather(all_edges, 1, order),
|
|
torch.gather(all_edges, 1, 1 - order)
|
|
), dim=-1)
|
|
|
|
# Elliminate duplicates and return inverse mapping
|
|
unique_edges, idx_map = torch.unique(
|
|
sorted_edges, dim=0, return_inverse=True)
|
|
|
|
tris = torch.arange(attr_idx.shape[0]).repeat_interleave(3).cuda()
|
|
|
|
tris_per_edge = torch.zeros(
|
|
(unique_edges.shape[0], 2), dtype=torch.int64).cuda()
|
|
|
|
# Compute edge to face table
|
|
mask0 = order[:, 0] == 0
|
|
mask1 = order[:, 0] == 1
|
|
tris_per_edge[idx_map[mask0], 0] = tris[mask0]
|
|
tris_per_edge[idx_map[mask1], 1] = tris[mask1]
|
|
|
|
return tris_per_edge
|
|
|
|
|
|
@torch.cuda.amp.autocast(enabled=False)
|
|
def normal_consistency(face_normals, t_pos_idx):
|
|
|
|
tris_per_edge = compute_edge_to_face_mapping(t_pos_idx)
|
|
|
|
# Fetch normals for both faces sharind an edge
|
|
n0 = face_normals[tris_per_edge[:, 0], :]
|
|
n1 = face_normals[tris_per_edge[:, 1], :]
|
|
|
|
# Compute error metric based on normal difference
|
|
term = torch.clamp(torch.sum(n0 * n1, -1, keepdim=True), min=-1.0, max=1.0)
|
|
term = (1.0 - term)
|
|
|
|
return torch.mean(torch.abs(term))
|
|
|
|
|
|
def laplacian_uniform(verts, faces):
|
|
|
|
V = verts.shape[0]
|
|
F = faces.shape[0]
|
|
|
|
# Neighbor indices
|
|
ii = faces[:, [1, 2, 0]].flatten()
|
|
jj = faces[:, [2, 0, 1]].flatten()
|
|
adj = torch.stack(
|
|
[torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique(dim=1)
|
|
adj_values = torch.ones(
|
|
adj.shape[1], device=verts.device, dtype=torch.float)
|
|
|
|
# Diagonal indices
|
|
diag_idx = adj[0]
|
|
|
|
# Build the sparse matrix
|
|
idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1)
|
|
values = torch.cat((-adj_values, adj_values))
|
|
|
|
# The coalesce operation sums the duplicate indices, resulting in the
|
|
# correct diagonal
|
|
return torch.sparse_coo_tensor(idx, values, (V, V)).coalesce()
|
|
|
|
|
|
@torch.cuda.amp.autocast(enabled=False)
|
|
def laplacian_smooth_loss(verts, faces):
|
|
with torch.no_grad():
|
|
L = laplacian_uniform(verts, faces.long())
|
|
loss = L.mm(verts)
|
|
loss = loss.norm(dim=1)
|
|
loss = loss.mean()
|
|
return loss
|
|
|
|
|
|
class NeRFRenderer(nn.Module):
|
|
def __init__(self, opt):
|
|
super().__init__()
|
|
|
|
self.opt = opt
|
|
self.bound = opt.bound
|
|
self.cascade = 1 + math.ceil(math.log2(opt.bound))
|
|
self.grid_size = 128
|
|
self.max_level = None
|
|
self.dmtet = opt.dmtet
|
|
self.cuda_ray = opt.cuda_ray
|
|
self.taichi_ray = opt.taichi_ray
|
|
self.min_near = opt.min_near
|
|
self.density_thresh = opt.density_thresh
|
|
|
|
# prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)
|
|
# NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.
|
|
aabb_train = torch.FloatTensor(
|
|
[-opt.bound, -opt.bound, -opt.bound, opt.bound, opt.bound, opt.bound])
|
|
aabb_infer = aabb_train.clone()
|
|
self.register_buffer('aabb_train', aabb_train)
|
|
self.register_buffer('aabb_infer', aabb_infer)
|
|
|
|
self.glctx = None
|
|
|
|
# extra state for cuda raymarching
|
|
if self.cuda_ray:
|
|
# density grid
|
|
density_grid = torch.zeros(
|
|
[self.cascade, self.grid_size ** 3]) # [CAS, H * H * H]
|
|
density_bitfield = torch.zeros(
|
|
self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8]
|
|
self.register_buffer('density_grid', density_grid)
|
|
self.register_buffer('density_bitfield', density_bitfield)
|
|
self.mean_density = 0
|
|
self.iter_density = 0
|
|
|
|
# load dmtet vertices
|
|
if self.opt.dmtet:
|
|
self.dmtet = DMTetGeometry(opt.tet_grid_size, opt.tet_mlp, opt).to(opt.device)
|
|
if self.opt.h <= 2048 and self.opt.w <= 2048:
|
|
self.glctx = dr.RasterizeCudaContext()
|
|
else:
|
|
self.glctx = dr.RasterizeGLContext()
|
|
|
|
if self.taichi_ray:
|
|
from einops import rearrange
|
|
from taichi_modules import RayMarcherTaichi
|
|
from taichi_modules import VolumeRendererTaichi
|
|
from taichi_modules import RayAABBIntersector as RayAABBIntersectorTaichi
|
|
from taichi_modules import raymarching_test as raymarching_test_taichi
|
|
from taichi_modules import composite_test as composite_test_fw
|
|
from taichi_modules import packbits as packbits_taichi
|
|
self.rearrange = rearrange
|
|
self.packbits_taichi = packbits_taichi
|
|
self.ray_aabb_intersector = RayAABBIntersectorTaichi
|
|
self.raymarching_test_taichi = raymarching_test_taichi
|
|
self.composite_test_fw = composite_test_fw
|
|
self.ray_marching = RayMarcherTaichi(
|
|
batch_size=4096) # TODO: hard encoded batch size
|
|
self.volume_render = VolumeRendererTaichi(
|
|
batch_size=4096) # TODO: hard encoded batch size
|
|
# density grid
|
|
density_grid = torch.zeros(
|
|
[self.cascade, self.grid_size ** 3]) # [CAS, H * H * H]
|
|
density_bitfield = torch.zeros(
|
|
self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8]
|
|
self.register_buffer('density_grid', density_grid)
|
|
self.register_buffer('density_bitfield', density_bitfield)
|
|
self.mean_density = 0
|
|
self.iter_density = 0
|
|
|
|
if self.opt.density_activation == 'exp':
|
|
self.density_activation = trunc_exp
|
|
elif self.opt.density_activation == 'softplus':
|
|
self.density_activation = F.softplus
|
|
elif self.opt.density_activation == 'relu':
|
|
self.density_activation = F.relu
|
|
|
|
# ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192
|
|
def finite_difference_normal(self, x, epsilon=1e-2):
|
|
# x: [N, 3]
|
|
dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
|
dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
|
dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
|
dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
|
|
dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound))
|
|
dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound))
|
|
|
|
normal = torch.stack([
|
|
0.5 * (dx_pos - dx_neg) / epsilon,
|
|
0.5 * (dy_pos - dy_neg) / epsilon,
|
|
0.5 * (dz_pos - dz_neg) / epsilon
|
|
], dim=-1)
|
|
|
|
return -normal
|
|
|
|
def normal(self, x):
|
|
normal = self.finite_difference_normal(x)
|
|
normal = safe_normalize(normal)
|
|
normal = torch.nan_to_num(normal)
|
|
return normal
|
|
|
|
@torch.no_grad()
|
|
def density_blob(self, x):
|
|
# x: [B, N, 3]
|
|
|
|
d = (x ** 2).sum(-1)
|
|
|
|
if self.opt.density_activation == 'exp':
|
|
g = self.opt.blob_density * \
|
|
torch.exp(- d / (2 * self.opt.blob_radius ** 2))
|
|
else:
|
|
g = self.opt.blob_density * \
|
|
(1 - torch.sqrt(d) / self.opt.blob_radius)
|
|
|
|
return g
|
|
|
|
def forward(self, x, d):
|
|
raise NotImplementedError()
|
|
|
|
def density(self, x):
|
|
raise NotImplementedError()
|
|
|
|
def reset_extra_state(self):
|
|
if not (self.cuda_ray or self.taichi_ray):
|
|
return
|
|
# density grid
|
|
self.density_grid.zero_()
|
|
self.mean_density = 0
|
|
self.iter_density = 0
|
|
|
|
@torch.no_grad()
|
|
def export_mesh(self, path, resolution=None, decimate_target=-1, S=128):
|
|
from meshutils import decimate_mesh, clean_mesh, poisson_mesh_reconstruction
|
|
if self.opt.dmtet:
|
|
vertices, triangles = self.dmtet.get_verts_face()
|
|
vertices = vertices.detach().cpu().numpy()
|
|
triangles = triangles.detach().cpu().numpy()
|
|
|
|
else:
|
|
|
|
if resolution is None:
|
|
resolution = self.grid_size
|
|
|
|
if self.cuda_ray:
|
|
density_thresh = min(self.mean_density, self.density_thresh) \
|
|
if np.greater(self.mean_density, 0) else self.density_thresh
|
|
else:
|
|
density_thresh = self.density_thresh
|
|
|
|
sigmas = np.zeros(
|
|
[resolution, resolution, resolution], dtype=np.float32)
|
|
|
|
# query
|
|
X = torch.linspace(-1, 1, resolution).split(S)
|
|
Y = torch.linspace(-1, 1, resolution).split(S)
|
|
Z = torch.linspace(-1, 1, resolution).split(S)
|
|
|
|
for xi, xs in enumerate(X):
|
|
for yi, ys in enumerate(Y):
|
|
for zi, zs in enumerate(Z):
|
|
xx, yy, zz = custom_meshgrid(xs, ys, zs)
|
|
pts = torch.cat(
|
|
[xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3]
|
|
val = self.density(pts.to(self.aabb_train.device))
|
|
sigmas[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(
|
|
zs)] = val['sigma'].reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z]
|
|
|
|
logger.info(
|
|
f'[INFO] marching cubes thresh: {density_thresh} ({sigmas.min()} ~ {sigmas.max()})')
|
|
|
|
vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)
|
|
vertices = vertices / (resolution - 1.0) * 2 - 1
|
|
|
|
# clean
|
|
vertices = vertices.astype(np.float32)
|
|
triangles = triangles.astype(np.int32)
|
|
vertices, triangles = clean_mesh(
|
|
vertices, triangles, remesh=True, remesh_size=0.01)
|
|
|
|
# decimation
|
|
if decimate_target > 0 and triangles.shape[0] > decimate_target:
|
|
vertices, triangles = decimate_mesh(
|
|
vertices, triangles, decimate_target)
|
|
|
|
v = torch.from_numpy(vertices).contiguous(
|
|
).float().to(self.aabb_train.device)
|
|
f = torch.from_numpy(triangles).contiguous().int().to(
|
|
self.aabb_train.device)
|
|
|
|
# mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault...
|
|
# mesh.export(os.path.join(path, f'mesh.ply'))
|
|
|
|
def _export(v, f, h0=2048, w0=2048, ssaa=1, name=''):
|
|
# v, f: torch Tensor
|
|
device = v.device
|
|
v_np = v.cpu().numpy() # [N, 3]
|
|
f_np = f.cpu().numpy() # [M, 3]
|
|
|
|
logger.info(
|
|
f'[INFO] running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}')
|
|
|
|
# unwrap uvs
|
|
import xatlas
|
|
import nvdiffrast.torch as dr
|
|
from sklearn.neighbors import NearestNeighbors
|
|
from scipy.ndimage import binary_dilation, binary_erosion
|
|
|
|
atlas = xatlas.Atlas()
|
|
atlas.add_mesh(v_np, f_np)
|
|
chart_options = xatlas.ChartOptions()
|
|
chart_options.max_iterations = 4 # for faster unwrap...
|
|
atlas.generate(chart_options=chart_options)
|
|
vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
|
|
|
|
# vmapping, ft_np, vt_np = xatlas.parametrize(v_np, f_np) # [N], [M, 3], [N, 2]
|
|
|
|
vt = torch.from_numpy(vt_np.astype(np.float32)).float().to(device)
|
|
ft = torch.from_numpy(ft_np.astype(np.int64)).int().to(device)
|
|
|
|
# render uv maps
|
|
uv = vt * 2.0 - 1.0 # uvs to range [-1, 1]
|
|
uv = torch.cat((uv, torch.zeros_like(
|
|
uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]
|
|
|
|
if ssaa > 1:
|
|
h = int(h0 * ssaa)
|
|
w = int(w0 * ssaa)
|
|
else:
|
|
h, w = h0, w0
|
|
|
|
if self.glctx is None:
|
|
if h <= 2048 and w <= 2048:
|
|
self.glctx = dr.RasterizeCudaContext()
|
|
else:
|
|
self.glctx = dr.RasterizeGLContext()
|
|
|
|
rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(
|
|
0), ft, (h, w)) # [1, h, w, 4]
|
|
xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3]
|
|
mask, _ = dr.interpolate(torch.ones_like(
|
|
v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1]
|
|
|
|
# masked query
|
|
xyzs = xyzs.view(-1, 3)
|
|
mask = (mask > 0).view(-1)
|
|
|
|
feats = torch.zeros(h * w, 3, device=device, dtype=torch.float32)
|
|
|
|
if mask.any():
|
|
xyzs = xyzs[mask] # [M, 3]
|
|
|
|
# batched inference to avoid OOM
|
|
all_feats = []
|
|
head = 0
|
|
while head < xyzs.shape[0]:
|
|
tail = min(head + 640000, xyzs.shape[0])
|
|
results_ = self.density(xyzs[head:tail])
|
|
all_feats.append(results_['albedo'].float())
|
|
head += 640000
|
|
|
|
feats[mask] = torch.cat(all_feats, dim=0)
|
|
|
|
feats = feats.view(h, w, -1)
|
|
mask = mask.view(h, w)
|
|
|
|
# quantize [0.0, 1.0] to [0, 255]
|
|
feats = feats.cpu().numpy()
|
|
feats = (feats * 255).astype(np.uint8)
|
|
|
|
### NN search as an antialiasing ...
|
|
mask = mask.cpu().numpy()
|
|
|
|
inpaint_region = binary_dilation(mask, iterations=3)
|
|
inpaint_region[mask] = 0
|
|
|
|
search_region = mask.copy()
|
|
not_search_region = binary_erosion(search_region, iterations=2)
|
|
search_region[not_search_region] = 0
|
|
|
|
search_coords = np.stack(np.nonzero(search_region), axis=-1)
|
|
inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1)
|
|
|
|
knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords)
|
|
_, indices = knn.kneighbors(inpaint_coords)
|
|
|
|
feats[tuple(inpaint_coords.T)] = feats[tuple(search_coords[indices[:, 0]].T)]
|
|
|
|
feats = cv2.cvtColor(feats, cv2.COLOR_RGB2BGR)
|
|
|
|
# do ssaa after the NN search, in numpy
|
|
if ssaa > 1:
|
|
feats = cv2.resize(feats, (w0, h0), interpolation=cv2.INTER_LINEAR)
|
|
|
|
cv2.imwrite(os.path.join(path, f'{name}albedo.png'), feats)
|
|
|
|
# save obj (v, vt, f /)
|
|
obj_file = os.path.join(path, f'{name}mesh.obj')
|
|
mtl_file = os.path.join(path, f'{name}mesh.mtl')
|
|
|
|
logger.info(f'[INFO] writing obj mesh to {obj_file}')
|
|
with open(obj_file, "w") as fp:
|
|
fp.write(f'mtllib {name}mesh.mtl \n')
|
|
|
|
logger.info(f'[INFO] writing vertices {v_np.shape}')
|
|
for v in v_np:
|
|
fp.write(f'v {v[0]} {v[1]} {v[2]} \n')
|
|
|
|
logger.info(
|
|
f'[INFO] writing vertices texture coords {vt_np.shape}')
|
|
for v in vt_np:
|
|
fp.write(f'vt {v[0]} {1 - v[1]} \n')
|
|
|
|
logger.info(f'[INFO] writing faces {f_np.shape}')
|
|
fp.write(f'usemtl mat0 \n')
|
|
for i in range(len(f_np)):
|
|
fp.write(
|
|
f"f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1} {f_np[i, 1] + 1}/{ft_np[i, 1] + 1} {f_np[i, 2] + 1}/{ft_np[i, 2] + 1} \n")
|
|
|
|
with open(mtl_file, "w") as fp:
|
|
fp.write(f'newmtl mat0 \n')
|
|
fp.write(f'Ka 1.000000 1.000000 1.000000 \n')
|
|
fp.write(f'Kd 1.000000 1.000000 1.000000 \n')
|
|
fp.write(f'Ks 0.000000 0.000000 0.000000 \n')
|
|
fp.write(f'Tr 1.000000 \n')
|
|
fp.write(f'illum 1 \n')
|
|
fp.write(f'Ns 0.000000 \n')
|
|
fp.write(f'map_Kd {name}albedo.png \n')
|
|
|
|
_export(v, f)
|
|
|
|
def run(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, **kwargs):
|
|
# rays_o, rays_d: [B, N, 3]
|
|
# bg_color: [BN, 3] in range [0, 1]
|
|
# return: image: [B, N, 3], depth: [B, N]
|
|
|
|
prefix = rays_o.shape[:-1]
|
|
rays_o = rays_o.contiguous().view(-1, 3)
|
|
rays_d = rays_d.contiguous().view(-1, 3)
|
|
|
|
N = rays_o.shape[0] # N = B * N, in fact
|
|
device = rays_o.device
|
|
|
|
results = {}
|
|
|
|
# choose aabb
|
|
aabb = self.aabb_train if self.training else self.aabb_infer
|
|
|
|
# sample steps
|
|
# nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb, self.min_near)
|
|
# nears.unsqueeze_(-1)
|
|
# fars.unsqueeze_(-1)
|
|
nears, fars = near_far_from_bound(rays_o, rays_d, self.bound, type='sphere', min_near=self.min_near)
|
|
|
|
# random sample light_d if not provided
|
|
if light_d is None:
|
|
# gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)
|
|
if self.training:
|
|
light_d = safe_normalize(rays_o + torch.randn(3, device=rays_o.device)) # [N, 3]
|
|
else:
|
|
light_d = safe_normalize(rays_o[0:1] + torch.randn(3, device=rays_o.device)) # [N, 3]
|
|
|
|
#print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}')
|
|
|
|
z_vals = torch.linspace(0.0, 1.0, self.opt.num_steps, device=device).unsqueeze(0) # [1, T]
|
|
z_vals = z_vals.expand((N, self.opt.num_steps)) # [N, T]
|
|
z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars]
|
|
|
|
# perturb z_vals
|
|
sample_dist = (fars - nears) / self.opt.num_steps
|
|
if perturb:
|
|
z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist
|
|
#z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs.
|
|
|
|
# generate xyzs
|
|
xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3]
|
|
xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip.
|
|
|
|
#plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
|
|
|
|
# query SDF and RGB
|
|
density_outputs = self.density(xyzs.reshape(-1, 3))
|
|
|
|
#sigmas = density_outputs['sigma'].view(N, self.opt.num_steps) # [N, T]
|
|
for k, v in density_outputs.items():
|
|
density_outputs[k] = v.view(N, self.opt.num_steps, -1)
|
|
|
|
# upsample z_vals (nerf-like)
|
|
if self.opt.upsample_steps > 0:
|
|
with torch.no_grad():
|
|
|
|
deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1]
|
|
deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
|
|
|
|
alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T]
|
|
alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1]
|
|
weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T]
|
|
|
|
# sample new z_vals
|
|
z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1]
|
|
new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], self.opt.upsample_steps, det=not self.training).detach() # [N, t]
|
|
|
|
new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3]
|
|
new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip.
|
|
|
|
# only forward new points to save computation
|
|
new_density_outputs = self.density(new_xyzs.reshape(-1, 3))
|
|
#new_sigmas = new_density_outputs['sigma'].view(N, self.opt.upsample_steps) # [N, t]
|
|
for k, v in new_density_outputs.items():
|
|
new_density_outputs[k] = v.view(N, self.opt.upsample_steps, -1)
|
|
|
|
# re-order
|
|
z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t]
|
|
z_vals, z_index = torch.sort(z_vals, dim=1)
|
|
|
|
xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3]
|
|
xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs))
|
|
|
|
for k in density_outputs:
|
|
tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1)
|
|
density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output))
|
|
|
|
deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1]
|
|
deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
|
|
alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T+t]
|
|
alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1]
|
|
weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t]
|
|
|
|
dirs = rays_d.view(-1, 1, 3).expand_as(xyzs)
|
|
light_d = light_d.view(-1, 1, 3).expand_as(xyzs)
|
|
for k, v in density_outputs.items():
|
|
density_outputs[k] = v.view(-1, v.shape[-1])
|
|
|
|
dirs = safe_normalize(dirs)
|
|
sigmas, rgbs, normals = self(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), light_d, ratio=ambient_ratio, shading=shading)
|
|
rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3]
|
|
if normals is not None:
|
|
normals = normals.view(N, -1, 3)
|
|
|
|
# calculate weight_sum (mask)
|
|
weights_sum = weights.sum(dim=-1) # [N]
|
|
|
|
# calculate depth
|
|
depth = torch.sum(weights * z_vals, dim=-1)
|
|
|
|
# calculate color
|
|
image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1]
|
|
|
|
# mix background color
|
|
if bg_color is None:
|
|
if self.opt.bg_radius > 0:
|
|
# use the bg model to calculate bg_color
|
|
bg_color = self.background(rays_d) # [N, 3]
|
|
else:
|
|
bg_color = 1
|
|
|
|
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
|
|
|
|
image = image.view(*prefix, 3)
|
|
depth = depth.view(*prefix)
|
|
weights_sum = weights_sum.reshape(*prefix)
|
|
|
|
if self.training:
|
|
if self.opt.lambda_orient > 0 and normals is not None:
|
|
# orientation loss
|
|
loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2
|
|
results['loss_orient'] = loss_orient.sum(-1).mean()
|
|
|
|
if self.opt.lambda_3d_normal_smooth > 0 and normals is not None:
|
|
normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2)
|
|
results['loss_normal_perturb'] = (normals - normals_perturb).abs().mean()
|
|
|
|
if normals is not None:
|
|
normal_image = torch.sum(
|
|
weights.unsqueeze(-1) * (normals + 1) / 2, dim=-2) # [N, 3], in [0, 1]
|
|
results['normal_image'] = normal_image
|
|
|
|
results['image'] = image
|
|
results['depth'] = depth
|
|
results['weights'] = weights
|
|
results['weights_sum'] = weights_sum
|
|
|
|
return results
|
|
|
|
|
|
def run_cuda(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, T_thresh=1e-4, binarize=False, **kwargs):
|
|
# rays_o, rays_d: [B, N, 3]
|
|
# return: image: [B, N, 3], depth: [B, N]
|
|
|
|
prefix = rays_o.shape[:-1]
|
|
rays_o = rays_o.contiguous().view(-1, 3)
|
|
rays_d = rays_d.contiguous().view(-1, 3)
|
|
|
|
N = rays_o.shape[0] # B * N, in fact
|
|
device = rays_o.device
|
|
|
|
# pre-calculate near far
|
|
nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer)
|
|
|
|
# random sample light_d if not provided
|
|
if light_d is None:
|
|
# gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)
|
|
if self.training:
|
|
light_d = safe_normalize(rays_o[0:1] + torch.randn(3, device=rays_o.device)) # [N, 3]
|
|
else:
|
|
light_d = safe_normalize(rays_o[0:1] + torch.randn(3, device=rays_o.device)) # [N, 3]
|
|
|
|
results = {}
|
|
|
|
if self.training:
|
|
xyzs, dirs, ts, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, perturb, self.opt.dt_gamma, self.opt.max_steps)
|
|
dirs = safe_normalize(dirs)
|
|
|
|
if light_d.shape[0] > 1:
|
|
flatten_rays = raymarching.flatten_rays(rays, xyzs.shape[0]).long()
|
|
light_d = light_d[flatten_rays]
|
|
|
|
sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading)
|
|
weights, weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, ts, rays, T_thresh, binarize)
|
|
|
|
# normals related regularizations
|
|
if self.opt.lambda_orient > 0 and normals is not None:
|
|
# orientation loss
|
|
loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2
|
|
results['loss_orient'] = loss_orient.mean()
|
|
|
|
if self.opt.lambda_3d_normal_smooth > 0 and normals is not None:
|
|
normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2)
|
|
results['loss_normal_perturb'] = (normals - normals_perturb).abs().mean()
|
|
|
|
if normals is not None:
|
|
_, _, _, normal_image = raymarching.composite_rays_train(sigmas.detach(), (normals + 1) / 2, ts, rays, T_thresh, binarize)
|
|
results['normal_image'] = normal_image
|
|
|
|
# weights normalization
|
|
results['weights'] = weights
|
|
|
|
else:
|
|
|
|
# allocate outputs
|
|
dtype = torch.float32
|
|
|
|
weights_sum = torch.zeros(N, dtype=dtype, device=device)
|
|
depth = torch.zeros(N, dtype=dtype, device=device)
|
|
image = torch.zeros(N, 3, dtype=dtype, device=device)
|
|
|
|
n_alive = N
|
|
rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]
|
|
rays_t = nears.clone() # [N]
|
|
|
|
step = 0
|
|
|
|
while step < self.opt.max_steps: # hard coded max step
|
|
|
|
# count alive rays
|
|
n_alive = rays_alive.shape[0]
|
|
|
|
# exit loop
|
|
if n_alive <= 0:
|
|
break
|
|
|
|
# decide compact_steps
|
|
n_step = max(min(N // n_alive, 8), 1)
|
|
|
|
xyzs, dirs, ts = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, perturb if step == 0 else False, self.opt.dt_gamma, self.opt.max_steps)
|
|
dirs = safe_normalize(dirs)
|
|
sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading)
|
|
raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, ts, weights_sum, depth, image, T_thresh, binarize)
|
|
|
|
rays_alive = rays_alive[rays_alive >= 0]
|
|
#print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')
|
|
|
|
step += n_step
|
|
|
|
# mix background color
|
|
if bg_color is None:
|
|
if self.opt.bg_radius > 0:
|
|
# use the bg model to calculate bg_color
|
|
bg_color = self.background(rays_d) # [N, 3]
|
|
else:
|
|
bg_color = 1
|
|
|
|
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
|
|
image = image.view(*prefix, 3)
|
|
|
|
depth = depth.view(*prefix)
|
|
|
|
weights_sum = weights_sum.reshape(*prefix)
|
|
|
|
results['image'] = image
|
|
results['depth'] = depth
|
|
results['weights_sum'] = weights_sum
|
|
|
|
return results
|
|
|
|
def get_sdf_albedo_for_init(self, points=None):
|
|
output = self.density(self.dmtet.verts if points is None else points)
|
|
sigma, albedo = output['sigma'], output['albedo']
|
|
return sigma - self.density_thresh, albedo
|
|
|
|
def run_dmtet(self, rays_o, rays_d, mvp, h, w, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, **kwargs):
|
|
# mvp: [B, 4, 4]
|
|
|
|
device = mvp.device
|
|
campos = rays_o[:, 0, :] # only need one ray per batch
|
|
|
|
# random sample light_d if not provided
|
|
if light_d is None:
|
|
# gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)
|
|
light_d = safe_normalize(campos + torch.randn_like(campos)).view(-1, 1, 1, 3) # [B, 1, 1, 3]
|
|
|
|
results = {}
|
|
|
|
verts, faces = self.dmtet.get_verts_face()
|
|
|
|
# get normals
|
|
i0, i1, i2 = faces[:, 0], faces[:, 1], faces[:, 2]
|
|
v0, v1, v2 = verts[i0, :], verts[i1, :], verts[i2, :]
|
|
|
|
faces = faces.int()
|
|
|
|
face_normals = torch.cross(v1 - v0, v2 - v0)
|
|
face_normals = safe_normalize(face_normals)
|
|
|
|
vn = torch.zeros_like(verts)
|
|
vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)
|
|
vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)
|
|
vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)
|
|
|
|
vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))
|
|
|
|
# rasterization
|
|
verts_clip = torch.bmm(F.pad(verts, pad=(0, 1), mode='constant', value=1.0).unsqueeze(0).repeat(mvp.shape[0], 1, 1),
|
|
mvp.permute(0,2,1)).float() # [B, N, 4]
|
|
rast, rast_db = dr.rasterize(self.glctx, verts_clip, faces, (h, w))
|
|
|
|
alpha, _ = dr.interpolate(torch.ones_like(verts[:, :1]).unsqueeze(0), rast, faces) # [B, H, W, 1]
|
|
xyzs, _ = dr.interpolate(verts.unsqueeze(0), rast, faces) # [B, H, W, 3]
|
|
normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, faces)
|
|
normal = safe_normalize(normal)
|
|
|
|
xyzs = xyzs.view(-1, 3)
|
|
mask = (alpha > 0).view(-1).detach()
|
|
|
|
# do the lighting here since we have normal from mesh now.
|
|
albedo = torch.zeros_like(xyzs, dtype=torch.float32)
|
|
if mask.any():
|
|
masked_albedo = self.density(xyzs[mask])['albedo']
|
|
albedo[mask] = masked_albedo.float()
|
|
albedo = albedo.view(-1, h, w, 3)
|
|
|
|
if shading == 'albedo':
|
|
color = albedo
|
|
elif shading == 'textureless':
|
|
lambertian = ambient_ratio + (1 - ambient_ratio) * (normal * light_d).sum(-1).float().clamp(min=0)
|
|
color = lambertian.unsqueeze(-1).repeat(1, 1, 1, 3)
|
|
elif shading == 'normal':
|
|
color = (normal + 1) / 2
|
|
else: # 'lambertian'
|
|
lambertian = ambient_ratio + (1 - ambient_ratio) * (normal * light_d).sum(-1).float().clamp(min=0)
|
|
color = albedo * lambertian.unsqueeze(-1)
|
|
|
|
color = dr.antialias(color, rast, verts_clip, faces).clamp(0, 1) # [B, H, W, 3]
|
|
alpha = dr.antialias(alpha, rast, verts_clip, faces).clamp(0, 1) # [B, H, W, 1]
|
|
|
|
# mix background color
|
|
if bg_color is None:
|
|
if self.opt.bg_radius > 0:
|
|
# use the bg model to calculate bg_color
|
|
bg_color = self.background(rays_d) # [N, 3]
|
|
else:
|
|
bg_color = 1
|
|
|
|
if torch.is_tensor(bg_color) and len(bg_color.shape) > 1:
|
|
bg_color = bg_color.view(-1, h, w, 3)
|
|
|
|
depth = rast[:, :, :, [2]] # [B, H, W]
|
|
color = color + (1 - alpha) * bg_color
|
|
|
|
results['depth'] = depth
|
|
results['image'] = color
|
|
results['weights_sum'] = alpha.squeeze(-1)
|
|
|
|
normal_image = dr.antialias((normal + 1) / 2, rast, verts_clip, faces).clamp(0, 1) # [B, H, W, 3]
|
|
results['normal_image'] = normal_image
|
|
|
|
# regularizations
|
|
if self.training:
|
|
if self.opt.lambda_mesh_normal > 0:
|
|
results['loss_normal'] = normal_consistency(
|
|
face_normals, faces)
|
|
if self.opt.lambda_mesh_lap > 0:
|
|
results['loss_lap'] = laplacian_smooth_loss(verts, faces)
|
|
|
|
return results
|
|
|
|
def run_taichi(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, T_thresh=1e-4, **kwargs):
|
|
# rays_o, rays_d: [B, N, 3], assumes B == 1
|
|
# return: image: [B, N, 3], depth: [B, N]
|
|
|
|
prefix = rays_o.shape[:-1]
|
|
rays_o = rays_o.contiguous().view(-1, 3)
|
|
rays_d = rays_d.contiguous().view(-1, 3)
|
|
|
|
N = rays_o.shape[0] # N = B * N, in fact
|
|
device = rays_o.device
|
|
|
|
# pre-calculate near far
|
|
exp_step_factor = kwargs.get('exp_step_factor', 0.)
|
|
MAX_SAMPLES = 1024
|
|
NEAR_DISTANCE = 0.01
|
|
center = torch.zeros(1, 3)
|
|
half_size = torch.ones(1, 3)
|
|
_, hits_t, _ = self.ray_aabb_intersector.apply(rays_o, rays_d, center, half_size, 1)
|
|
hits_t[(hits_t[:, 0, 0] >= 0) & (hits_t[:, 0, 0] < NEAR_DISTANCE), 0, 0] = NEAR_DISTANCE
|
|
|
|
# TODO: should sample different light_d for each batch... but taichi end doesn't have a flatten_ray implemented currently...
|
|
# random sample light_d if not provided
|
|
if light_d is None:
|
|
# gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)
|
|
light_d = (rays_o[0] + torch.randn(3, device=device, dtype=torch.float))
|
|
light_d = safe_normalize(light_d)
|
|
|
|
results = {}
|
|
|
|
if self.training:
|
|
rays_a, xyzs, dirs, deltas, ts, _ = self.ray_marching(rays_o, rays_d, hits_t[:, 0], self.density_bitfield, self.cascade, self.bound, exp_step_factor, self.grid_size, MAX_SAMPLES)
|
|
dirs = safe_normalize(dirs)
|
|
# plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
|
|
sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading)
|
|
_, weights_sum, depth, image, weights = self.volume_render(sigmas, rgbs, deltas, ts, rays_a, kwargs.get('T_threshold', 1e-4))
|
|
|
|
# normals related regularizations
|
|
if self.opt.lambda_orient > 0 and normals is not None:
|
|
# orientation loss
|
|
loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2
|
|
results['loss_orient'] = loss_orient.mean()
|
|
|
|
if self.opt.lambda_3d_normal_smooth > 0 and normals is not None:
|
|
normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2)
|
|
results['loss_normal_perturb'] = (normals - normals_perturb).abs().mean()
|
|
|
|
if normals is not None:
|
|
_, _, _, normal_image, _ = self.volume_render(sigmas.detach(), (normals + 1) / 2, deltas, ts, rays_a, kwargs.get('T_threshold', 1e-4))
|
|
results['normal_image'] = normal_image
|
|
|
|
# weights normalization
|
|
results['weights'] = weights
|
|
|
|
else:
|
|
|
|
# allocate outputs
|
|
dtype = torch.float32
|
|
|
|
weights_sum = torch.zeros(N, dtype=dtype, device=device)
|
|
depth = torch.zeros(N, dtype=dtype, device=device)
|
|
image = torch.zeros(N, 3, dtype=dtype, device=device)
|
|
|
|
n_alive = N
|
|
rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]
|
|
rays_t = hits_t[:, 0, 0]
|
|
step = 0
|
|
|
|
min_samples = 1 if exp_step_factor == 0 else 4
|
|
|
|
while step < self.opt.max_steps: # hard coded max step
|
|
|
|
# count alive rays
|
|
n_alive = rays_alive.shape[0]
|
|
|
|
# exit loop
|
|
if n_alive <= 0:
|
|
break
|
|
|
|
# decide compact_steps
|
|
# n_step = max(min(N // n_alive, 8), 1)
|
|
n_step = max(min(N // n_alive, 64), min_samples)
|
|
|
|
xyzs, dirs, deltas, ts, N_eff_samples = \
|
|
self.raymarching_test_taichi(rays_o, rays_d, hits_t[:, 0], rays_alive,
|
|
self.density_bitfield, self.cascade,
|
|
self.bound, exp_step_factor,
|
|
self.grid_size, MAX_SAMPLES, n_step)
|
|
|
|
xyzs = self.rearrange(xyzs, 'n1 n2 c -> (n1 n2) c')
|
|
dirs = self.rearrange(dirs, 'n1 n2 c -> (n1 n2) c')
|
|
dirs = safe_normalize(dirs)
|
|
valid_mask = ~torch.all(dirs == 0, dim=1)
|
|
if valid_mask.sum() == 0:
|
|
break
|
|
|
|
sigmas = torch.zeros(len(xyzs), device=device)
|
|
rgbs = torch.zeros(len(xyzs), 3, device=device)
|
|
normals = torch.zeros(len(xyzs), 3, device=device)
|
|
|
|
sigmas[valid_mask], _rgbs, normals = self(xyzs[valid_mask], dirs[valid_mask], light_d, ratio=ambient_ratio, shading=shading)
|
|
rgbs[valid_mask] = _rgbs.float()
|
|
sigmas = self.rearrange(sigmas, '(n1 n2) -> n1 n2', n2=n_step)
|
|
rgbs = self.rearrange(rgbs, '(n1 n2) c -> n1 n2 c', n2=n_step)
|
|
if normals is not None:
|
|
normals = self.rearrange(normals, '(n1 n2) c -> n1 n2 c', n2=n_step)
|
|
|
|
self.composite_test_fw(sigmas, rgbs, deltas, ts, hits_t[:,0], rays_alive,
|
|
kwargs.get('T_threshold', 1e-4), N_eff_samples,
|
|
weights_sum, depth, image)
|
|
|
|
rays_alive = rays_alive[rays_alive >= 0]
|
|
|
|
step += n_step
|
|
|
|
# mix background color
|
|
if bg_color is None:
|
|
if self.opt.bg_radius > 0:
|
|
# use the bg model to calculate bg_color
|
|
bg_color = self.background(rays_d) # [N, 3]
|
|
else:
|
|
bg_color = 1
|
|
|
|
image = image + self.rearrange(1 - weights_sum, 'n -> n 1') * bg_color
|
|
image = image.view(*prefix, 3)
|
|
|
|
depth = depth.view(*prefix)
|
|
|
|
weights_sum = weights_sum.reshape(*prefix)
|
|
|
|
results['image'] = image
|
|
results['depth'] = depth
|
|
results['weights_sum'] = weights_sum
|
|
|
|
return results
|
|
|
|
|
|
@torch.no_grad()
|
|
def update_extra_state(self, decay=0.95, S=128):
|
|
# call before each epoch to update extra states.
|
|
|
|
if not (self.cuda_ray or self.taichi_ray):
|
|
return
|
|
|
|
### update density grid
|
|
tmp_grid = - torch.ones_like(self.density_grid)
|
|
|
|
X = torch.arange(self.grid_size, dtype=torch.int32, device=self.aabb_train.device).split(S)
|
|
Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.aabb_train.device).split(S)
|
|
Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.aabb_train.device).split(S)
|
|
|
|
for xs in X:
|
|
for ys in Y:
|
|
for zs in Z:
|
|
|
|
# construct points
|
|
xx, yy, zz = custom_meshgrid(xs, ys, zs)
|
|
coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
|
|
indices = raymarching.morton3D(coords).long() # [N]
|
|
xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
|
|
|
|
# cascading
|
|
for cas in range(self.cascade):
|
|
bound = min(2 ** cas, self.bound)
|
|
half_grid_size = bound / self.grid_size
|
|
# scale to current cascade's resolution
|
|
cas_xyzs = xyzs * (bound - half_grid_size)
|
|
# add noise in [-hgs, hgs]
|
|
cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
|
|
# query density
|
|
sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach()
|
|
# assign
|
|
tmp_grid[cas, indices] = sigmas
|
|
# ema update
|
|
valid_mask = self.density_grid >= 0
|
|
self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])
|
|
self.mean_density = torch.mean(self.density_grid[valid_mask]).item()
|
|
self.iter_density += 1
|
|
|
|
# convert to bitfield
|
|
density_thresh = min(self.mean_density, self.density_thresh)
|
|
if self.cuda_ray:
|
|
self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield)
|
|
elif self.taichi_ray:
|
|
self.packbits_taichi(self.density_grid.reshape(-1).contiguous(), density_thresh, self.density_bitfield)
|
|
|
|
# print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > density_thresh).sum() / (128**3 * self.cascade):.3f}')
|
|
|
|
|
|
def render(self, rays_o, rays_d, mvp, h, w, staged=False, max_ray_batch=4096, **kwargs):
|
|
# rays_o, rays_d: [B, N, 3]
|
|
# return: pred_rgb: [B, N, 3]
|
|
B, N = rays_o.shape[:2]
|
|
device = rays_o.device
|
|
|
|
if self.dmtet:
|
|
results = self.run_dmtet(rays_o, rays_d, mvp, h, w, **kwargs)
|
|
elif self.cuda_ray:
|
|
results = self.run_cuda(rays_o, rays_d, **kwargs)
|
|
elif self.taichi_ray:
|
|
results = self.run_taichi(rays_o, rays_d, **kwargs)
|
|
else:
|
|
if staged:
|
|
depth = torch.empty((B, N), device=device)
|
|
image = torch.empty((B, N, 3), device=device)
|
|
weights_sum = torch.empty((B, N), device=device)
|
|
|
|
for b in range(B):
|
|
head = 0
|
|
while head < N:
|
|
tail = min(head + max_ray_batch, N)
|
|
results_ = self.run(rays_o[b:b+1, head:tail], rays_d[b:b+1, head:tail], **kwargs)
|
|
depth[b:b+1, head:tail] = results_['depth']
|
|
weights_sum[b:b+1, head:tail] = results_['weights_sum']
|
|
image[b:b+1, head:tail] = results_['image']
|
|
head += max_ray_batch
|
|
|
|
results = {}
|
|
results['depth'] = depth
|
|
results['image'] = image
|
|
results['weights_sum'] = weights_sum
|
|
|
|
else:
|
|
results = self.run(rays_o, rays_d, **kwargs)
|
|
|
|
return results
|
|
|
|
def init_tet_from_nerf(self, reset_scale=True):
|
|
sdf = self.get_sdf_from_nerf(reset_scale=reset_scale)
|
|
self.dmtet.init_tet_from_sdf(sdf)
|
|
logger.info(f'init dmtet from NeRF Done ...')
|
|
|
|
|
|
@torch.no_grad()
|
|
def get_sdf_from_nerf(self, reset_scale=True):
|
|
if self.cuda_ray:
|
|
density_thresh = min(self.mean_density, self.density_thresh)
|
|
else:
|
|
density_thresh = self.density_thresh
|
|
|
|
if reset_scale:
|
|
# init scale
|
|
sigma = self.density(self.dmtet.verts)[
|
|
'sigma'] # verts covers [-1, 1] now
|
|
mask = sigma > density_thresh
|
|
valid_verts = self.dmtet.verts[mask]
|
|
tet_scale = valid_verts.abs().amax(dim=0) + 1e-1
|
|
self.dmtet.reset_tet_scale(tet_scale)
|
|
sdf = (self.density(self.dmtet.verts)[
|
|
'sigma'] - density_thresh).clamp(-1, 1)
|
|
return sdf
|