first commit

This commit is contained in:
Guocheng Qian
2023-08-02 19:51:43 -07:00
parent c2891c38cc
commit 13e18567fa
202 changed files with 43362 additions and 17 deletions

127
nerf/clip.py Normal file
View File

@@ -0,0 +1,127 @@
import torch
import torch.nn as nn
import torchvision.transforms as T
import torchvision.transforms.functional as TF
# import clip
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer, CLIPProcessor
from torchvision import transforms
import torch.nn.functional as F
def spherical_dist_loss(x, y):
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
# print(x.shape, y.shape)
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
class CLIP(nn.Module):
def __init__(self, device,
# clip_name = 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
clip_name = 'openai/clip-vit-large-patch14'
):
super().__init__()
self.device = device
clip_name = clip_name
# self.feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_name)
self.clip_model = CLIPModel.from_pretrained(clip_name).cuda()
self.tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14')
# self.tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch32')
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
# self.normalize = transforms.Normalize(mean=self.feature_extractor.image_mean, std=self.feature_extractor.image_std)
# self.resize = transforms.Resize(224)
# # image augmentation
# self.aug = T.Compose([
# T.Resize((224, 224)),
# T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
# ])
def get_text_embeds(self, prompt, neg_prompt=None, dir=None):
clip_text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
).input_ids.cuda()
text_z = self.clip_model.get_text_features(clip_text_input)
# text = clip.tokenize(prompt).to(self.device)
# text_z = self.clip_model.encode_text(text)
text_z = text_z / text_z.norm(dim=-1, keepdim=True)
return text_z
def set_epoch(self, epoch):
pass
def get_img_embeds(self, img):
img = self.aug(img)
image_z = self.clip_model.get_image_features(img)
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
return image_z
# def train_step(self, text_z, pred_rgb, image_ref_clip, **kwargs):
# pred_rgb = self.resize(pred_rgb)
# pred_rgb = self.normalize(pred_rgb)
# image_z = self.clip_model.get_image_features(pred_rgb)
# image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
# # print(image_z.shape, text_z.shape)
# loss = spherical_dist_loss(image_z, image_ref_clip)
# # loss = - (image_z * text_z).sum(-1).mean()
# return loss
def train_step(self, text_z, pred_rgb):
pred_rgb = self.aug(pred_rgb)
image_z = self.clip_model.encode_image(pred_rgb)
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
loss = - (image_z * text_z).sum(-1).mean()
# loss = spherical_dist_loss(image_z, text_z)
return loss
def text_loss(self, text_z, pred_rgb):
pred_rgb = self.resize(pred_rgb)
pred_rgb = self.normalize(pred_rgb)
image_z = self.clip_model.get_image_features(pred_rgb)
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
# print(image_z.shape, text_z.shape)
loss = spherical_dist_loss(image_z, text_z)
# loss = - (image_z * text_z).sum(-1).mean()
return loss
def img_loss(self, img_ref_z, pred_rgb):
# pred_rgb = self.aug(pred_rgb)
pred_rgb = self.resize(pred_rgb)
pred_rgb = self.normalize(pred_rgb)
image_z = self.clip_model.get_image_features(pred_rgb)
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
# loss = - (image_z * img_ref_z).sum(-1).mean()
loss = spherical_dist_loss(image_z, img_ref_z)
return loss

485
nerf/gui.py Normal file
View File

