import os import math import cv2 import numpy as np from tqdm import tqdm import torch import torch.nn as nn import torch.nn.functional as F from encoding import get_encoder import nvdiffrast.torch as dr import mcubes import raymarching from .utils import custom_meshgrid, safe_normalize import logging from activation import trunc_exp, biased_softplus logger = logging.getLogger(__name__) class MLP(nn.Module): def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True): super().__init__() self.dim_in = dim_in self.dim_out = dim_out self.dim_hidden = dim_hidden self.num_layers = num_layers net = [] for l in range(num_layers): net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias)) self.net = nn.ModuleList(net) def forward(self, x): for l in range(self.num_layers): x = self.net[l](x) if l != self.num_layers - 1: x = F.relu(x, inplace=True) return x def reset_parameters(self): @torch.no_grad() def weight_init(m): if isinstance(m, nn.Linear): nn.init.xavier_uniform_( m.weight, gain=nn.init.calculate_gain('relu')) if m.bias is not None: nn.init.zeros_(m.bias) self.apply(weight_init) def sample_pdf(bins, weights, n_samples, det=False): # This implementation is from NeRF # bins: [B, T], old_z_vals # weights: [B, T - 1], bin weights. # return: [B, n_samples], new_z_vals # Get pdf weights = weights + 1e-5 # prevent nans pdf = weights / torch.sum(weights, -1, keepdim=True) cdf = torch.cumsum(pdf, -1) cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # Take uniform samples if det: u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device) u = u.expand(list(cdf.shape[:-1]) + [n_samples]) else: u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device) # Invert CDF u = u.contiguous() inds = torch.searchsorted(cdf, u, right=True) below = torch.max(torch.zeros_like(inds - 1), inds - 1) above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) inds_g = torch.stack([below, above], -1) # (B, n_samples, 2) matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) denom = (cdf_g[..., 1] - cdf_g[..., 0]) denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) return samples @torch.cuda.amp.autocast(enabled=False) def near_far_from_bound(rays_o, rays_d, bound, type='cube', min_near=0.05): # rays: [B, N, 3], [B, N, 3] # bound: int, radius for ball or half-edge-length for cube # return near [B, N, 1], far [B, N, 1] radius = rays_o.norm(dim=-1, keepdim=True) if type == 'sphere': near = radius - bound # [B, N, 1] far = radius + bound elif type == 'cube': tmin = (-bound - rays_o) / (rays_d + 1e-15) # [B, N, 3] tmax = (bound - rays_o) / (rays_d + 1e-15) near = torch.where(tmin < tmax, tmin, tmax).max(dim=-1, keepdim=True)[0] far = torch.where(tmin > tmax, tmin, tmax).min(dim=-1, keepdim=True)[0] # if far < near, means no intersection, set both near and far to inf (1e9 here) mask = far < near near[mask] = 1e9 far[mask] = 1e9 # restrict near to a minimal value near = torch.clamp(near, min=min_near) return near, far def plot_pointcloud(pc, color=None): import trimesh # pc: [N, 3] # color: [N, 3/4] logger.info('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0)) pc = trimesh.PointCloud(pc, color) # axis axes = trimesh.creation.axis(axis_length=4) # sphere sphere = trimesh.creation.icosphere(radius=1) trimesh.Scene([pc, axes, sphere]).show() class DMTet: def __init__(self, device='cuda'): self.device = device self.triangle_table = torch.tensor([ [-1, -1, -1, -1, -1, -1], [1, 0, 2, -1, -1, -1], [4, 0, 3, -1, -1, -1], [1, 4, 2, 1, 3, 4], [3, 1, 5, -1, -1, -1], [2, 3, 0, 2, 5, 3], [1, 4, 0, 1, 5, 4], [4, 2, 5, -1, -1, -1], [4, 5, 2, -1, -1, -1], [4, 1, 0, 4, 5, 1], [3, 2, 0, 3, 5, 2], [1, 3, 5, -1, -1, -1], [4, 1, 2, 4, 3, 1], [3, 0, 4, -1, -1, -1], [2, 0, 1, -1, -1, -1], [-1, -1, -1, -1, -1, -1] ], dtype=torch.long, device=self.device) self.num_triangles_table = torch.tensor( [0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=self.device) self.base_tet_edges = torch.tensor( [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=self.device) ############################################################################### # Utility functions ############################################################################### def sort_edges(self, edges_ex2): with torch.no_grad(): order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long() order = order.unsqueeze(dim=1) a = torch.gather(input=edges_ex2, index=order, dim=1) b = torch.gather(input=edges_ex2, index=1-order, dim=1) return torch.stack([a, b], -1) def map_uv(self, faces, face_gidx, max_idx): N = int(np.ceil(np.sqrt((max_idx+1)//2))) tex_y, tex_x = torch.meshgrid( torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device=self.device), torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device=self.device), ) # indexing='ij') pad = 0.9 / N uvs = torch.stack([ tex_x, tex_y, tex_x + pad, tex_y, tex_x + pad, tex_y + pad, tex_x, tex_y + pad ], dim=-1).view(-1, 2) def _idx(tet_idx, N): x = tet_idx % N y = torch.div(tet_idx, N, rounding_mode='trunc') return y * N + x tet_idx = _idx(torch.div(face_gidx, 2, rounding_mode='trunc'), N) tri_idx = face_gidx % 2 uv_idx = torch.stack(( tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2 ), dim=-1). view(-1, 3) return uvs, uv_idx ############################################################################### # Marching tets implementation ############################################################################### def __call__(self, pos_nx3, sdf_n, tet_fx4, return_uv=True): # pos_nx3: [N, 3] # sdf_n: [N] # tet_fx4: [F, 4] with torch.no_grad(): occ_n = sdf_n > 0 occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) occ_sum = torch.sum(occ_fx4, -1) # [F,] # a valid tets not all positive (out space) and not all negative (inner) valid_tets = (occ_sum > 0) & (occ_sum < 4) occ_sum = occ_sum[valid_tets] # find all vertices all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2) all_edges = self.sort_edges(all_edges) unique_edges, idx_map = torch.unique( all_edges, dim=0, return_inverse=True) # find out the edges across the surface to interpolate and refine unique_edges = unique_edges.long() mask_edges = occ_n[unique_edges.reshape(-1) ].reshape(-1, 2).sum(-1) == 1 mapping = torch.ones( (unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1 mapping[mask_edges] = torch.arange( mask_edges.sum(), dtype=torch.long, device=self.device) idx_map = mapping[idx_map] # map edges to verts interp_v = unique_edges[mask_edges] edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) edges_to_interp_sdf[:, -1] *= -1 denominator = edges_to_interp_sdf.sum(1, keepdim=True) edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1])/denominator # interpolate edges by sdf verts = (edges_to_interp * edges_to_interp_sdf).sum(1) idx_map = idx_map.reshape(-1, 6) v_id = torch.pow(2, torch.arange( 4, dtype=torch.long, device=self.device)) tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) num_triangles = self.num_triangles_table[tetindex] # Generate triangle indices faces = torch.cat(( torch.gather(input=idx_map[num_triangles == 1], dim=1, index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3), torch.gather(input=idx_map[num_triangles == 2], dim=1, index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3), ), dim=0) if return_uv: # Get global face index (static, does not depend on topology) num_tets = tet_fx4.shape[0] tet_gidx = torch.arange(num_tets, dtype=torch.long, device=self.device)[ valid_tets] face_gidx = torch.cat(( tet_gidx[num_triangles == 1]*2, torch.stack( (tet_gidx[num_triangles == 2]*2, tet_gidx[num_triangles == 2]*2 + 1), dim=-1).view(-1) ), dim=0) uvs, uv_idx = self.map_uv(faces, face_gidx, num_tets*2) else: uvs, uv_idx = None, None return verts, faces, uvs, uv_idx ############################################################################### # Regularizer ############################################################################### def sdf_reg_loss(sdf, all_edges): sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1, 2) mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1]) sdf_f1x6x2 = sdf_f1x6x2[mask] sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \ torch.nn.functional.binary_cross_entropy_with_logits( sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float()) return sdf_diff ############################################################################### # Geometry interface ############################################################################### class DMTetGeometry(torch.nn.Module): def __init__(self, grid_res, tet_mlp, opt, device='cuda'): super(DMTetGeometry, self).__init__() self.opt = opt self.device = device self.tet_scale = torch.ones(3, device=device) self.grid_res = grid_res self.marching_tets = DMTet() tets = np.load('data/tets/{}_tets.npz'.format(self.grid_res)) # for 64/128, [N=36562/277410, 3], in [-0.5, 0.5]^3 self.verts = torch.tensor( tets['vertices'], dtype=torch.float32, device=self.device) * 2 # for 64/128, [M=192492/1524684, 4], vert indices for each tetrahetron self.indices = torch.tensor( tets['indices'], dtype=torch.long, device=self.device) self.generate_edges() self.tet_mlp = tet_mlp if tet_mlp: self.encoder, self.in_dim = get_encoder('hashgrid', input_dim=3) self.encoder = self.encoder.to(device) self.mlp = MLP(self.in_dim, 4, 32, 3, False).to(device) self.sdf = None else: sdf = torch.nn.Parameter(torch.zeros_like( self.verts[..., 0]), requires_grad=True) self.register_parameter('sdf', sdf) deform = torch.nn.Parameter( torch.zeros_like(self.verts), requires_grad=True) self.register_parameter('deform', deform) if opt.base_mesh and os.path.exists(opt.base_mesh): self.init_tet_from_mesh(opt.base_mesh) def reset_tet(self, reset_scale=True): if self.tet_mlp: self.mlp.reset_parameters() else: self.sdf.data = torch.zeros_like(self.verts[..., 0]) self.deform.data = torch.zeros_like(self.verts) if reset_scale: self.reset_tet_scale() def get_sdf_from_mesh(self, base_mesh): logger.info(f'[INFO] init sdf from base mesh: {base_mesh}') import cubvh import trimesh mesh = trimesh.load(base_mesh, force='mesh') scale = 1.5 / np.array(mesh.bounds[1] - mesh.bounds[0]).max() center = np.array(mesh.bounds[1] + mesh.bounds[0]) / 2 mesh.vertices = (mesh.vertices - center) * scale # build with numpy.ndarray/torch.Tensor BVH = cubvh.cuBVH(mesh.vertices, mesh.faces) sdf, face_id, _ = BVH.signed_distance( self.verts, return_uvw=False, mode='watertight') sdf *= -1 # INNER is POSITIVE return sdf def init_tet_from_mesh(self, base_mesh): sdf = self.get_sdf_from_mesh(base_mesh) self.init_tet_from_sdf(sdf) # visualize # sdf_np_gt = sdf.cpu().numpy() # sdf_np = self.mlp(self.encoder(self.verts)).detach().cpu().numpy()[..., 0] # verts_np = self.verts.cpu().numpy() # color = np.zeros_like(verts_np) # color[sdf_np < 0] = [1, 0, 0] # color[sdf_np > 0] = [0, 0, 1] # color = (color * 255).astype(np.uint8) # pc = trimesh.PointCloud(verts_np, color) # axes = trimesh.creation.axis(axis_length=4) # box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline() # trimesh.Scene([mesh, pc, axes, box]).show() def init_tet_from_sdf(self, sdf, pretrain_iters=5000, lr=1e-3): if self.tet_mlp: self.mlp.reset_parameters() # pretraining loss_fn = torch.nn.MSELoss() optimizer = torch.optim.Adam(list(self.parameters()), lr=lr) #batch_size = min(10240, self.verts.shape[0]) batch_size = self.verts.shape[0] pbar = tqdm(range(pretrain_iters), desc="init dmtet mlp from sdf") for i in pbar: rand_idx = torch.randint(0, self.verts.shape[0], (batch_size,)) p = self.verts[rand_idx] ref_value = sdf[rand_idx] output = self.mlp(self.encoder(p)) loss = loss_fn(output[..., 0], ref_value) optimizer.zero_grad() loss.backward() optimizer.step() pbar.set_postfix(loss=loss.item()) else: self.sdf.data = sdf.squeeze() @torch.no_grad() def reset_tet_scale(self, tet_scale=1.): if isinstance(tet_scale, float): tet_scale = torch.ones(3, device=self.device) * tet_scale self.tet_scale = tet_scale self.verts = self.verts * tet_scale @torch.no_grad() def generate_edges(self): # six edges for each tetrahedron. edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=self.device) all_edges = self.indices[:, edges].reshape(-1, 2) # [M * 6, 2] all_edges_sorted = torch.sort(all_edges, dim=1)[0] self.all_edges = torch.unique(all_edges_sorted, dim=0) def get_sdf_deform(self): if self.tet_mlp: # predict SDF and per-vertex deformation pred = self.mlp(self.encoder(self.verts)) sdf, deform = pred[:, 0], pred[:, 1:] return sdf, torch.tanh(deform) / (self.grid_res) else: return self.sdf, torch.tanh(self.deform) / (self.grid_res) def get_verts_face(self): sdf, deform = self.get_sdf_deform() verts, faces, _, _ = self.marching_tets( self.verts + deform, sdf, self.indices, return_uv=False) return verts, faces # def getAABB(self): # return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values # def getMesh(self, material): # pred = self.mlp(self.encoder(self.verts)) # predict SDF and per-vertex deformation # sdf, deform = pred[:, 0], pred[:, 1:] # v_deformed = self.verts + torch.tanh(deform) / (self.grid_res) # verts, faces, uvs, uv_idx = self.marching_tets(v_deformed, sdf, self.indices) # imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material) # # Run mesh operations to generate tangent space # imesh = mesh.auto_normals(imesh) # imesh = mesh.compute_tangents(imesh) # return imesh, sdf # def render(self, glctx, target, lgt, opt_material, bsdf=None): # # return rendered buffers, keys: ['shaded', 'kd_grad', 'occlusion']. # opt_mesh, sdf = self.getMesh(opt_material) # buffers = render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], # msaa=True, background=None, bsdf=bsdf) # buffers['mesh'] = opt_mesh # buffers['sdf'] = sdf # return buffers # def tick(self, glctx, target, lgt, opt_material, loss_fn, guidance_model, text_z, iteration): # # ============================================================================================== # # Render optimizable object with identical conditions # # ============================================================================================== # buffers = self.render(glctx, target, lgt, opt_material) # mesh = buffers['mesh'] # # ============================================================================================== # # Compute loss # # ============================================================================================== # t_iter = iteration / self.opt.iter # if iteration < int(self.opt.iter * 0.2): # # mode = 'normal_latent' # pred_rgb = buffers['normal'][..., 0:4].permute(0, 3, 1, 2).contiguous() # as_latent = True # elif iteration < int(self.opt.iter * 0.6): # # mode = 'normal' # pred_rgb = buffers['normal'][..., 0:3].permute(0, 3, 1, 2).contiguous() # as_latent = False # else: # # mode = 'rgb' # pred_rgb = buffers['shaded'][..., 0:3].permute(0, 3, 1, 2).contiguous() # pred_ws = buffers['shaded'][..., 3].unsqueeze(1) # [B, 1, H, W] # pred_rgb = pred_rgb * pred_ws + (1 - pred_ws) * 1 # white bg # as_latent = False # # torch_vis_2d(pred_rgb[0]) # # torch_vis_2d(pred_normal[0]) # # torch_vis_2d(pred_ws[0]) # if self.opt.directional_text: # all_pos = [] # all_neg = [] # for emb in text_z[target['direction']]: # list of [2, S, -1] # pos, neg = emb.chunk(2) # [1, S, -1] # all_pos.append(pos) # all_neg.append(neg) # text_embedding = torch.cat(all_pos + all_neg, dim=0) # [2b, S, -1] # else: # text_embedding = text_z # img_loss = guidance_model.train_step(text_embedding, pred_rgb.half(), as_latent=as_latent) # # img_loss = torch.tensor(0.0, device = self.device) # # below are lots of regularizations... # reg_loss = torch.tensor(0.0, device = self.device) # if iteration < int(self.opt.iter * 0.6): # # SDF regularizer # sdf_weight = self.opt.sdf_regularizer - (self.opt.sdf_regularizer - 0.01) * min(1.0, 4.0 * t_iter) # sdf_loss = sdf_reg_loss(buffers['sdf'], self.all_edges).mean() * sdf_weight # Dropoff to 0.01 # reg_loss = reg_loss + sdf_loss # # directly regularize mesh smoothness in finetuning... # if iteration > int(self.opt.iter * 0.2): # lap_loss = regularizer.laplace_regularizer_const(mesh.v_pos, mesh.t_pos_idx) * self.opt.laplace_scale #* min(1.0, iteration / 500) # reg_loss = reg_loss + lap_loss # # normal_loss = regularizer.normal_consistency(mesh.v_pos, mesh.t_pos_idx) * self.opt.laplace_scale * min(1.0, iteration / 500) # # reg_loss = reg_loss + normal_loss # else: # # Albedo (k_d) smoothnesss regularizer # # reg_loss += torch.mean(buffers['kd_grad'][..., :-1] * buffers['kd_grad'][..., -1:]) * 0.03 * min(1.0, (iteration - int(self.opt.iter * 0.6)) / 500) # # # Visibility regularizer # # reg_loss += torch.mean(buffers['occlusion'][..., :-1] * buffers['occlusion'][..., -1:]) * 0.001 * min(1.0, (iteration - int(self.opt.iter * 0.6)) / 500) # # # Light white balance regularizer # reg_loss += lgt.regularizer() * 0.005 # return img_loss, reg_loss def compute_edge_to_face_mapping(attr_idx): with torch.no_grad(): # Get unique edges # Create all edges, packed by triangle all_edges = torch.cat(( torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1), torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1), torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1), ), dim=-1).view(-1, 2) # Swap edge order so min index is always first order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1) sorted_edges = torch.cat(( torch.gather(all_edges, 1, order), torch.gather(all_edges, 1, 1 - order) ), dim=-1) # Elliminate duplicates and return inverse mapping unique_edges, idx_map = torch.unique( sorted_edges, dim=0, return_inverse=True) tris = torch.arange(attr_idx.shape[0]).repeat_interleave(3).cuda() tris_per_edge = torch.zeros( (unique_edges.shape[0], 2), dtype=torch.int64).cuda() # Compute edge to face table mask0 = order[:, 0] == 0 mask1 = order[:, 0] == 1 tris_per_edge[idx_map[mask0], 0] = tris[mask0] tris_per_edge[idx_map[mask1], 1] = tris[mask1] return tris_per_edge @torch.cuda.amp.autocast(enabled=False) def normal_consistency(face_normals, t_pos_idx): tris_per_edge = compute_edge_to_face_mapping(t_pos_idx) # Fetch normals for both faces sharind an edge n0 = face_normals[tris_per_edge[:, 0], :] n1 = face_normals[tris_per_edge[:, 1], :] # Compute error metric based on normal difference term = torch.clamp(torch.sum(n0 * n1, -1, keepdim=True), min=-1.0, max=1.0) term = (1.0 - term) return torch.mean(torch.abs(term)) def laplacian_uniform(verts, faces): V = verts.shape[0] F = faces.shape[0] # Neighbor indices ii = faces[:, [1, 2, 0]].flatten() jj = faces[:, [2, 0, 1]].flatten() adj = torch.stack( [torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique(dim=1) adj_values = torch.ones( adj.shape[1], device=verts.device, dtype=torch.float) # Diagonal indices diag_idx = adj[0] # Build the sparse matrix idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1) values = torch.cat((-adj_values, adj_values)) # The coalesce operation sums the duplicate indices, resulting in the # correct diagonal return torch.sparse_coo_tensor(idx, values, (V, V)).coalesce() @torch.cuda.amp.autocast(enabled=False) def laplacian_smooth_loss(verts, faces): with torch.no_grad(): L = laplacian_uniform(verts, faces.long()) loss = L.mm(verts) loss = loss.norm(dim=1) loss = loss.mean() return loss class NeRFRenderer(nn.Module): def __init__(self, opt): super().__init__() self.opt = opt self.bound = opt.bound self.cascade = 1 + math.ceil(math.log2(opt.bound)) self.grid_size = 128 self.max_level = None self.dmtet = opt.dmtet self.cuda_ray = opt.cuda_ray self.taichi_ray = opt.taichi_ray self.min_near = opt.min_near self.density_thresh = opt.density_thresh # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax) # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing. aabb_train = torch.FloatTensor( [-opt.bound, -opt.bound, -opt.bound, opt.bound, opt.bound, opt.bound]) aabb_infer = aabb_train.clone() self.register_buffer('aabb_train', aabb_train) self.register_buffer('aabb_infer', aabb_infer) self.glctx = None # extra state for cuda raymarching if self.cuda_ray: # density grid density_grid = torch.zeros( [self.cascade, self.grid_size ** 3]) # [CAS, H * H * H] density_bitfield = torch.zeros( self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8] self.register_buffer('density_grid', density_grid) self.register_buffer('density_bitfield', density_bitfield) self.mean_density = 0 self.iter_density = 0 # load dmtet vertices if self.opt.dmtet: self.dmtet = DMTetGeometry(opt.tet_grid_size, opt.tet_mlp, opt).to(opt.device) if self.opt.h <= 2048 and self.opt.w <= 2048: self.glctx = dr.RasterizeCudaContext() else: self.glctx = dr.RasterizeGLContext() if self.taichi_ray: from einops import rearrange from taichi_modules import RayMarcherTaichi from taichi_modules import VolumeRendererTaichi from taichi_modules import RayAABBIntersector as RayAABBIntersectorTaichi from taichi_modules import raymarching_test as raymarching_test_taichi from taichi_modules import composite_test as composite_test_fw from taichi_modules import packbits as packbits_taichi self.rearrange = rearrange self.packbits_taichi = packbits_taichi self.ray_aabb_intersector = RayAABBIntersectorTaichi self.raymarching_test_taichi = raymarching_test_taichi self.composite_test_fw = composite_test_fw self.ray_marching = RayMarcherTaichi( batch_size=4096) # TODO: hard encoded batch size self.volume_render = VolumeRendererTaichi( batch_size=4096) # TODO: hard encoded batch size # density grid density_grid = torch.zeros( [self.cascade, self.grid_size ** 3]) # [CAS, H * H * H] density_bitfield = torch.zeros( self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8] self.register_buffer('density_grid', density_grid) self.register_buffer('density_bitfield', density_bitfield) self.mean_density = 0 self.iter_density = 0 if self.opt.density_activation == 'exp': self.density_activation = trunc_exp elif self.opt.density_activation == 'softplus': self.density_activation = F.softplus elif self.opt.density_activation == 'relu': self.density_activation = F.relu # ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192 def finite_difference_normal(self, x, epsilon=1e-2): # x: [N, 3] dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound)) dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound)) normal = torch.stack([ 0.5 * (dx_pos - dx_neg) / epsilon, 0.5 * (dy_pos - dy_neg) / epsilon, 0.5 * (dz_pos - dz_neg) / epsilon ], dim=-1) return -normal def normal(self, x): normal = self.finite_difference_normal(x) normal = safe_normalize(normal) normal = torch.nan_to_num(normal) return normal @torch.no_grad() def density_blob(self, x): # x: [B, N, 3] d = (x ** 2).sum(-1) if self.opt.density_activation == 'exp': g = self.opt.blob_density * \ torch.exp(- d / (2 * self.opt.blob_radius ** 2)) else: g = self.opt.blob_density * \ (1 - torch.sqrt(d) / self.opt.blob_radius) return g def forward(self, x, d): raise NotImplementedError() def density(self, x): raise NotImplementedError() def reset_extra_state(self): if not (self.cuda_ray or self.taichi_ray): return # density grid self.density_grid.zero_() self.mean_density = 0 self.iter_density = 0 @torch.no_grad() def export_mesh(self, path, resolution=None, decimate_target=-1, S=128): from meshutils import decimate_mesh, clean_mesh, poisson_mesh_reconstruction if self.opt.dmtet: vertices, triangles = self.dmtet.get_verts_face() vertices = vertices.detach().cpu().numpy() triangles = triangles.detach().cpu().numpy() else: if resolution is None: resolution = self.grid_size if self.cuda_ray: density_thresh = min(self.mean_density, self.density_thresh) \ if np.greater(self.mean_density, 0) else self.density_thresh else: density_thresh = self.density_thresh sigmas = np.zeros( [resolution, resolution, resolution], dtype=np.float32) # query X = torch.linspace(-1, 1, resolution).split(S) Y = torch.linspace(-1, 1, resolution).split(S) Z = torch.linspace(-1, 1, resolution).split(S) for xi, xs in enumerate(X): for yi, ys in enumerate(Y): for zi, zs in enumerate(Z): xx, yy, zz = custom_meshgrid(xs, ys, zs) pts = torch.cat( [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3] val = self.density(pts.to(self.aabb_train.device)) sigmas[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len( zs)] = val['sigma'].reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z] logger.info( f'[INFO] marching cubes thresh: {density_thresh} ({sigmas.min()} ~ {sigmas.max()})') vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh) vertices = vertices / (resolution - 1.0) * 2 - 1 # clean vertices = vertices.astype(np.float32) triangles = triangles.astype(np.int32) vertices, triangles = clean_mesh( vertices, triangles, remesh=True, remesh_size=0.01) # decimation if decimate_target > 0 and triangles.shape[0] > decimate_target: vertices, triangles = decimate_mesh( vertices, triangles, decimate_target) v = torch.from_numpy(vertices).contiguous( ).float().to(self.aabb_train.device) f = torch.from_numpy(triangles).contiguous().int().to( self.aabb_train.device) # mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault... # mesh.export(os.path.join(path, f'mesh.ply')) def _export(v, f, h0=2048, w0=2048, ssaa=1, name=''): # v, f: torch Tensor device = v.device v_np = v.cpu().numpy() # [N, 3] f_np = f.cpu().numpy() # [M, 3] logger.info( f'[INFO] running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}') # unwrap uvs import xatlas import nvdiffrast.torch as dr from sklearn.neighbors import NearestNeighbors from scipy.ndimage import binary_dilation, binary_erosion atlas = xatlas.Atlas() atlas.add_mesh(v_np, f_np) chart_options = xatlas.ChartOptions() chart_options.max_iterations = 4 # for faster unwrap... atlas.generate(chart_options=chart_options) vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2] # vmapping, ft_np, vt_np = xatlas.parametrize(v_np, f_np) # [N], [M, 3], [N, 2] vt = torch.from_numpy(vt_np.astype(np.float32)).float().to(device) ft = torch.from_numpy(ft_np.astype(np.int64)).int().to(device) # render uv maps uv = vt * 2.0 - 1.0 # uvs to range [-1, 1] uv = torch.cat((uv, torch.zeros_like( uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4] if ssaa > 1: h = int(h0 * ssaa) w = int(w0 * ssaa) else: h, w = h0, w0 if self.glctx is None: if h <= 2048 and w <= 2048: self.glctx = dr.RasterizeCudaContext() else: self.glctx = dr.RasterizeGLContext() rast, _ = dr.rasterize(self.glctx, uv.unsqueeze( 0), ft, (h, w)) # [1, h, w, 4] xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3] mask, _ = dr.interpolate(torch.ones_like( v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1] # masked query xyzs = xyzs.view(-1, 3) mask = (mask > 0).view(-1) feats = torch.zeros(h * w, 3, device=device, dtype=torch.float32) if mask.any(): xyzs = xyzs[mask] # [M, 3] # batched inference to avoid OOM all_feats = [] head = 0 while head < xyzs.shape[0]: tail = min(head + 640000, xyzs.shape[0]) results_ = self.density(xyzs[head:tail]) all_feats.append(results_['albedo'].float()) head += 640000 feats[mask] = torch.cat(all_feats, dim=0) feats = feats.view(h, w, -1) mask = mask.view(h, w) # quantize [0.0, 1.0] to [0, 255] feats = feats.cpu().numpy() feats = (feats * 255).astype(np.uint8) ### NN search as an antialiasing ... mask = mask.cpu().numpy() inpaint_region = binary_dilation(mask, iterations=3) inpaint_region[mask] = 0 search_region = mask.copy() not_search_region = binary_erosion(search_region, iterations=2) search_region[not_search_region] = 0 search_coords = np.stack(np.nonzero(search_region), axis=-1) inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1) knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords) _, indices = knn.kneighbors(inpaint_coords) feats[tuple(inpaint_coords.T)] = feats[tuple(search_coords[indices[:, 0]].T)] feats = cv2.cvtColor(feats, cv2.COLOR_RGB2BGR) # do ssaa after the NN search, in numpy if ssaa > 1: feats = cv2.resize(feats, (w0, h0), interpolation=cv2.INTER_LINEAR) cv2.imwrite(os.path.join(path, f'{name}albedo.png'), feats) # save obj (v, vt, f /) obj_file = os.path.join(path, f'{name}mesh.obj') mtl_file = os.path.join(path, f'{name}mesh.mtl') logger.info(f'[INFO] writing obj mesh to {obj_file}') with open(obj_file, "w") as fp: fp.write(f'mtllib {name}mesh.mtl \n') logger.info(f'[INFO] writing vertices {v_np.shape}') for v in v_np: fp.write(f'v {v[0]} {v[1]} {v[2]} \n') logger.info( f'[INFO] writing vertices texture coords {vt_np.shape}') for v in vt_np: fp.write(f'vt {v[0]} {1 - v[1]} \n') logger.info(f'[INFO] writing faces {f_np.shape}') fp.write(f'usemtl mat0 \n') for i in range(len(f_np)): fp.write( f"f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1} {f_np[i, 1] + 1}/{ft_np[i, 1] + 1} {f_np[i, 2] + 1}/{ft_np[i, 2] + 1} \n") with open(mtl_file, "w") as fp: fp.write(f'newmtl mat0 \n') fp.write(f'Ka 1.000000 1.000000 1.000000 \n') fp.write(f'Kd 1.000000 1.000000 1.000000 \n') fp.write(f'Ks 0.000000 0.000000 0.000000 \n') fp.write(f'Tr 1.000000 \n') fp.write(f'illum 1 \n') fp.write(f'Ns 0.000000 \n') fp.write(f'map_Kd {name}albedo.png \n') _export(v, f) def run(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, **kwargs): # rays_o, rays_d: [B, N, 3] # bg_color: [BN, 3] in range [0, 1] # return: image: [B, N, 3], depth: [B, N] prefix = rays_o.shape[:-1] rays_o = rays_o.contiguous().view(-1, 3) rays_d = rays_d.contiguous().view(-1, 3) N = rays_o.shape[0] # N = B * N, in fact device = rays_o.device results = {} # choose aabb aabb = self.aabb_train if self.training else self.aabb_infer # sample steps # nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb, self.min_near) # nears.unsqueeze_(-1) # fars.unsqueeze_(-1) nears, fars = near_far_from_bound(rays_o, rays_d, self.bound, type='sphere', min_near=self.min_near) # random sample light_d if not provided if light_d is None: # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face) if self.training: light_d = safe_normalize(rays_o + torch.randn(3, device=rays_o.device)) # [N, 3] else: light_d = safe_normalize(rays_o[0:1] + torch.randn(3, device=rays_o.device)) # [N, 3] #print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}') z_vals = torch.linspace(0.0, 1.0, self.opt.num_steps, device=device).unsqueeze(0) # [1, T] z_vals = z_vals.expand((N, self.opt.num_steps)) # [N, T] z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars] # perturb z_vals sample_dist = (fars - nears) / self.opt.num_steps if perturb: z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist #z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs. # generate xyzs xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3] xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip. #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) # query SDF and RGB density_outputs = self.density(xyzs.reshape(-1, 3)) #sigmas = density_outputs['sigma'].view(N, self.opt.num_steps) # [N, T] for k, v in density_outputs.items(): density_outputs[k] = v.view(N, self.opt.num_steps, -1) # upsample z_vals (nerf-like) if self.opt.upsample_steps > 0: with torch.no_grad(): deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1] deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1) alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T] alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1] weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T] # sample new z_vals z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1] new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], self.opt.upsample_steps, det=not self.training).detach() # [N, t] new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3] new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip. # only forward new points to save computation new_density_outputs = self.density(new_xyzs.reshape(-1, 3)) #new_sigmas = new_density_outputs['sigma'].view(N, self.opt.upsample_steps) # [N, t] for k, v in new_density_outputs.items(): new_density_outputs[k] = v.view(N, self.opt.upsample_steps, -1) # re-order z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t] z_vals, z_index = torch.sort(z_vals, dim=1) xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3] xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs)) for k in density_outputs: tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1) density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output)) deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1] deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1) alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T+t] alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1] weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t] dirs = rays_d.view(-1, 1, 3).expand_as(xyzs) light_d = light_d.view(-1, 1, 3).expand_as(xyzs) for k, v in density_outputs.items(): density_outputs[k] = v.view(-1, v.shape[-1]) dirs = safe_normalize(dirs) sigmas, rgbs, normals = self(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), light_d, ratio=ambient_ratio, shading=shading) rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3] if normals is not None: normals = normals.view(N, -1, 3) # calculate weight_sum (mask) weights_sum = weights.sum(dim=-1) # [N] # calculate depth depth = torch.sum(weights * z_vals, dim=-1) # calculate color image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1] # mix background color if bg_color is None: if self.opt.bg_radius > 0: # use the bg model to calculate bg_color bg_color = self.background(rays_d) # [N, 3] else: bg_color = 1 image = image + (1 - weights_sum).unsqueeze(-1) * bg_color image = image.view(*prefix, 3) depth = depth.view(*prefix) weights_sum = weights_sum.reshape(*prefix) if self.training: if self.opt.lambda_orient > 0 and normals is not None: # orientation loss loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2 results['loss_orient'] = loss_orient.sum(-1).mean() if self.opt.lambda_3d_normal_smooth > 0 and normals is not None: normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2) results['loss_normal_perturb'] = (normals - normals_perturb).abs().mean() if normals is not None: normal_image = torch.sum( weights.unsqueeze(-1) * (normals + 1) / 2, dim=-2) # [N, 3], in [0, 1] results['normal_image'] = normal_image results['image'] = image results['depth'] = depth results['weights'] = weights results['weights_sum'] = weights_sum return results def run_cuda(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, T_thresh=1e-4, binarize=False, **kwargs): # rays_o, rays_d: [B, N, 3] # return: image: [B, N, 3], depth: [B, N] prefix = rays_o.shape[:-1] rays_o = rays_o.contiguous().view(-1, 3) rays_d = rays_d.contiguous().view(-1, 3) N = rays_o.shape[0] # B * N, in fact device = rays_o.device # pre-calculate near far nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer) # random sample light_d if not provided if light_d is None: # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face) if self.training: light_d = safe_normalize(rays_o[0:1] + torch.randn(3, device=rays_o.device)) # [N, 3] else: light_d = safe_normalize(rays_o[0:1] + torch.randn(3, device=rays_o.device)) # [N, 3] results = {} if self.training: xyzs, dirs, ts, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, perturb, self.opt.dt_gamma, self.opt.max_steps) dirs = safe_normalize(dirs) if light_d.shape[0] > 1: flatten_rays = raymarching.flatten_rays(rays, xyzs.shape[0]).long() light_d = light_d[flatten_rays] sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading) weights, weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, ts, rays, T_thresh, binarize) # normals related regularizations if self.opt.lambda_orient > 0 and normals is not None: # orientation loss loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2 results['loss_orient'] = loss_orient.mean() if self.opt.lambda_3d_normal_smooth > 0 and normals is not None: normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2) results['loss_normal_perturb'] = (normals - normals_perturb).abs().mean() if normals is not None: _, _, _, normal_image = raymarching.composite_rays_train(sigmas.detach(), (normals + 1) / 2, ts, rays, T_thresh, binarize) results['normal_image'] = normal_image # weights normalization results['weights'] = weights else: # allocate outputs dtype = torch.float32 weights_sum = torch.zeros(N, dtype=dtype, device=device) depth = torch.zeros(N, dtype=dtype, device=device) image = torch.zeros(N, 3, dtype=dtype, device=device) n_alive = N rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N] rays_t = nears.clone() # [N] step = 0 while step < self.opt.max_steps: # hard coded max step # count alive rays n_alive = rays_alive.shape[0] # exit loop if n_alive <= 0: break # decide compact_steps n_step = max(min(N // n_alive, 8), 1) xyzs, dirs, ts = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, perturb if step == 0 else False, self.opt.dt_gamma, self.opt.max_steps) dirs = safe_normalize(dirs) sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading) raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, ts, weights_sum, depth, image, T_thresh, binarize) rays_alive = rays_alive[rays_alive >= 0] #print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}') step += n_step # mix background color if bg_color is None: if self.opt.bg_radius > 0: # use the bg model to calculate bg_color bg_color = self.background(rays_d) # [N, 3] else: bg_color = 1 image = image + (1 - weights_sum).unsqueeze(-1) * bg_color image = image.view(*prefix, 3) depth = depth.view(*prefix) weights_sum = weights_sum.reshape(*prefix) results['image'] = image results['depth'] = depth results['weights_sum'] = weights_sum return results def get_sdf_albedo_for_init(self, points=None): output = self.density(self.dmtet.verts if points is None else points) sigma, albedo = output['sigma'], output['albedo'] return sigma - self.density_thresh, albedo def run_dmtet(self, rays_o, rays_d, mvp, h, w, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, **kwargs): # mvp: [B, 4, 4] device = mvp.device campos = rays_o[:, 0, :] # only need one ray per batch # random sample light_d if not provided if light_d is None: # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face) light_d = safe_normalize(campos + torch.randn_like(campos)).view(-1, 1, 1, 3) # [B, 1, 1, 3] results = {} verts, faces = self.dmtet.get_verts_face() # get normals i0, i1, i2 = faces[:, 0], faces[:, 1], faces[:, 2] v0, v1, v2 = verts[i0, :], verts[i1, :], verts[i2, :] faces = faces.int() face_normals = torch.cross(v1 - v0, v2 - v0) face_normals = safe_normalize(face_normals) vn = torch.zeros_like(verts) vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals) vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals) vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals) vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device)) # rasterization verts_clip = torch.bmm(F.pad(verts, pad=(0, 1), mode='constant', value=1.0).unsqueeze(0).repeat(mvp.shape[0], 1, 1), mvp.permute(0,2,1)).float() # [B, N, 4] rast, rast_db = dr.rasterize(self.glctx, verts_clip, faces, (h, w)) alpha, _ = dr.interpolate(torch.ones_like(verts[:, :1]).unsqueeze(0), rast, faces) # [B, H, W, 1] xyzs, _ = dr.interpolate(verts.unsqueeze(0), rast, faces) # [B, H, W, 3] normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, faces) normal = safe_normalize(normal) xyzs = xyzs.view(-1, 3) mask = (alpha > 0).view(-1).detach() # do the lighting here since we have normal from mesh now. albedo = torch.zeros_like(xyzs, dtype=torch.float32) if mask.any(): masked_albedo = self.density(xyzs[mask])['albedo'] albedo[mask] = masked_albedo.float() albedo = albedo.view(-1, h, w, 3) if shading == 'albedo': color = albedo elif shading == 'textureless': lambertian = ambient_ratio + (1 - ambient_ratio) * (normal * light_d).sum(-1).float().clamp(min=0) color = lambertian.unsqueeze(-1).repeat(1, 1, 1, 3) elif shading == 'normal': color = (normal + 1) / 2 else: # 'lambertian' lambertian = ambient_ratio + (1 - ambient_ratio) * (normal * light_d).sum(-1).float().clamp(min=0) color = albedo * lambertian.unsqueeze(-1) color = dr.antialias(color, rast, verts_clip, faces).clamp(0, 1) # [B, H, W, 3] alpha = dr.antialias(alpha, rast, verts_clip, faces).clamp(0, 1) # [B, H, W, 1] # mix background color if bg_color is None: if self.opt.bg_radius > 0: # use the bg model to calculate bg_color bg_color = self.background(rays_d) # [N, 3] else: bg_color = 1 if torch.is_tensor(bg_color) and len(bg_color.shape) > 1: bg_color = bg_color.view(-1, h, w, 3) depth = rast[:, :, :, [2]] # [B, H, W] color = color + (1 - alpha) * bg_color results['depth'] = depth results['image'] = color results['weights_sum'] = alpha.squeeze(-1) normal_image = dr.antialias((normal + 1) / 2, rast, verts_clip, faces).clamp(0, 1) # [B, H, W, 3] results['normal_image'] = normal_image # regularizations if self.training: if self.opt.lambda_mesh_normal > 0: results['loss_normal'] = normal_consistency( face_normals, faces) if self.opt.lambda_mesh_lap > 0: results['loss_lap'] = laplacian_smooth_loss(verts, faces) return results def run_taichi(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, T_thresh=1e-4, **kwargs): # rays_o, rays_d: [B, N, 3], assumes B == 1 # return: image: [B, N, 3], depth: [B, N] prefix = rays_o.shape[:-1] rays_o = rays_o.contiguous().view(-1, 3) rays_d = rays_d.contiguous().view(-1, 3) N = rays_o.shape[0] # N = B * N, in fact device = rays_o.device # pre-calculate near far exp_step_factor = kwargs.get('exp_step_factor', 0.) MAX_SAMPLES = 1024 NEAR_DISTANCE = 0.01 center = torch.zeros(1, 3) half_size = torch.ones(1, 3) _, hits_t, _ = self.ray_aabb_intersector.apply(rays_o, rays_d, center, half_size, 1) hits_t[(hits_t[:, 0, 0] >= 0) & (hits_t[:, 0, 0] < NEAR_DISTANCE), 0, 0] = NEAR_DISTANCE # TODO: should sample different light_d for each batch... but taichi end doesn't have a flatten_ray implemented currently... # random sample light_d if not provided if light_d is None: # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face) light_d = (rays_o[0] + torch.randn(3, device=device, dtype=torch.float)) light_d = safe_normalize(light_d) results = {} if self.training: rays_a, xyzs, dirs, deltas, ts, _ = self.ray_marching(rays_o, rays_d, hits_t[:, 0], self.density_bitfield, self.cascade, self.bound, exp_step_factor, self.grid_size, MAX_SAMPLES) dirs = safe_normalize(dirs) # plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading) _, weights_sum, depth, image, weights = self.volume_render(sigmas, rgbs, deltas, ts, rays_a, kwargs.get('T_threshold', 1e-4)) # normals related regularizations if self.opt.lambda_orient > 0 and normals is not None: # orientation loss loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2 results['loss_orient'] = loss_orient.mean() if self.opt.lambda_3d_normal_smooth > 0 and normals is not None: normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2) results['loss_normal_perturb'] = (normals - normals_perturb).abs().mean() if normals is not None: _, _, _, normal_image, _ = self.volume_render(sigmas.detach(), (normals + 1) / 2, deltas, ts, rays_a, kwargs.get('T_threshold', 1e-4)) results['normal_image'] = normal_image # weights normalization results['weights'] = weights else: # allocate outputs dtype = torch.float32 weights_sum = torch.zeros(N, dtype=dtype, device=device) depth = torch.zeros(N, dtype=dtype, device=device) image = torch.zeros(N, 3, dtype=dtype, device=device) n_alive = N rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N] rays_t = hits_t[:, 0, 0] step = 0 min_samples = 1 if exp_step_factor == 0 else 4 while step < self.opt.max_steps: # hard coded max step # count alive rays n_alive = rays_alive.shape[0] # exit loop if n_alive <= 0: break # decide compact_steps # n_step = max(min(N // n_alive, 8), 1) n_step = max(min(N // n_alive, 64), min_samples) xyzs, dirs, deltas, ts, N_eff_samples = \ self.raymarching_test_taichi(rays_o, rays_d, hits_t[:, 0], rays_alive, self.density_bitfield, self.cascade, self.bound, exp_step_factor, self.grid_size, MAX_SAMPLES, n_step) xyzs = self.rearrange(xyzs, 'n1 n2 c -> (n1 n2) c') dirs = self.rearrange(dirs, 'n1 n2 c -> (n1 n2) c') dirs = safe_normalize(dirs) valid_mask = ~torch.all(dirs == 0, dim=1) if valid_mask.sum() == 0: break sigmas = torch.zeros(len(xyzs), device=device) rgbs = torch.zeros(len(xyzs), 3, device=device) normals = torch.zeros(len(xyzs), 3, device=device) sigmas[valid_mask], _rgbs, normals = self(xyzs[valid_mask], dirs[valid_mask], light_d, ratio=ambient_ratio, shading=shading) rgbs[valid_mask] = _rgbs.float() sigmas = self.rearrange(sigmas, '(n1 n2) -> n1 n2', n2=n_step) rgbs = self.rearrange(rgbs, '(n1 n2) c -> n1 n2 c', n2=n_step) if normals is not None: normals = self.rearrange(normals, '(n1 n2) c -> n1 n2 c', n2=n_step) self.composite_test_fw(sigmas, rgbs, deltas, ts, hits_t[:,0], rays_alive, kwargs.get('T_threshold', 1e-4), N_eff_samples, weights_sum, depth, image) rays_alive = rays_alive[rays_alive >= 0] step += n_step # mix background color if bg_color is None: if self.opt.bg_radius > 0: # use the bg model to calculate bg_color bg_color = self.background(rays_d) # [N, 3] else: bg_color = 1 image = image + self.rearrange(1 - weights_sum, 'n -> n 1') * bg_color image = image.view(*prefix, 3) depth = depth.view(*prefix) weights_sum = weights_sum.reshape(*prefix) results['image'] = image results['depth'] = depth results['weights_sum'] = weights_sum return results @torch.no_grad() def update_extra_state(self, decay=0.95, S=128): # call before each epoch to update extra states. if not (self.cuda_ray or self.taichi_ray): return ### update density grid tmp_grid = - torch.ones_like(self.density_grid) X = torch.arange(self.grid_size, dtype=torch.int32, device=self.aabb_train.device).split(S) Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.aabb_train.device).split(S) Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.aabb_train.device).split(S) for xs in X: for ys in Y: for zs in Z: # construct points xx, yy, zz = custom_meshgrid(xs, ys, zs) coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128) indices = raymarching.morton3D(coords).long() # [N] xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1] # cascading for cas in range(self.cascade): bound = min(2 ** cas, self.bound) half_grid_size = bound / self.grid_size # scale to current cascade's resolution cas_xyzs = xyzs * (bound - half_grid_size) # add noise in [-hgs, hgs] cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size # query density sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach() # assign tmp_grid[cas, indices] = sigmas # ema update valid_mask = self.density_grid >= 0 self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask]) self.mean_density = torch.mean(self.density_grid[valid_mask]).item() self.iter_density += 1 # convert to bitfield density_thresh = min(self.mean_density, self.density_thresh) if self.cuda_ray: self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield) elif self.taichi_ray: self.packbits_taichi(self.density_grid.reshape(-1).contiguous(), density_thresh, self.density_bitfield) # print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > density_thresh).sum() / (128**3 * self.cascade):.3f}') def render(self, rays_o, rays_d, mvp, h, w, staged=False, max_ray_batch=4096, **kwargs): # rays_o, rays_d: [B, N, 3] # return: pred_rgb: [B, N, 3] B, N = rays_o.shape[:2] device = rays_o.device if self.dmtet: results = self.run_dmtet(rays_o, rays_d, mvp, h, w, **kwargs) elif self.cuda_ray: results = self.run_cuda(rays_o, rays_d, **kwargs) elif self.taichi_ray: results = self.run_taichi(rays_o, rays_d, **kwargs) else: if staged: depth = torch.empty((B, N), device=device) image = torch.empty((B, N, 3), device=device) weights_sum = torch.empty((B, N), device=device) for b in range(B): head = 0 while head < N: tail = min(head + max_ray_batch, N) results_ = self.run(rays_o[b:b+1, head:tail], rays_d[b:b+1, head:tail], **kwargs) depth[b:b+1, head:tail] = results_['depth'] weights_sum[b:b+1, head:tail] = results_['weights_sum'] image[b:b+1, head:tail] = results_['image'] head += max_ray_batch results = {} results['depth'] = depth results['image'] = image results['weights_sum'] = weights_sum else: results = self.run(rays_o, rays_d, **kwargs) return results def init_tet_from_nerf(self, reset_scale=True): sdf = self.get_sdf_from_nerf(reset_scale=reset_scale) self.dmtet.init_tet_from_sdf(sdf) logger.info(f'init dmtet from NeRF Done ...') @torch.no_grad() def get_sdf_from_nerf(self, reset_scale=True): if self.cuda_ray: density_thresh = min(self.mean_density, self.density_thresh) else: density_thresh = self.density_thresh if reset_scale: # init scale sigma = self.density(self.dmtet.verts)[ 'sigma'] # verts covers [-1, 1] now mask = sigma > density_thresh valid_verts = self.dmtet.verts[mask] tet_scale = valid_verts.abs().amax(dim=0) + 1e-1 self.dmtet.reset_tet_scale(tet_scale) sdf = (self.density(self.dmtet.verts)[ 'sigma'] - density_thresh).clamp(-1, 1) return sdf