@@ -0,0 +1,485 @@
import math
import torch
import numpy as np
import dearpygui.dearpygui as dpg
from scipy.spatial.transform import Rotation as R
from nerf.utils import *
class OrbitCamera:
def __init__(self, W, H, r=2, fovy=60):
self.W = W
self.H = H
self.radius = r # camera distance from center
self.fovy = fovy # in degree
self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point
self.rot = R.from_matrix(np.eye(3))
self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized!
self.near = 0.001
self.far = 1000
# pose
@property
def pose(self):
# first move camera to radius
res = np.eye(4, dtype=np.float32)
res[2, 3] = self.radius
# rotate
rot = np.eye(4, dtype=np.float32)
rot[:3, :3] = self.rot.as_matrix()
res = rot @ res
# translate
res[:3, 3] -= self.center
return res
# intrinsics
@property
def intrinsics(self):
focal = self.H / (2 * np.tan(np.deg2rad(self.fovy) / 2))
return np.array([focal, focal, self.W // 2, self.H // 2])
@property
def mvp(self):
focal = self.H / (2 * np.tan(np.deg2rad(self.fovy) / 2))
projection = np.array([
[2*focal/self.W, 0, 0, 0],
[0, -2*focal/self.H, 0, 0],
[0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)],
[0, 0, -1, 0]
], dtype=np.float32)
return projection @ np.linalg.inv(self.pose) # [4, 4]
def orbit(self, dx, dy):
# rotate along camera up/side axis!
side = self.rot.as_matrix()[:3, 0] # why this is side --> ? # already normalized.
rotvec_x = self.up * np.deg2rad(-0.1 * dx)
rotvec_y = side * np.deg2rad(-0.1 * dy)
self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot
def scale(self, delta):
self.radius *= 1.1 ** (-delta)
def pan(self, dx, dy, dz=0):
# pan in camera coordinate system (careful on the sensitivity!)
self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([dx, -dy, dz])
class NeRFGUI:
def __init__(self, opt, trainer, loader=None, debug=True):
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.W = opt.W
self.H = opt.H
self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
self.debug = debug
self.bg_color = torch.ones(3, dtype=torch.float32) # default white bg
self.training = False
self.step = 0 # training step
self.trainer = trainer
self.loader = loader
self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
self.need_update = True # camera moved, should reset accumulation
self.spp = 1 # sample per pixel
self.light_dir = np.array([opt.light_theta, opt.light_phi])
self.ambient_ratio = 1.0
self.mode = 'image' # choose from ['image', 'depth']
self.shading = 'albedo'
self.dynamic_resolution = True if not self.opt.dmtet else False
self.downscale = 1
self.train_steps = 16
dpg.create_context()
self.register_dpg()
self.test_step()
def __del__(self):
dpg.destroy_context()
def train_step(self):
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
starter.record()
outputs = self.trainer.train_gui(self.loader, step=self.train_steps)
ender.record()
torch.cuda.synchronize()
t = starter.elapsed_time(ender)
self.step += self.train_steps
self.need_update = True
dpg.set_value("_log_train_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
dpg.set_value("_log_train_log", f'step = {self.step: 5d} (+{self.train_steps: 2d}), loss = {outputs["loss"]:.4f}, lr = {outputs["lr"]:.5f}')
# dynamic train steps
# max allowed train time per-frame is 500 ms
full_t = t / self.train_steps * 16
train_steps = min(16, max(4, int(16 * 500 / full_t)))
if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8:
self.train_steps = train_steps
def prepare_buffer(self, outputs):
if self.mode == 'image':
return outputs['image'].astype(np.float32)
else:
depth = outputs['depth'].astype(np.float32)
depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-6)
return np.expand_dims(depth, -1).repeat(3, -1)
def test_step(self):
if self.need_update or self.spp < self.opt.max_spp:
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
starter.record()
outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.cam.mvp, self.W, self.H, self.bg_color, self.spp, self.downscale, self.light_dir, self.ambient_ratio, self.shading)
ender.record()
torch.cuda.synchronize()
t = starter.elapsed_time(ender)
# update dynamic resolution
if self.dynamic_resolution:
# max allowed infer time per-frame is 200 ms
full_t = t / (self.downscale ** 2)
downscale = min(1, max(1/4, math.sqrt(200 / full_t)))
if downscale > self.downscale * 1.2 or downscale < self.downscale * 0.8:
self.downscale = downscale
if self.need_update:
self.render_buffer = self.prepare_buffer(outputs)
self.spp = 1
self.need_update = False
else:
self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1)
self.spp += 1
dpg.set_value("_log_infer_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}')
dpg.set_value("_log_spp", self.spp)
dpg.set_value("_texture", self.render_buffer)
def register_dpg(self):
### register texture
with dpg.texture_registry(show=False):
dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture")
### register window
# the rendered image, as the primary window
with dpg.window(tag="_primary_window", width=self.W, height=self.H):
# add the texture
dpg.add_image("_texture")
dpg.set_primary_window("_primary_window", True)
# control window
with dpg.window(label="Control", tag="_control_window", width=400, height=300):
# text prompt
if self.opt.text is not None:
dpg.add_text("text: " + self.opt.text, tag="_log_prompt_text")
if self.opt.negative != '':
dpg.add_text("negative text: " + self.opt.negative, tag="_log_prompt_negative_text")
# button theme
with dpg.theme() as theme_button:
with dpg.theme_component(dpg.mvButton):
dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))
dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))
dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))
dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)
dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)
# time
if not self.opt.test:
with dpg.group(horizontal=True):
dpg.add_text("Train time: ")
dpg.add_text("no data", tag="_log_train_time")
with dpg.group(horizontal=True):
dpg.add_text("Infer time: ")
dpg.add_text("no data", tag="_log_infer_time")
with dpg.group(horizontal=True):
dpg.add_text("SPP: ")
dpg.add_text("1", tag="_log_spp")
# train button
if not self.opt.test:
with dpg.collapsing_header(label="Train", default_open=True):
with dpg.group(horizontal=True):
dpg.add_text("Train: ")
def callback_train(sender, app_data):
if self.training:
self.training = False
dpg.configure_item("_button_train", label="start")
else:
self.training = True
dpg.configure_item("_button_train", label="stop")
dpg.add_button(label="start", tag="_button_train", callback=callback_train)
dpg.bind_item_theme("_button_train", theme_button)
def callback_reset(sender, app_data):
@torch.no_grad()
def weight_reset(m: nn.Module):
reset_parameters = getattr(m, "reset_parameters", None)
if callable(reset_parameters):
m.reset_parameters()
self.trainer.model.apply(fn=weight_reset)
self.trainer.model.reset_extra_state() # for cuda_ray density_grid and step_counter
self.need_update = True
dpg.add_button(label="reset", tag="_button_reset", callback=callback_reset)
dpg.bind_item_theme("_button_reset", theme_button)
with dpg.group(horizontal=True):
dpg.add_text("Checkpoint: ")
def callback_save(sender, app_data):
self.trainer.save_checkpoint(full=True, best=False)
dpg.set_value("_log_ckpt", "saved " + os.path.basename(self.trainer.stats["checkpoints"][-1]))
self.trainer.epoch += 1 # use epoch to indicate different calls.
dpg.add_button(label="save", tag="_button_save", callback=callback_save)
dpg.bind_item_theme("_button_save", theme_button)
dpg.add_text("", tag="_log_ckpt")
# save mesh
with dpg.group(horizontal=True):
dpg.add_text("Marching Cubes: ")
def callback_mesh(sender, app_data):
self.trainer.save_mesh()
dpg.set_value("_log_mesh", "saved " + f'{self.trainer.name}_{self.trainer.epoch}.ply')
self.trainer.epoch += 1 # use epoch to indicate different calls.
dpg.add_button(label="mesh", tag="_button_mesh", callback=callback_mesh)
dpg.bind_item_theme("_button_mesh", theme_button)
dpg.add_text("", tag="_log_mesh")
with dpg.group(horizontal=True):
dpg.add_text("", tag="_log_train_log")
# rendering options
with dpg.collapsing_header(label="Options", default_open=True):
# dynamic rendering resolution
with dpg.group(horizontal=True):
def callback_set_dynamic_resolution(sender, app_data):
if self.dynamic_resolution:
self.dynamic_resolution = False
self.downscale = 1
else:
self.dynamic_resolution = True
self.need_update = True
dpg.add_checkbox(label="dynamic resolution", default_value=self.dynamic_resolution, callback=callback_set_dynamic_resolution)
dpg.add_text(f"{self.W}x{self.H}", tag="_log_resolution")
# mode combo
def callback_change_mode(sender, app_data):
self.mode = app_data
self.need_update = True
dpg.add_combo(('image', 'depth'), label='mode', default_value=self.mode, callback=callback_change_mode)
# bg_color picker
def callback_change_bg(sender, app_data):
self.bg_color = torch.tensor(app_data[:3], dtype=torch.float32) # only need RGB in [0, 1]
self.need_update = True
dpg.add_color_edit((255, 255, 255), label="Background Color", width=200, tag="_color_editor", no_alpha=True, callback=callback_change_bg)
# fov slider
def callback_set_fovy(sender, app_data):
self.cam.fovy = app_data
self.need_update = True
dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy)
# dt_gamma slider
def callback_set_dt_gamma(sender, app_data):
self.opt.dt_gamma = app_data
self.need_update = True
dpg.add_slider_float(label="dt_gamma", min_value=0, max_value=0.1, format="%.5f", default_value=self.opt.dt_gamma, callback=callback_set_dt_gamma)
# max_steps slider
def callback_set_max_steps(sender, app_data):
self.opt.max_steps = app_data
self.need_update = True
dpg.add_slider_int(label="max steps", min_value=1, max_value=1024, format="%d", default_value=self.opt.max_steps, callback=callback_set_max_steps)
# aabb slider
def callback_set_aabb(sender, app_data, user_data):
# user_data is the dimension for aabb (xmin, ymin, zmin, xmax, ymax, zmax)
self.trainer.model.aabb_infer[user_data] = app_data
# also change train aabb ? [better not...]
#self.trainer.model.aabb_train[user_data] = app_data
self.need_update = True
dpg.add_separator()
dpg.add_text("Axis-aligned bounding box:")
with dpg.group(horizontal=True):
dpg.add_slider_float(label="x", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=0)
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=3)
with dpg.group(horizontal=True):
dpg.add_slider_float(label="y", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=1)
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=4)
with dpg.group(horizontal=True):
dpg.add_slider_float(label="z", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=2)
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=5)
# light dir
def callback_set_light_dir(sender, app_data, user_data):
self.light_dir[user_data] = app_data
self.need_update = True
dpg.add_separator()
dpg.add_text("Plane Light Direction:")
with dpg.group(horizontal=True):
dpg.add_slider_float(label="theta", min_value=0, max_value=180, format="%.2f", default_value=self.opt.light_theta, callback=callback_set_light_dir, user_data=0)
with dpg.group(horizontal=True):
dpg.add_slider_float(label="phi", min_value=0, max_value=360, format="%.2f", default_value=self.opt.light_phi, callback=callback_set_light_dir, user_data=1)
# ambient ratio
def callback_set_abm_ratio(sender, app_data):
self.ambient_ratio = app_data
self.need_update = True
dpg.add_slider_float(label="ambient", min_value=0, max_value=1.0, format="%.5f", default_value=self.ambient_ratio, callback=callback_set_abm_ratio)
# shading mode
def callback_change_shading(sender, app_data):
self.shading = app_data
self.need_update = True
dpg.add_combo(('albedo', 'lambertian', 'textureless', 'normal'), label='shading', default_value=self.shading, callback=callback_change_shading)
# debug info
if self.debug:
with dpg.collapsing_header(label="Debug"):
# pose
dpg.add_separator()
dpg.add_text("Camera Pose:")
dpg.add_text(str(self.cam.pose), tag="_log_pose")
### register camera handler
def callback_camera_drag_rotate(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
dx = app_data[1]
dy = app_data[2]
self.cam.orbit(dx, dy)
self.need_update = True
if self.debug:
dpg.set_value("_log_pose", str(self.cam.pose))
def callback_camera_wheel_scale(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
delta = app_data
self.cam.scale(delta)
self.need_update = True
if self.debug:
dpg.set_value("_log_pose", str(self.cam.pose))
def callback_camera_drag_pan(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
dx = app_data[1]
dy = app_data[2]
self.cam.pan(dx, dy)
self.need_update = True
if self.debug:
dpg.set_value("_log_pose", str(self.cam.pose))
with dpg.handler_registry():
dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate)
dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Right, callback=callback_camera_drag_pan)
dpg.create_viewport(title='torch-ngp', width=self.W, height=self.H, resizable=False)
# TODO: seems dearpygui doesn't support resizing texture...
# def callback_resize(sender, app_data):
# self.W = app_data[0]
# self.H = app_data[1]
# # how to reload texture ???
# dpg.set_viewport_resize_callback(callback_resize)
### global theme
with dpg.theme() as theme_no_padding:
with dpg.theme_component(dpg.mvAll):
# set all padding to 0 to avoid scroll bar
dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core)
dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core)
dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core)
dpg.bind_item_theme("_primary_window", theme_no_padding)
dpg.setup_dearpygui()
#dpg.show_metrics()
dpg.show_viewport()
def render(self):
while dpg.is_dearpygui_running():
# update texture every frame
if self.training:
self.train_step()
self.test_step()
dpg.render_dearpygui_frame()

238
nerf/network.py Normal file
View File

@@ -0,0 +1,238 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from activation import trunc_exp
from .renderer import NeRFRenderer
import numpy as np
from encoding import get_encoder
from .utils import safe_normalize
from tqdm import tqdm
class ResBlock(nn.Module):
def __init__(self, dim_in, dim_out, bias=True):
super().__init__()
self.dim_in = dim_in
self.dim_out = dim_out
self.dense = nn.Linear(self.dim_in, self.dim_out, bias=bias)
self.norm = nn.LayerNorm(self.dim_out)
self.activation = nn.SiLU(inplace=True)
if self.dim_in != self.dim_out:
self.skip = nn.Linear(self.dim_in, self.dim_out, bias=False)
else:
self.skip = None
def forward(self, x):
# x: [B, C]
identity = x
out = self.dense(x)
out = self.norm(out)
if self.skip is not None:
identity = self.skip(identity)
out += identity
out = self.activation(out)
return out
class BasicBlock(nn.Module):
def __init__(self, dim_in, dim_out, bias=True):
super().__init__()
self.dim_in = dim_in
self.dim_out = dim_out
self.dense = nn.Linear(self.dim_in, self.dim_out, bias=bias)
self.activation = nn.ReLU(inplace=True)
def forward(self, x):
# x: [B, C]
out = self.dense(x)
out = self.activation(out)
return out
class MLP(nn.Module):
def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True, block=BasicBlock):
super().__init__()
self.dim_in = dim_in
self.dim_out = dim_out
self.dim_hidden = dim_hidden
self.num_layers = num_layers
net = []
for l in range(num_layers):
if l == 0:
net.append(BasicBlock(self.dim_in, self.dim_hidden, bias=bias))
elif l != num_layers - 1:
net.append(block(self.dim_hidden, self.dim_hidden, bias=bias))
else:
net.append(nn.Linear(self.dim_hidden, self.dim_out, bias=bias))
self.net = nn.ModuleList(net)
def reset_parameters(self):
@torch.no_grad()
def weight_init(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
nn.init.zeros_(m.bias)
self.apply(weight_init)
def forward(self, x):
for l in range(self.num_layers):
x = self.net[l](x)
return x
class NeRFNetwork(NeRFRenderer):
def __init__(self,
opt,
num_layers=5, # 5 in paper
hidden_dim=64, # 128 in paper
num_layers_bg=2, # 3 in paper
hidden_dim_bg=32, # 64 in paper
encoding='frequency_torch', # pure pytorch
output_dim=4, # 7 for DMTet (sdf 1 + color 3 + deform 3), 4 for NeRF
):
super().__init__(opt)
self.num_layers = opt.num_layers if hasattr(opt, 'num_layers') else num_layers
self.hidden_dim = opt.hidden_dim if hasattr(opt, 'hidden_dim') else hidden_dim
num_layers_bg = opt.num_layers_bg if hasattr(opt, 'num_layers_bg') else num_layers_bg
hidden_dim_bg = opt.hidden_dim_bg if hasattr(opt, 'hidden_dim_bg') else hidden_dim_bg
self.encoder, self.in_dim = get_encoder(encoding, input_dim=3, multires=6)
self.sigma_net = MLP(self.in_dim, output_dim, hidden_dim, num_layers, bias=True, block=ResBlock)
self.grid_levels_mask = 0
# background network
if self.opt.bg_radius > 0:
self.num_layers_bg = num_layers_bg
self.hidden_dim_bg = hidden_dim_bg
self.encoder_bg, self.in_dim_bg = get_encoder(encoding, input_dim=3, multires=4)
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
else:
self.bg_net = None
def common_forward(self, x):
# x: [N, 3], in [-bound, bound]
# sigma
h = self.encoder(x, bound=self.bound, max_level=self.max_level)
# Feature masking for coarse-to-fine training
if self.grid_levels_mask > 0:
h_mask: torch.Tensor = torch.arange(self.in_dim, device=h.device) < self.in_dim - self.grid_levels_mask * self.level_dim # (self.in_dim)
h_mask = h_mask.reshape(1, self.in_dim).float() # (1, self.in_dim)
h = h * h_mask # (N, self.in_dim)
h = self.sigma_net(h)
sigma = self.density_activation(h[..., 0] + self.density_blob(x))
albedo = torch.sigmoid(h[..., 1:])
return sigma, albedo
def normal(self, x):
with torch.enable_grad():
x.requires_grad_(True)
sigma, albedo = self.common_forward(x)
# query gradient
normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
# normal = self.finite_difference_normal(x)
normal = safe_normalize(normal)
# normal = torch.nan_to_num(normal)
return normal
def forward(self, x, d, l=None, ratio=1, shading='albedo'):
# x: [N, 3], in [-bound, bound]
# d: [N, 3], view direction, nomalized in [-1, 1]
# l: [3], plane light direction, nomalized in [-1, 1]
# ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
if shading == 'albedo':
# no need to query normal
sigma, color = self.common_forward(x)
normal = None
else:
# query normal
# sigma, albedo = self.common_forward(x)
# normal = self.normal(x)
with torch.enable_grad():
x.requires_grad_(True)
sigma, albedo = self.common_forward(x)
# query gradient
normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
normal = safe_normalize(normal)
# normal = torch.nan_to_num(normal)
# normal = normal.detach()
# lambertian shading
lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]
if shading == 'textureless':
color = lambertian.unsqueeze(-1).repeat(1, 3)
elif shading == 'normal':
color = (normal + 1) / 2
else: # 'lambertian'
color = albedo * lambertian.unsqueeze(-1)
return sigma, color, normal
def density(self, x):
# x: [N, 3], in [-bound, bound]
sigma, albedo = self.common_forward(x)
return {
'sigma': sigma,
'albedo': albedo,
}
def background(self, d):
h = self.encoder_bg(d) # [N, C]
h = self.bg_net(h)
# sigmoid activation for rgb
rgbs = torch.sigmoid(h)
return rgbs
# optimizer utils
def get_params(self, lr):
params = [
# {'params': self.encoder.parameters(), 'lr': lr * 10},
{'params': self.sigma_net.parameters(), 'lr': lr * self.opt.lr_scale_nerf},
]
if self.opt.bg_radius > 0:
# params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
params.append({'params': self.bg_net.parameters(), 'lr': lr})
if self.opt.dmtet:
params.append({'params': self.dmtet.parameters(), 'lr': lr})
return params

216
nerf/network_grid.py Normal file
View File

@@ -0,0 +1,216 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from activation import trunc_exp, biased_softplus
from .renderer import NeRFRenderer, MLP
import numpy as np
from encoding import get_encoder
from .utils import safe_normalize
from tqdm import tqdm
import logging
logger = logging.getLogger(__name__)
class NeRFNetwork(NeRFRenderer):
def __init__(self,
opt,
num_layers=3,
hidden_dim=64,
num_layers_bg=2,
hidden_dim_bg=32,
level_dim=2
):
super().__init__(opt)
self.num_layers = opt.num_layers if hasattr(opt, 'num_layers') else num_layers
self.hidden_dim = opt.hidden_dim if hasattr(opt, 'hidden_dim') else hidden_dim
self.level_dim = opt.level_dim if hasattr(opt, 'level_dim') else level_dim
num_layers_bg = opt.num_layers_bg if hasattr(opt, 'num_layers_bg') else num_layers_bg
hidden_dim_bg = opt.hidden_dim_bg if hasattr(opt, 'hidden_dim_bg') else hidden_dim_bg
if self.opt.grid_type == 'hashgrid':
self.encoder, self.in_dim = get_encoder('hashgrid', input_dim=3, log2_hashmap_size=19, desired_resolution=2048 * self.bound, interpolation='smoothstep')
elif self.opt.grid_type == 'tilegrid':
self.encoder, self.in_dim = get_encoder(
'tiledgrid',
input_dim=3,
level_dim=self.level_dim,
log2_hashmap_size=16,
num_levels=16,
desired_resolution= 2048 * self.bound,
)
self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
# self.normal_net = MLP(self.in_dim, 3, hidden_dim, num_layers, bias=True)
# masking
self.grid_levels_mask = 0
# background network
if self.opt.bg_radius > 0:
self.num_layers_bg = num_layers_bg
self.hidden_dim_bg = hidden_dim_bg
# use a very simple network to avoid it learning the prompt...
self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3, multires=6)
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
else:
self.bg_net = None
def common_forward(self, x):
# sigma
h = self.encoder(x, bound=self.bound, max_level=self.max_level)
# Feature masking for coarse-to-fine training
if self.grid_levels_mask > 0:
h_mask: torch.Tensor = torch.arange(self.in_dim, device=h.device) < self.in_dim - self.grid_levels_mask * self.level_dim # (self.in_dim)
h_mask = h_mask.reshape(1, self.in_dim).float() # (1, self.in_dim)
h = h * h_mask # (N, self.in_dim)
h = self.sigma_net(h)
sigma = self.density_activation(h[..., 0] + self.density_blob(x))
albedo = torch.sigmoid(h[..., 1:])
return sigma, albedo
def forward(self, x, d, l=None, ratio=1, shading='albedo'):
# x: [N, 3], in [-bound, bound]
# d: [N, 3], view direction, nomalized in [-1, 1]
# l: [3], plane light direction, nomalized in [-1, 1]
# ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
sigma, albedo = self.common_forward(x)
if shading == 'albedo':
normal = None
color = albedo
else: # lambertian shading
normal = self.normal(x)
if shading == 'normal':
color = (normal + 1) / 2
else:
lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]
if shading == 'textureless':
color = lambertian.unsqueeze(-1).repeat(1, 3)
else: # 'lambertian'
color = albedo * lambertian.unsqueeze(-1)
return sigma, color, normal
def density(self, x):
# x: [N, 3], in [-bound, bound]
sigma, albedo = self.common_forward(x)
return {
'sigma': sigma,
'albedo': albedo,
}
def background(self, d):
h = self.encoder_bg(d) # [N, C]
h = self.bg_net(h)
# sigmoid activation for rgb
rgbs = torch.sigmoid(h)
return rgbs
# optimizer utils
def get_params(self, lr):
params = [
{'params': self.encoder.parameters(), 'lr': lr * 10},
{'params': self.sigma_net.parameters(), 'lr': lr * self.opt.lr_scale_nerf},
# {'params': self.normal_net.parameters(), 'lr': lr},
]
if self.opt.bg_radius > 0:
# params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
params.append({'params': self.bg_net.parameters(), 'lr': lr})
if self.opt.dmtet:
params.append({'params': self.dmtet.parameters(), 'lr': lr})
return params
def reset_sigmanet(self):
self.sigma_net.reset_parameters()
def init_nerf_from_sdf_color(self, rpst, albedo,
points=None, pretrain_iters=10000, lr=0.001, rpst_type='sdf',
):
self.reset_sigmanet()
# matching optimization
self.train()
self.grid_levels_mask = 0
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(list(self.parameters()), lr=lr)
milestones = [int(0.4 * pretrain_iters), int(0.8 * pretrain_iters)]
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)
rpst = rpst.squeeze().clamp(0, 1)
# rpst = torch.ones_like(rpst) * 0.4
pbar = tqdm(range(pretrain_iters), desc="NeRF sigma optimization")
rgb_loss = torch.tensor(0, device=rpst.device)
for i in pbar:
output = self.density(points)
if rpst_type == 'sdf':
pred_rpst = output['sigma'] - self.density_thresh
else:
pred_rpst = output['sigma']
sdf_loss = loss_fn(pred_rpst, rpst)
if albedo is not None:
pred_albedo = output['albedo']
rgb_loss = loss_fn(pred_albedo, albedo)
loss = 10 * sdf_loss + rgb_loss
else:
loss = sdf_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
pbar.set_postfix(loss=loss.item(), rgb_loss=rgb_loss.item(), sdf_loss=sdf_loss.item())
logger.info(f'lr: {lr} Accuracy: (pred_rpst[rpst>0]>0).sum() / (rpst>0).sum()')
pbar = tqdm(range(pretrain_iters), desc="NeRF color optimization")
def init_tet_from_sdf_color(self, sdf, colors=None, pretrain_iters=5000, lr=0.01):
self.train()
self.grid_levels_mask = 0
self.dmtet.reset_tet(reset_scale=False)
self.dmtet.init_tet_from_sdf(sdf, pretrain_iters=pretrain_iters, lr=lr)
if colors is not None:
self.reset_sigmanet()
loss_fn = torch.nn.MSELoss()
pretrain_iters = 5000
optimizer = torch.optim.Adam(list(self.parameters()), lr=0.01)
pbar = tqdm(range(pretrain_iters), desc="NeRF color optimization")
for i in pbar:
pred_albedo = self.density(self.dmtet.verts)['albedo']
loss = loss_fn(pred_albedo, colors)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_postfix(loss=loss.item())

161
nerf/network_grid_taichi.py Normal file
View File

@@ -0,0 +1,161 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from activation import trunc_exp
from .renderer import NeRFRenderer
import numpy as np
from encoding import get_encoder
from .utils import safe_normalize
from tqdm import tqdm
class MLP(nn.Module):
def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
super().__init__()
self.dim_in = dim_in
self.dim_out = dim_out
self.dim_hidden = dim_hidden
self.num_layers = num_layers
net = []
for l in range(num_layers):
net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
self.net = nn.ModuleList(net)
def forward(self, x):
for l in range(self.num_layers):
x = self.net[l](x)
if l != self.num_layers - 1:
x = F.relu(x, inplace=True)
return x
def reset_parameters(self):
@torch.no_grad()
def weight_init(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
nn.init.zeros_(m.bias)
self.apply(weight_init)
class NeRFNetwork(NeRFRenderer):
def __init__(self,
opt,
num_layers=2,
hidden_dim=32,
num_layers_bg=2,
hidden_dim_bg=16,
):
super().__init__(opt)
self.num_layers = opt.num_layers if hasattr(opt, 'num_layers') else num_layers
self.hidden_dim = opt.hidden_dim if hasattr(opt, 'hidden_dim') else hidden_dim
num_layers_bg = opt.num_layers_bg if hasattr(opt, 'num_layers_bg') else num_layers_bg
hidden_dim_bg = opt.hidden_dim_bg if hasattr(opt, 'hidden_dim_bg') else hidden_dim_bg
self.encoder, self.in_dim = get_encoder('hashgrid_taichi', input_dim=3, log2_hashmap_size=19, desired_resolution=2048 * self.bound, interpolation='smoothstep')
self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
# self.normal_net = MLP(self.in_dim, 3, hidden_dim, num_layers, bias=True)
self.grid_levels_mask = 0
# background network
if self.opt.bg_radius > 0:
self.num_layers_bg = num_layers_bg
self.hidden_dim_bg = hidden_dim_bg
# use a very simple network to avoid it learning the prompt...
self.encoder_bg, self.in_dim_bg = get_encoder('frequency_torch', input_dim=3, multires=4) # TODO: freq encoder can be replaced by a Taichi implementation
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
else:
self.bg_net = None
def common_forward(self, x):
# sigma
h = self.encoder(x, bound=self.bound)
# Feature masking for coarse-to-fine training
if self.grid_levels_mask > 0:
h_mask: torch.Tensor = torch.arange(self.in_dim, device=h.device) < self.in_dim - self.grid_levels_mask * self.level_dim # (self.in_dim)
h_mask = h_mask.reshape(1, self.in_dim).float() # (1, self.in_dim)
h = h * h_mask # (N, self.in_dim)
h = self.sigma_net(h)
sigma = self.density_activation(h[..., 0] + self.density_blob(x))
albedo = torch.sigmoid(h[..., 1:])
return sigma, albedo
def forward(self, x, d, l=None, ratio=1, shading='albedo'):
# x: [N, 3], in [-bound, bound]
# d: [N, 3], view direction, nomalized in [-1, 1]
# l: [3], plane light direction, nomalized in [-1, 1]
# ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
sigma, albedo = self.common_forward(x)
if shading == 'albedo':
normal = None
color = albedo
else: # lambertian shading
normal = self.normal(x)
lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]
if shading == 'textureless':
color = lambertian.unsqueeze(-1).repeat(1, 3)
elif shading == 'normal':
color = (normal + 1) / 2
else: # 'lambertian'
color = albedo * lambertian.unsqueeze(-1)
return sigma, color, normal
def density(self, x):
# x: [N, 3], in [-bound, bound]
sigma, albedo = self.common_forward(x)
return {
'sigma': sigma,
'albedo': albedo,
}
def background(self, d):
h = self.encoder_bg(d) # [N, C]
h = self.bg_net(h)
# sigmoid activation for rgb
rgbs = torch.sigmoid(h)
return rgbs
# optimizer utils
def get_params(self, lr):
params = [
{'params': self.encoder.parameters(), 'lr': lr * 10},
{'params': self.sigma_net.parameters(), 'lr': lr * self.opt.lr_scale_nerf},
# {'params': self.normal_net.parameters(), 'lr': lr},
]
if self.opt.bg_radius > 0:
# params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
params.append({'params': self.bg_net.parameters(), 'lr': lr})
if self.opt.dmtet:
params.append({'params': self.dmtet.parameters(), 'lr': lr})
return params

178
nerf/network_grid_tcnn.py Normal file
View File

@@ -0,0 +1,178 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from activation import trunc_exp, biased_softplus
from .renderer import NeRFRenderer
import numpy as np
from encoding import get_encoder
from .utils import safe_normalize
import tinycudann as tcnn
class MLP(nn.Module):
def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
super().__init__()
self.dim_in = dim_in
self.dim_out = dim_out
self.dim_hidden = dim_hidden
self.num_layers = num_layers
net = []
for l in range(num_layers):
net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
self.net = nn.ModuleList(net)
def forward(self, x):
for l in range(self.num_layers):
x = self.net[l](x)
if l != self.num_layers - 1:
x = F.relu(x, inplace=True)
return x
class NeRFNetwork(NeRFRenderer):
def __init__(self,
opt,
num_layers=3,
hidden_dim=64,
num_layers_bg=2,
hidden_dim_bg=32,
):
super().__init__(opt)
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.encoder = tcnn.Encoding(
n_input_dims=3,
encoding_config={
"otype": "HashGrid",
"n_levels": 16,
"n_features_per_level": 2,
"log2_hashmap_size": 19,
"base_resolution": 16,
"interpolation": "Smoothstep",
"per_level_scale": np.exp2(np.log2(2048 * self.bound / 16) / (16 - 1)),
},
dtype=torch.float32, # ENHANCE: default float16 seems unstable...
)
self.in_dim = self.encoder.n_output_dims
# use torch MLP, as tcnn MLP doesn't impl second-order derivative
self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
self.density_activation = trunc_exp if self.opt.density_activation == 'exp' else biased_softplus
# background network
if self.opt.bg_radius > 0:
self.num_layers_bg = num_layers_bg
self.hidden_dim_bg = hidden_dim_bg
# use a very simple network to avoid it learning the prompt...
self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3, multires=6)
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
else:
self.bg_net = None
def common_forward(self, x):
# sigma
enc = self.encoder((x + self.bound) / (2 * self.bound)).float()
h = self.sigma_net(enc)
sigma = self.density_activation(h[..., 0] + self.density_blob(x))
albedo = torch.sigmoid(h[..., 1:])
return sigma, albedo
def normal(self, x):
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
x.requires_grad_(True)
sigma, albedo = self.common_forward(x)
# query gradient
normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
# normal = self.finite_difference_normal(x)
normal = safe_normalize(normal)
normal = torch.nan_to_num(normal)
return normal
def forward(self, x, d, l=None, ratio=1, shading='albedo'):
# x: [N, 3], in [-bound, bound]
# d: [N, 3], view direction, nomalized in [-1, 1]
# l: [3], plane light direction, nomalized in [-1, 1]
# ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
if shading == 'albedo':
sigma, albedo = self.common_forward(x)
normal = None
color = albedo
else: # lambertian shading
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
x.requires_grad_(True)
sigma, albedo = self.common_forward(x)
normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
normal = safe_normalize(normal)
normal = torch.nan_to_num(normal)
lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]
if shading == 'textureless':
color = lambertian.unsqueeze(-1).repeat(1, 3)
elif shading == 'normal':
color = (normal + 1) / 2
else: # 'lambertian'
color = albedo * lambertian.unsqueeze(-1)
return sigma, color, normal
def density(self, x):
# x: [N, 3], in [-bound, bound]
sigma, albedo = self.common_forward(x)
return {
'sigma': sigma,
'albedo': albedo,
}
def background(self, d):
h = self.encoder_bg(d) # [N, C]
h = self.bg_net(h)
# sigmoid activation for rgb
rgbs = torch.sigmoid(h)
return rgbs
# optimizer utils
def get_params(self, lr):
params = [
{'params': self.encoder.parameters(), 'lr': lr * 10},
{'params': self.sigma_net.parameters(), 'lr': lr},
]
if self.opt.bg_radius > 0:
params.append({'params': self.bg_net.parameters(), 'lr': lr})
if self.opt.dmtet:
params.append({'params': self.sdf, 'lr': lr})
params.append({'params': self.deform, 'lr': lr})
return params

329
nerf/provider.py Normal file
View File

@@ -0,0 +1,329 @@
import random
import numpy as np
from scipy.spatial.transform import Slerp, Rotation
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from .utils import get_rays, safe_normalize
DIR_COLORS = np.array([
[255, 0, 0, 255], # front
[0, 255, 0, 255], # side
[0, 0, 255, 255], # back
[255, 255, 0, 255], # side
[255, 0, 255, 255], # overhead
[0, 255, 255, 255], # bottom
], dtype=np.uint8)
def visualize_poses(poses, dirs, size=0.1):
# poses: [B, 4, 4], dirs: [B]
import trimesh
axes = trimesh.creation.axis(axis_length=4)
sphere = trimesh.creation.icosphere(radius=1)
objects = [axes, sphere]
for pose, dir in zip(poses, dirs):
# a camera is visualized with 8 line segments.
pos = pose[:3, 3]
a = pos + size * pose[:3, 0] + size * pose[:3, 1] - size * pose[:3, 2]
b = pos - size * pose[:3, 0] + size * pose[:3, 1] - size * pose[:3, 2]
c = pos - size * pose[:3, 0] - size * pose[:3, 1] - size * pose[:3, 2]
d = pos + size * pose[:3, 0] - size * pose[:3, 1] - size * pose[:3, 2]
segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a]])
segs = trimesh.load_path(segs)
# different color for different dirs
segs.colors = DIR_COLORS[[dir]].repeat(len(segs.entities), 0)
objects.append(segs)
trimesh.Scene(objects).show()
def get_view_direction(thetas, phis, overhead, front):
# phis [B,]; thetas: [B,]
# front = 0 [0, front)
# side (right) = 1 [front, 180)
# back = 2 [180, 180+front)
# side (left) = 3 [180+front, 360)
# top = 4 [0, overhead]
# bottom = 5 [180-overhead, 180]
res = torch.zeros(thetas.shape[0], dtype=torch.long)
# first determine by phis
phis = phis % (2 * np.pi)
res[(phis < front / 2) | (phis >= 2 * np.pi - front / 2)] = 0
res[(phis >= front / 2) & (phis < np.pi - front / 2)] = 1
res[(phis >= np.pi - front / 2) & (phis < np.pi + front / 2)] = 2
res[(phis >= np.pi + front / 2) & (phis < 2 * np.pi - front / 2)] = 3
# override by thetas
res[thetas <= overhead] = 4
res[thetas >= (np.pi - overhead)] = 5
return res
def rand_poses(size, device, opt, radius_range=[1, 1.5], theta_range=[0, 120], phi_range=[0, 360], return_dirs=False, angle_overhead=30, angle_front=60, uniform_sphere_rate=0.5):
''' generate random poses from an orbit camera
Args:
size: batch size of generated poses.
device: where to allocate the output.
radius: camera radius
theta_range: [min, max], should be in [0, pi]
phi_range: [min, max], should be in [0, 2 * pi]
Return:
poses: [size, 4, 4]
'''
theta_range = np.array(theta_range) / 180 * np.pi
phi_range = np.array(phi_range) / 180 * np.pi
angle_overhead = angle_overhead / 180 * np.pi
angle_front = angle_front / 180 * np.pi
radius = torch.rand(size, device=device) * (radius_range[1] - radius_range[0]) + radius_range[0]
if random.random() < uniform_sphere_rate:
unit_centers = F.normalize(
torch.stack([
(torch.rand(size, device=device) - 0.5) * 2.0,
torch.rand(size, device=device),
(torch.rand(size, device=device) - 0.5) * 2.0,
], dim=-1), p=2, dim=1
)
thetas = torch.acos(unit_centers[:,1])
phis = torch.atan2(unit_centers[:,0], unit_centers[:,2])
phis[phis < 0] += 2 * np.pi
centers = unit_centers * radius.unsqueeze(-1)
else:
thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0]
phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0]
phis[phis < 0] += 2 * np.pi
centers = torch.stack([
radius * torch.sin(thetas) * torch.sin(phis),
radius * torch.cos(thetas),
radius * torch.sin(thetas) * torch.cos(phis),
], dim=-1) # [B, 3]
targets = 0
# jitters
if opt.jitter_pose:
jit_center = opt.jitter_center # 0.015 # was 0.2
jit_target = opt.jitter_target
centers += torch.rand_like(centers) * jit_center - jit_center/2.0
targets += torch.randn_like(centers) * jit_target
# lookat
forward_vector = safe_normalize(centers - targets)
up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(size, 1)
right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
if opt.jitter_pose:
up_noise = torch.randn_like(up_vector) * opt.jitter_up
else:
up_noise = 0
up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1) + up_noise)
poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
poses[:, :3, 3] = centers
if return_dirs:
dirs = get_view_direction(thetas, phis, angle_overhead, angle_front)
else:
dirs = None
# back to degree
thetas = thetas / np.pi * 180
phis = phis / np.pi * 180
return poses, dirs, thetas, phis, radius
def circle_poses(device, radius=torch.tensor([3.2]), theta=torch.tensor([60]), phi=torch.tensor([0]), return_dirs=False, angle_overhead=30, angle_front=60):
theta = theta / 180 * np.pi
phi = phi / 180 * np.pi
angle_overhead = angle_overhead / 180 * np.pi
angle_front = angle_front / 180 * np.pi
centers = torch.stack([
radius * torch.sin(theta) * torch.sin(phi),
radius * torch.cos(theta),
radius * torch.sin(theta) * torch.cos(phi),
], dim=-1) # [B, 3]
# lookat
forward_vector = safe_normalize(centers)
up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(len(centers), 1)
right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1))
poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(len(centers), 1, 1)
poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
poses[:, :3, 3] = centers
if return_dirs:
dirs = get_view_direction(theta, phi, angle_overhead, angle_front)
else:
dirs = None
return poses, dirs
class NeRFDataset:
def __init__(self, opt, device, type='train', H=256, W=256, size=100):
super().__init__()
self.opt = opt
self.device = device
self.type = type # train, val, test
self.H = H
self.W = W
self.size = size
self.training = self.type in ['train', 'all']
self.cx = self.H / 2
self.cy = self.W / 2
self.near = self.opt.min_near
self.far = 1000 # infinite
# [debug] visualize poses
# poses, dirs, _, _, _ = rand_poses(100, self.device, opt, radius_range=self.opt.radius_range, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, jitter=self.opt.jitter_pose, uniform_sphere_rate=1)
# visualize_poses(poses.detach().cpu().numpy(), dirs.detach().cpu().numpy())
def get_default_view_data(self):
H = int(self.opt.known_view_scale * self.H)
W = int(self.opt.known_view_scale * self.W)
cx = H / 2
cy = W / 2
radii = torch.FloatTensor(self.opt.ref_radii).to(self.device)
thetas = torch.FloatTensor(self.opt.ref_polars).to(self.device)
phis = torch.FloatTensor(self.opt.ref_azimuths).to(self.device)
poses, dirs = circle_poses(self.device, radius=radii, theta=thetas, phi=phis, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front)
fov = self.opt.default_fovy
focal = H / (2 * np.tan(np.deg2rad(fov) / 2))
intrinsics = np.array([focal, focal, cx, cy])
projection = torch.tensor([
[2*focal/W, 0, 0, 0],
[0, -2*focal/H, 0, 0],
[0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)],
[0, 0, -1, 0]
], dtype=torch.float32, device=self.device).unsqueeze(0).repeat(len(radii), 1, 1)
mvp = projection @ torch.inverse(poses) # [B, 4, 4]
# sample a low-resolution but full image
rays = get_rays(poses, intrinsics, H, W, -1)
data = {
'H': H,
'W': W,
'rays_o': rays['rays_o'],
'rays_d': rays['rays_d'],
'dir': dirs,
'mvp': mvp,
'polar': self.opt.ref_polars,
'azimuth': self.opt.ref_azimuths,
'radius': self.opt.ref_radii,
}
return data
def collate(self, index):
B = len(index)
if self.training:
# random pose on the fly
poses, dirs, thetas, phis, radius = rand_poses(B, self.device, self.opt, radius_range=self.opt.radius_range, theta_range=self.opt.theta_range, phi_range=self.opt.phi_range, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, uniform_sphere_rate=self.opt.uniform_sphere_rate)
# random focal
fov = random.random() * (self.opt.fovy_range[1] - self.opt.fovy_range[0]) + self.opt.fovy_range[0]
elif self.type == 'six_views':
# six views
thetas_six = [90]*4 + [1e-6] + [180]
phis_six = [0, 90, 180, -90, 0, 0]
thetas = torch.FloatTensor([thetas_six[index[0]]]).to(self.device)
phis = torch.FloatTensor([phis_six[index[0]]]).to(self.device)
radius = torch.FloatTensor([self.opt.default_radius]).to(self.device)
poses, dirs = circle_poses(self.device, radius=radius, theta=thetas, phi=phis, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front)
# fixed focal
fov = self.opt.default_fovy
else:
# circle pose
thetas = torch.FloatTensor([self.opt.default_polar]).to(self.device)
phis = torch.FloatTensor([(index[0] / self.size) * 360]).to(self.device)
radius = torch.FloatTensor([self.opt.default_radius]).to(self.device)
poses, dirs = circle_poses(self.device, radius=radius, theta=thetas, phi=phis, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front)
# fixed focal
fov = self.opt.default_fovy
focal = self.H / (2 * np.tan(np.deg2rad(fov) / 2))
intrinsics = np.array([focal, focal, self.cx, self.cy])
projection = torch.tensor([
[2*focal/self.W, 0, 0, 0],
[0, -2*focal/self.H, 0, 0],
[0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)],
[0, 0, -1, 0]
], dtype=torch.float32, device=self.device).unsqueeze(0)
mvp = projection @ torch.inverse(poses) # [1, 4, 4]
# sample a low-resolution but full image
rays = get_rays(poses, intrinsics, self.H, self.W, -1)
# delta polar/azimuth/radius to default view
delta_polar = thetas - self.opt.default_polar
delta_azimuth = phis - self.opt.default_azimuth
delta_azimuth[delta_azimuth > 180] -= 360 # range in [-180, 180]
delta_radius = radius - self.opt.default_radius
data = {
'H': self.H,
'W': self.W,
'rays_o': rays['rays_o'],
'rays_d': rays['rays_d'],
'dir': dirs,
'mvp': mvp,
'polar': delta_polar,
'azimuth': delta_azimuth,
'radius': delta_radius,
}
return data
def dataloader(self, batch_size=None):
batch_size = batch_size or self.opt.batch_size
loader = DataLoader(list(range(self.size)), batch_size=batch_size, collate_fn=self.collate, shuffle=self.training, num_workers=0)
loader._data = self
return loader
def generate_grid_points(resolution=128, device='cuda'):
# resolution: number of points along each dimension
# Generate the grid points
x = torch.linspace(0, 1, resolution)
y = torch.linspace(0, 1, resolution)
z = torch.linspace(0, 1, resolution)
# Create the meshgrid
grid_x, grid_y, grid_z = torch.meshgrid(x, y, z)
# Flatten the grid points if needed
grid_points = torch.stack((grid_x.flatten(), grid_y.flatten(), grid_z.flatten()), dim=1).to(device)
return grid_points

1575
nerf/renderer.py Normal file

File diff suppressed because it is too large Load Diff

1599
nerf/utils.py Normal file

File diff suppressed because it is too large Load Diff