first commit
This commit is contained in:
127
nerf/clip.py
Normal file
127
nerf/clip.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as TF
|
||||
|
||||
# import clip
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer, CLIPProcessor
|
||||
from torchvision import transforms
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def spherical_dist_loss(x, y):
|
||||
x = F.normalize(x, dim=-1)
|
||||
y = F.normalize(y, dim=-1)
|
||||
# print(x.shape, y.shape)
|
||||
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
|
||||
|
||||
|
||||
class CLIP(nn.Module):
|
||||
def __init__(self, device,
|
||||
# clip_name = 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
|
||||
clip_name = 'openai/clip-vit-large-patch14'
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.device = device
|
||||
|
||||
clip_name = clip_name
|
||||
|
||||
# self.feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_name)
|
||||
self.clip_model = CLIPModel.from_pretrained(clip_name).cuda()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14')
|
||||
# self.tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch32')
|
||||
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
||||
|
||||
# self.normalize = transforms.Normalize(mean=self.feature_extractor.image_mean, std=self.feature_extractor.image_std)
|
||||
|
||||
# self.resize = transforms.Resize(224)
|
||||
|
||||
# # image augmentation
|
||||
# self.aug = T.Compose([
|
||||
# T.Resize((224, 224)),
|
||||
# T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
# ])
|
||||
|
||||
|
||||
def get_text_embeds(self, prompt, neg_prompt=None, dir=None):
|
||||
|
||||
clip_text_input = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).input_ids.cuda()
|
||||
text_z = self.clip_model.get_text_features(clip_text_input)
|
||||
# text = clip.tokenize(prompt).to(self.device)
|
||||
# text_z = self.clip_model.encode_text(text)
|
||||
text_z = text_z / text_z.norm(dim=-1, keepdim=True)
|
||||
|
||||
return text_z
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
pass
|
||||
|
||||
def get_img_embeds(self, img):
|
||||
img = self.aug(img)
|
||||
image_z = self.clip_model.get_image_features(img)
|
||||
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
|
||||
return image_z
|
||||
|
||||
|
||||
# def train_step(self, text_z, pred_rgb, image_ref_clip, **kwargs):
|
||||
|
||||
# pred_rgb = self.resize(pred_rgb)
|
||||
# pred_rgb = self.normalize(pred_rgb)
|
||||
|
||||
# image_z = self.clip_model.get_image_features(pred_rgb)
|
||||
# image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
|
||||
|
||||
# # print(image_z.shape, text_z.shape)
|
||||
# loss = spherical_dist_loss(image_z, image_ref_clip)
|
||||
|
||||
# # loss = - (image_z * text_z).sum(-1).mean()
|
||||
|
||||
# return loss
|
||||
|
||||
def train_step(self, text_z, pred_rgb):
|
||||
|
||||
pred_rgb = self.aug(pred_rgb)
|
||||
|
||||
image_z = self.clip_model.encode_image(pred_rgb)
|
||||
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
|
||||
|
||||
loss = - (image_z * text_z).sum(-1).mean()
|
||||
# loss = spherical_dist_loss(image_z, text_z)
|
||||
return loss
|
||||
|
||||
def text_loss(self, text_z, pred_rgb):
|
||||
|
||||
pred_rgb = self.resize(pred_rgb)
|
||||
pred_rgb = self.normalize(pred_rgb)
|
||||
|
||||
image_z = self.clip_model.get_image_features(pred_rgb)
|
||||
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
|
||||
|
||||
# print(image_z.shape, text_z.shape)
|
||||
loss = spherical_dist_loss(image_z, text_z)
|
||||
|
||||
# loss = - (image_z * text_z).sum(-1).mean()
|
||||
|
||||
return loss
|
||||
|
||||
def img_loss(self, img_ref_z, pred_rgb):
|
||||
# pred_rgb = self.aug(pred_rgb)
|
||||
pred_rgb = self.resize(pred_rgb)
|
||||
pred_rgb = self.normalize(pred_rgb)
|
||||
|
||||
image_z = self.clip_model.get_image_features(pred_rgb)
|
||||
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
|
||||
|
||||
# loss = - (image_z * img_ref_z).sum(-1).mean()
|
||||
loss = spherical_dist_loss(image_z, img_ref_z)
|
||||
|
||||
return loss
|
||||
485
nerf/gui.py
Normal file
485
nerf/gui.py
Normal file
@@ -0,0 +1,485 @@
|
||||
import math
|
||||
import torch
|
||||
import numpy as np
|
||||
import dearpygui.dearpygui as dpg
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
|
||||
from nerf.utils import *
|
||||
|
||||
|
||||
class OrbitCamera:
|
||||
def __init__(self, W, H, r=2, fovy=60):
|
||||
self.W = W
|
||||
self.H = H
|
||||
self.radius = r # camera distance from center
|
||||
self.fovy = fovy # in degree
|
||||
self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point
|
||||
self.rot = R.from_matrix(np.eye(3))
|
||||
self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized!
|
||||
self.near = 0.001
|
||||
self.far = 1000
|
||||
|
||||
# pose
|
||||
@property
|
||||
def pose(self):
|
||||
# first move camera to radius
|
||||
res = np.eye(4, dtype=np.float32)
|
||||
res[2, 3] = self.radius
|
||||
# rotate
|
||||
rot = np.eye(4, dtype=np.float32)
|
||||
rot[:3, :3] = self.rot.as_matrix()
|
||||
res = rot @ res
|
||||
# translate
|
||||
res[:3, 3] -= self.center
|
||||
return res
|
||||
|
||||
# intrinsics
|
||||
@property
|
||||
def intrinsics(self):
|
||||
focal = self.H / (2 * np.tan(np.deg2rad(self.fovy) / 2))
|
||||
return np.array([focal, focal, self.W // 2, self.H // 2])
|
||||
|
||||
@property
|
||||
def mvp(self):
|
||||
focal = self.H / (2 * np.tan(np.deg2rad(self.fovy) / 2))
|
||||
projection = np.array([
|
||||
[2*focal/self.W, 0, 0, 0],
|
||||
[0, -2*focal/self.H, 0, 0],
|
||||
[0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)],
|
||||
[0, 0, -1, 0]
|
||||
], dtype=np.float32)
|
||||
|
||||
return projection @ np.linalg.inv(self.pose) # [4, 4]
|
||||
|
||||
def orbit(self, dx, dy):
|
||||
# rotate along camera up/side axis!
|
||||
side = self.rot.as_matrix()[:3, 0] # why this is side --> ? # already normalized.
|
||||
rotvec_x = self.up * np.deg2rad(-0.1 * dx)
|
||||
rotvec_y = side * np.deg2rad(-0.1 * dy)
|
||||
self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot
|
||||
|
||||
def scale(self, delta):
|
||||
self.radius *= 1.1 ** (-delta)
|
||||
|
||||
def pan(self, dx, dy, dz=0):
|
||||
# pan in camera coordinate system (careful on the sensitivity!)
|
||||
self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([dx, -dy, dz])
|
||||
|
||||
|
||||
class NeRFGUI:
|
||||
def __init__(self, opt, trainer, loader=None, debug=True):
|
||||
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
|
||||
self.W = opt.W
|
||||
self.H = opt.H
|
||||
self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
|
||||
self.debug = debug
|
||||
self.bg_color = torch.ones(3, dtype=torch.float32) # default white bg
|
||||
self.training = False
|
||||
self.step = 0 # training step
|
||||
|
||||
self.trainer = trainer
|
||||
self.loader = loader
|
||||
self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
|
||||
self.need_update = True # camera moved, should reset accumulation
|
||||
self.spp = 1 # sample per pixel
|
||||
self.light_dir = np.array([opt.light_theta, opt.light_phi])
|
||||
self.ambient_ratio = 1.0
|
||||
self.mode = 'image' # choose from ['image', 'depth']
|
||||
self.shading = 'albedo'
|
||||
|
||||
self.dynamic_resolution = True if not self.opt.dmtet else False
|
||||
self.downscale = 1
|
||||
self.train_steps = 16
|
||||
|
||||
dpg.create_context()
|
||||
self.register_dpg()
|
||||
self.test_step()
|
||||
|
||||
|
||||
def __del__(self):
|
||||
dpg.destroy_context()
|
||||
|
||||
|
||||
def train_step(self):
|
||||
|
||||
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
||||
starter.record()
|
||||
|
||||
outputs = self.trainer.train_gui(self.loader, step=self.train_steps)
|
||||
|
||||
ender.record()
|
||||
torch.cuda.synchronize()
|
||||
t = starter.elapsed_time(ender)
|
||||
|
||||
self.step += self.train_steps
|
||||
self.need_update = True
|
||||
|
||||
dpg.set_value("_log_train_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
|
||||
dpg.set_value("_log_train_log", f'step = {self.step: 5d} (+{self.train_steps: 2d}), loss = {outputs["loss"]:.4f}, lr = {outputs["lr"]:.5f}')
|
||||
|
||||
# dynamic train steps
|
||||
# max allowed train time per-frame is 500 ms
|
||||
full_t = t / self.train_steps * 16
|
||||
train_steps = min(16, max(4, int(16 * 500 / full_t)))
|
||||
if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8:
|
||||
self.train_steps = train_steps
|
||||
|
||||
|
||||
def prepare_buffer(self, outputs):
|
||||
if self.mode == 'image':
|
||||
return outputs['image'].astype(np.float32)
|
||||
else:
|
||||
depth = outputs['depth'].astype(np.float32)
|
||||
depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-6)
|
||||
return np.expand_dims(depth, -1).repeat(3, -1)
|
||||
|
||||
|
||||
def test_step(self):
|
||||
|
||||
if self.need_update or self.spp < self.opt.max_spp:
|
||||
|
||||
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
||||
starter.record()
|
||||
|
||||
outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.cam.mvp, self.W, self.H, self.bg_color, self.spp, self.downscale, self.light_dir, self.ambient_ratio, self.shading)
|
||||
|
||||
ender.record()
|
||||
torch.cuda.synchronize()
|
||||
t = starter.elapsed_time(ender)
|
||||
|
||||
# update dynamic resolution
|
||||
if self.dynamic_resolution:
|
||||
# max allowed infer time per-frame is 200 ms
|
||||
full_t = t / (self.downscale ** 2)
|
||||
downscale = min(1, max(1/4, math.sqrt(200 / full_t)))
|
||||
if downscale > self.downscale * 1.2 or downscale < self.downscale * 0.8:
|
||||
self.downscale = downscale
|
||||
|
||||
if self.need_update:
|
||||
self.render_buffer = self.prepare_buffer(outputs)
|
||||
self.spp = 1
|
||||
self.need_update = False
|
||||
else:
|
||||
self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1)
|
||||
self.spp += 1
|
||||
|
||||
dpg.set_value("_log_infer_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
|
||||
dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}')
|
||||
dpg.set_value("_log_spp", self.spp)
|
||||
dpg.set_value("_texture", self.render_buffer)
|
||||
|
||||
|
||||
def register_dpg(self):
|
||||
|
||||
### register texture
|
||||
|
||||
with dpg.texture_registry(show=False):
|
||||
dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture")
|
||||
|
||||
### register window
|
||||
|
||||
# the rendered image, as the primary window
|
||||
with dpg.window(tag="_primary_window", width=self.W, height=self.H):
|
||||
|
||||
# add the texture
|
||||
dpg.add_image("_texture")
|
||||
|
||||
dpg.set_primary_window("_primary_window", True)
|
||||
|
||||
# control window
|
||||
with dpg.window(label="Control", tag="_control_window", width=400, height=300):
|
||||
|
||||
# text prompt
|
||||
if self.opt.text is not None:
|
||||
dpg.add_text("text: " + self.opt.text, tag="_log_prompt_text")
|
||||
|
||||
if self.opt.negative != '':
|
||||
dpg.add_text("negative text: " + self.opt.negative, tag="_log_prompt_negative_text")
|
||||
|
||||
# button theme
|
||||
with dpg.theme() as theme_button:
|
||||
with dpg.theme_component(dpg.mvButton):
|
||||
dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))
|
||||
dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))
|
||||
dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))
|
||||
dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)
|
||||
dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)
|
||||
|
||||
# time
|
||||
if not self.opt.test:
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_text("Train time: ")
|
||||
dpg.add_text("no data", tag="_log_train_time")
|
||||
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_text("Infer time: ")
|
||||
dpg.add_text("no data", tag="_log_infer_time")
|
||||
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_text("SPP: ")
|
||||
dpg.add_text("1", tag="_log_spp")
|
||||
|
||||
# train button
|
||||
if not self.opt.test:
|
||||
with dpg.collapsing_header(label="Train", default_open=True):
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_text("Train: ")
|
||||
|
||||
def callback_train(sender, app_data):
|
||||
if self.training:
|
||||
self.training = False
|
||||
dpg.configure_item("_button_train", label="start")
|
||||
else:
|
||||
self.training = True
|
||||
dpg.configure_item("_button_train", label="stop")
|
||||
|
||||
dpg.add_button(label="start", tag="_button_train", callback=callback_train)
|
||||
dpg.bind_item_theme("_button_train", theme_button)
|
||||
|
||||
def callback_reset(sender, app_data):
|
||||
@torch.no_grad()
|
||||
def weight_reset(m: nn.Module):
|
||||
reset_parameters = getattr(m, "reset_parameters", None)
|
||||
if callable(reset_parameters):
|
||||
m.reset_parameters()
|
||||
self.trainer.model.apply(fn=weight_reset)
|
||||
self.trainer.model.reset_extra_state() # for cuda_ray density_grid and step_counter
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_button(label="reset", tag="_button_reset", callback=callback_reset)
|
||||
dpg.bind_item_theme("_button_reset", theme_button)
|
||||
|
||||
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_text("Checkpoint: ")
|
||||
|
||||
def callback_save(sender, app_data):
|
||||
self.trainer.save_checkpoint(full=True, best=False)
|
||||
dpg.set_value("_log_ckpt", "saved " + os.path.basename(self.trainer.stats["checkpoints"][-1]))
|
||||
self.trainer.epoch += 1 # use epoch to indicate different calls.
|
||||
|
||||
dpg.add_button(label="save", tag="_button_save", callback=callback_save)
|
||||
dpg.bind_item_theme("_button_save", theme_button)
|
||||
|
||||
dpg.add_text("", tag="_log_ckpt")
|
||||
|
||||
# save mesh
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_text("Marching Cubes: ")
|
||||
|
||||
def callback_mesh(sender, app_data):
|
||||
self.trainer.save_mesh()
|
||||
dpg.set_value("_log_mesh", "saved " + f'{self.trainer.name}_{self.trainer.epoch}.ply')
|
||||
self.trainer.epoch += 1 # use epoch to indicate different calls.
|
||||
|
||||
dpg.add_button(label="mesh", tag="_button_mesh", callback=callback_mesh)
|
||||
dpg.bind_item_theme("_button_mesh", theme_button)
|
||||
|
||||
dpg.add_text("", tag="_log_mesh")
|
||||
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_text("", tag="_log_train_log")
|
||||
|
||||
|
||||
# rendering options
|
||||
with dpg.collapsing_header(label="Options", default_open=True):
|
||||
|
||||
# dynamic rendering resolution
|
||||
with dpg.group(horizontal=True):
|
||||
|
||||
def callback_set_dynamic_resolution(sender, app_data):
|
||||
if self.dynamic_resolution:
|
||||
self.dynamic_resolution = False
|
||||
self.downscale = 1
|
||||
else:
|
||||
self.dynamic_resolution = True
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_checkbox(label="dynamic resolution", default_value=self.dynamic_resolution, callback=callback_set_dynamic_resolution)
|
||||
dpg.add_text(f"{self.W}x{self.H}", tag="_log_resolution")
|
||||
|
||||
# mode combo
|
||||
def callback_change_mode(sender, app_data):
|
||||
self.mode = app_data
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_combo(('image', 'depth'), label='mode', default_value=self.mode, callback=callback_change_mode)
|
||||
|
||||
# bg_color picker
|
||||
def callback_change_bg(sender, app_data):
|
||||
self.bg_color = torch.tensor(app_data[:3], dtype=torch.float32) # only need RGB in [0, 1]
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_color_edit((255, 255, 255), label="Background Color", width=200, tag="_color_editor", no_alpha=True, callback=callback_change_bg)
|
||||
|
||||
# fov slider
|
||||
def callback_set_fovy(sender, app_data):
|
||||
self.cam.fovy = app_data
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy)
|
||||
|
||||
# dt_gamma slider
|
||||
def callback_set_dt_gamma(sender, app_data):
|
||||
self.opt.dt_gamma = app_data
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_slider_float(label="dt_gamma", min_value=0, max_value=0.1, format="%.5f", default_value=self.opt.dt_gamma, callback=callback_set_dt_gamma)
|
||||
|
||||
# max_steps slider
|
||||
def callback_set_max_steps(sender, app_data):
|
||||
self.opt.max_steps = app_data
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_slider_int(label="max steps", min_value=1, max_value=1024, format="%d", default_value=self.opt.max_steps, callback=callback_set_max_steps)
|
||||
|
||||
# aabb slider
|
||||
def callback_set_aabb(sender, app_data, user_data):
|
||||
# user_data is the dimension for aabb (xmin, ymin, zmin, xmax, ymax, zmax)
|
||||
self.trainer.model.aabb_infer[user_data] = app_data
|
||||
|
||||
# also change train aabb ? [better not...]
|
||||
#self.trainer.model.aabb_train[user_data] = app_data
|
||||
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_separator()
|
||||
dpg.add_text("Axis-aligned bounding box:")
|
||||
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_slider_float(label="x", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=0)
|
||||
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=3)
|
||||
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_slider_float(label="y", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=1)
|
||||
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=4)
|
||||
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_slider_float(label="z", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=2)
|
||||
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=5)
|
||||
|
||||
# light dir
|
||||
def callback_set_light_dir(sender, app_data, user_data):
|
||||
self.light_dir[user_data] = app_data
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_separator()
|
||||
dpg.add_text("Plane Light Direction:")
|
||||
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_slider_float(label="theta", min_value=0, max_value=180, format="%.2f", default_value=self.opt.light_theta, callback=callback_set_light_dir, user_data=0)
|
||||
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_slider_float(label="phi", min_value=0, max_value=360, format="%.2f", default_value=self.opt.light_phi, callback=callback_set_light_dir, user_data=1)
|
||||
|
||||
# ambient ratio
|
||||
def callback_set_abm_ratio(sender, app_data):
|
||||
self.ambient_ratio = app_data
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_slider_float(label="ambient", min_value=0, max_value=1.0, format="%.5f", default_value=self.ambient_ratio, callback=callback_set_abm_ratio)
|
||||
|
||||
# shading mode
|
||||
def callback_change_shading(sender, app_data):
|
||||
self.shading = app_data
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_combo(('albedo', 'lambertian', 'textureless', 'normal'), label='shading', default_value=self.shading, callback=callback_change_shading)
|
||||
|
||||
|
||||
# debug info
|
||||
if self.debug:
|
||||
with dpg.collapsing_header(label="Debug"):
|
||||
# pose
|
||||
dpg.add_separator()
|
||||
dpg.add_text("Camera Pose:")
|
||||
dpg.add_text(str(self.cam.pose), tag="_log_pose")
|
||||
|
||||
|
||||
### register camera handler
|
||||
|
||||
def callback_camera_drag_rotate(sender, app_data):
|
||||
|
||||
if not dpg.is_item_focused("_primary_window"):
|
||||
return
|
||||
|
||||
dx = app_data[1]
|
||||
dy = app_data[2]
|
||||
|
||||
self.cam.orbit(dx, dy)
|
||||
self.need_update = True
|
||||
|
||||
if self.debug:
|
||||
dpg.set_value("_log_pose", str(self.cam.pose))
|
||||
|
||||
|
||||
def callback_camera_wheel_scale(sender, app_data):
|
||||
|
||||
if not dpg.is_item_focused("_primary_window"):
|
||||
return
|
||||
|
||||
delta = app_data
|
||||
|
||||
self.cam.scale(delta)
|
||||
self.need_update = True
|
||||
|
||||
if self.debug:
|
||||
dpg.set_value("_log_pose", str(self.cam.pose))
|
||||
|
||||
|
||||
def callback_camera_drag_pan(sender, app_data):
|
||||
|
||||
if not dpg.is_item_focused("_primary_window"):
|
||||
return
|
||||
|
||||
dx = app_data[1]
|
||||
dy = app_data[2]
|
||||
|
||||
self.cam.pan(dx, dy)
|
||||
self.need_update = True
|
||||
|
||||
if self.debug:
|
||||
dpg.set_value("_log_pose", str(self.cam.pose))
|
||||
|
||||
|
||||
with dpg.handler_registry():
|
||||
dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate)
|
||||
dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
|
||||
dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Right, callback=callback_camera_drag_pan)
|
||||
|
||||
|
||||
dpg.create_viewport(title='torch-ngp', width=self.W, height=self.H, resizable=False)
|
||||
|
||||
# TODO: seems dearpygui doesn't support resizing texture...
|
||||
# def callback_resize(sender, app_data):
|
||||
# self.W = app_data[0]
|
||||
# self.H = app_data[1]
|
||||
# # how to reload texture ???
|
||||
|
||||
# dpg.set_viewport_resize_callback(callback_resize)
|
||||
|
||||
### global theme
|
||||
with dpg.theme() as theme_no_padding:
|
||||
with dpg.theme_component(dpg.mvAll):
|
||||
# set all padding to 0 to avoid scroll bar
|
||||
dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core)
|
||||
dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core)
|
||||
dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core)
|
||||
|
||||
dpg.bind_item_theme("_primary_window", theme_no_padding)
|
||||
|
||||
dpg.setup_dearpygui()
|
||||
|
||||
#dpg.show_metrics()
|
||||
|
||||
dpg.show_viewport()
|
||||
|
||||
|
||||
def render(self):
|
||||
|
||||
while dpg.is_dearpygui_running():
|
||||
# update texture every frame
|
||||
if self.training:
|
||||
self.train_step()
|
||||
self.test_step()
|
||||
dpg.render_dearpygui_frame()
|
||||
238
nerf/network.py
Normal file
238
nerf/network.py
Normal file
@@ -0,0 +1,238 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from activation import trunc_exp
|
||||
from .renderer import NeRFRenderer
|
||||
|
||||
import numpy as np
|
||||
from encoding import get_encoder
|
||||
|
||||
from .utils import safe_normalize
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, bias=True):
|
||||
super().__init__()
|
||||
self.dim_in = dim_in
|
||||
self.dim_out = dim_out
|
||||
|
||||
self.dense = nn.Linear(self.dim_in, self.dim_out, bias=bias)
|
||||
self.norm = nn.LayerNorm(self.dim_out)
|
||||
self.activation = nn.SiLU(inplace=True)
|
||||
|
||||
if self.dim_in != self.dim_out:
|
||||
self.skip = nn.Linear(self.dim_in, self.dim_out, bias=False)
|
||||
else:
|
||||
self.skip = None
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, C]
|
||||
identity = x
|
||||
|
||||
out = self.dense(x)
|
||||
out = self.norm(out)
|
||||
|
||||
if self.skip is not None:
|
||||
identity = self.skip(identity)
|
||||
|
||||
out += identity
|
||||
out = self.activation(out)
|
||||
|
||||
return out
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, bias=True):
|
||||
super().__init__()
|
||||
self.dim_in = dim_in
|
||||
self.dim_out = dim_out
|
||||
|
||||
self.dense = nn.Linear(self.dim_in, self.dim_out, bias=bias)
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, C]
|
||||
|
||||
out = self.dense(x)
|
||||
out = self.activation(out)
|
||||
|
||||
return out
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True, block=BasicBlock):
|
||||
super().__init__()
|
||||
self.dim_in = dim_in
|
||||
self.dim_out = dim_out
|
||||
self.dim_hidden = dim_hidden
|
||||
self.num_layers = num_layers
|
||||
|
||||
net = []
|
||||
for l in range(num_layers):
|
||||
if l == 0:
|
||||
net.append(BasicBlock(self.dim_in, self.dim_hidden, bias=bias))
|
||||
elif l != num_layers - 1:
|
||||
net.append(block(self.dim_hidden, self.dim_hidden, bias=bias))
|
||||
else:
|
||||
net.append(nn.Linear(self.dim_hidden, self.dim_out, bias=bias))
|
||||
|
||||
self.net = nn.ModuleList(net)
|
||||
|
||||
def reset_parameters(self):
|
||||
@torch.no_grad()
|
||||
def weight_init(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
|
||||
nn.init.zeros_(m.bias)
|
||||
self.apply(weight_init)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
for l in range(self.num_layers):
|
||||
x = self.net[l](x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class NeRFNetwork(NeRFRenderer):
|
||||
def __init__(self,
|
||||
opt,
|
||||
num_layers=5, # 5 in paper
|
||||
hidden_dim=64, # 128 in paper
|
||||
num_layers_bg=2, # 3 in paper
|
||||
hidden_dim_bg=32, # 64 in paper
|
||||
encoding='frequency_torch', # pure pytorch
|
||||
output_dim=4, # 7 for DMTet (sdf 1 + color 3 + deform 3), 4 for NeRF
|
||||
):
|
||||
|
||||
super().__init__(opt)
|
||||
self.num_layers = opt.num_layers if hasattr(opt, 'num_layers') else num_layers
|
||||
self.hidden_dim = opt.hidden_dim if hasattr(opt, 'hidden_dim') else hidden_dim
|
||||
num_layers_bg = opt.num_layers_bg if hasattr(opt, 'num_layers_bg') else num_layers_bg
|
||||
hidden_dim_bg = opt.hidden_dim_bg if hasattr(opt, 'hidden_dim_bg') else hidden_dim_bg
|
||||
|
||||
self.encoder, self.in_dim = get_encoder(encoding, input_dim=3, multires=6)
|
||||
self.sigma_net = MLP(self.in_dim, output_dim, hidden_dim, num_layers, bias=True, block=ResBlock)
|
||||
|
||||
self.grid_levels_mask = 0
|
||||
|
||||
# background network
|
||||
if self.opt.bg_radius > 0:
|
||||
self.num_layers_bg = num_layers_bg
|
||||
self.hidden_dim_bg = hidden_dim_bg
|
||||
self.encoder_bg, self.in_dim_bg = get_encoder(encoding, input_dim=3, multires=4)
|
||||
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
|
||||
|
||||
else:
|
||||
self.bg_net = None
|
||||
|
||||
def common_forward(self, x):
|
||||
# x: [N, 3], in [-bound, bound]
|
||||
|
||||
# sigma
|
||||
h = self.encoder(x, bound=self.bound, max_level=self.max_level)
|
||||
|
||||
# Feature masking for coarse-to-fine training
|
||||
if self.grid_levels_mask > 0:
|
||||
h_mask: torch.Tensor = torch.arange(self.in_dim, device=h.device) < self.in_dim - self.grid_levels_mask * self.level_dim # (self.in_dim)
|
||||
h_mask = h_mask.reshape(1, self.in_dim).float() # (1, self.in_dim)
|
||||
h = h * h_mask # (N, self.in_dim)
|
||||
|
||||
h = self.sigma_net(h)
|
||||
|
||||
sigma = self.density_activation(h[..., 0] + self.density_blob(x))
|
||||
albedo = torch.sigmoid(h[..., 1:])
|
||||
|
||||
return sigma, albedo
|
||||
|
||||
def normal(self, x):
|
||||
|
||||
with torch.enable_grad():
|
||||
x.requires_grad_(True)
|
||||
sigma, albedo = self.common_forward(x)
|
||||
# query gradient
|
||||
normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
|
||||
|
||||
# normal = self.finite_difference_normal(x)
|
||||
normal = safe_normalize(normal)
|
||||
# normal = torch.nan_to_num(normal)
|
||||
|
||||
return normal
|
||||
|
||||
def forward(self, x, d, l=None, ratio=1, shading='albedo'):
|
||||
# x: [N, 3], in [-bound, bound]
|
||||
# d: [N, 3], view direction, nomalized in [-1, 1]
|
||||
# l: [3], plane light direction, nomalized in [-1, 1]
|
||||
# ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
|
||||
|
||||
if shading == 'albedo':
|
||||
# no need to query normal
|
||||
sigma, color = self.common_forward(x)
|
||||
normal = None
|
||||
|
||||
else:
|
||||
# query normal
|
||||
|
||||
# sigma, albedo = self.common_forward(x)
|
||||
# normal = self.normal(x)
|
||||
|
||||
with torch.enable_grad():
|
||||
x.requires_grad_(True)
|
||||
sigma, albedo = self.common_forward(x)
|
||||
# query gradient
|
||||
normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
|
||||
normal = safe_normalize(normal)
|
||||
# normal = torch.nan_to_num(normal)
|
||||
# normal = normal.detach()
|
||||
|
||||
# lambertian shading
|
||||
lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]
|
||||
|
||||
if shading == 'textureless':
|
||||
color = lambertian.unsqueeze(-1).repeat(1, 3)
|
||||
elif shading == 'normal':
|
||||
color = (normal + 1) / 2
|
||||
else: # 'lambertian'
|
||||
color = albedo * lambertian.unsqueeze(-1)
|
||||
|
||||
return sigma, color, normal
|
||||
|
||||
|
||||
def density(self, x):
|
||||
# x: [N, 3], in [-bound, bound]
|
||||
|
||||
sigma, albedo = self.common_forward(x)
|
||||
|
||||
return {
|
||||
'sigma': sigma,
|
||||
'albedo': albedo,
|
||||
}
|
||||
|
||||
|
||||
def background(self, d):
|
||||
|
||||
h = self.encoder_bg(d) # [N, C]
|
||||
|
||||
h = self.bg_net(h)
|
||||
|
||||
# sigmoid activation for rgb
|
||||
rgbs = torch.sigmoid(h)
|
||||
|
||||
return rgbs
|
||||
|
||||
# optimizer utils
|
||||
def get_params(self, lr):
|
||||
|
||||
params = [
|
||||
# {'params': self.encoder.parameters(), 'lr': lr * 10},
|
||||
{'params': self.sigma_net.parameters(), 'lr': lr * self.opt.lr_scale_nerf},
|
||||
]
|
||||
|
||||
if self.opt.bg_radius > 0:
|
||||
# params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
|
||||
params.append({'params': self.bg_net.parameters(), 'lr': lr})
|
||||
|
||||
if self.opt.dmtet:
|
||||
params.append({'params': self.dmtet.parameters(), 'lr': lr})
|
||||
|
||||
return params
|
||||
216
nerf/network_grid.py
Normal file
216
nerf/network_grid.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from activation import trunc_exp, biased_softplus
|
||||
from .renderer import NeRFRenderer, MLP
|
||||
|
||||
import numpy as np
|
||||
from encoding import get_encoder
|
||||
|
||||
from .utils import safe_normalize
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NeRFNetwork(NeRFRenderer):
|
||||
def __init__(self,
|
||||
opt,
|
||||
num_layers=3,
|
||||
hidden_dim=64,
|
||||
num_layers_bg=2,
|
||||
hidden_dim_bg=32,
|
||||
level_dim=2
|
||||
):
|
||||
|
||||
super().__init__(opt)
|
||||
|
||||
self.num_layers = opt.num_layers if hasattr(opt, 'num_layers') else num_layers
|
||||
self.hidden_dim = opt.hidden_dim if hasattr(opt, 'hidden_dim') else hidden_dim
|
||||
self.level_dim = opt.level_dim if hasattr(opt, 'level_dim') else level_dim
|
||||
num_layers_bg = opt.num_layers_bg if hasattr(opt, 'num_layers_bg') else num_layers_bg
|
||||
hidden_dim_bg = opt.hidden_dim_bg if hasattr(opt, 'hidden_dim_bg') else hidden_dim_bg
|
||||
|
||||
if self.opt.grid_type == 'hashgrid':
|
||||
self.encoder, self.in_dim = get_encoder('hashgrid', input_dim=3, log2_hashmap_size=19, desired_resolution=2048 * self.bound, interpolation='smoothstep')
|
||||
elif self.opt.grid_type == 'tilegrid':
|
||||
self.encoder, self.in_dim = get_encoder(
|
||||
'tiledgrid',
|
||||
input_dim=3,
|
||||
level_dim=self.level_dim,
|
||||
log2_hashmap_size=16,
|
||||
num_levels=16,
|
||||
desired_resolution= 2048 * self.bound,
|
||||
)
|
||||
self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
|
||||
# self.normal_net = MLP(self.in_dim, 3, hidden_dim, num_layers, bias=True)
|
||||
|
||||
# masking
|
||||
self.grid_levels_mask = 0
|
||||
|
||||
# background network
|
||||
if self.opt.bg_radius > 0:
|
||||
self.num_layers_bg = num_layers_bg
|
||||
self.hidden_dim_bg = hidden_dim_bg
|
||||
|
||||
# use a very simple network to avoid it learning the prompt...
|
||||
self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3, multires=6)
|
||||
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
|
||||
|
||||
else:
|
||||
self.bg_net = None
|
||||
|
||||
def common_forward(self, x):
|
||||
|
||||
# sigma
|
||||
h = self.encoder(x, bound=self.bound, max_level=self.max_level)
|
||||
|
||||
# Feature masking for coarse-to-fine training
|
||||
if self.grid_levels_mask > 0:
|
||||
h_mask: torch.Tensor = torch.arange(self.in_dim, device=h.device) < self.in_dim - self.grid_levels_mask * self.level_dim # (self.in_dim)
|
||||
h_mask = h_mask.reshape(1, self.in_dim).float() # (1, self.in_dim)
|
||||
h = h * h_mask # (N, self.in_dim)
|
||||
|
||||
h = self.sigma_net(h)
|
||||
|
||||
sigma = self.density_activation(h[..., 0] + self.density_blob(x))
|
||||
albedo = torch.sigmoid(h[..., 1:])
|
||||
|
||||
return sigma, albedo
|
||||
|
||||
|
||||
def forward(self, x, d, l=None, ratio=1, shading='albedo'):
|
||||
# x: [N, 3], in [-bound, bound]
|
||||
# d: [N, 3], view direction, nomalized in [-1, 1]
|
||||
# l: [3], plane light direction, nomalized in [-1, 1]
|
||||
# ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
|
||||
|
||||
sigma, albedo = self.common_forward(x)
|
||||
|
||||
if shading == 'albedo':
|
||||
normal = None
|
||||
color = albedo
|
||||
|
||||
else: # lambertian shading
|
||||
|
||||
normal = self.normal(x)
|
||||
if shading == 'normal':
|
||||
color = (normal + 1) / 2
|
||||
else:
|
||||
lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]
|
||||
if shading == 'textureless':
|
||||
color = lambertian.unsqueeze(-1).repeat(1, 3)
|
||||
else: # 'lambertian'
|
||||
color = albedo * lambertian.unsqueeze(-1)
|
||||
|
||||
return sigma, color, normal
|
||||
|
||||
|
||||
def density(self, x):
|
||||
# x: [N, 3], in [-bound, bound]
|
||||
|
||||
sigma, albedo = self.common_forward(x)
|
||||
|
||||
return {
|
||||
'sigma': sigma,
|
||||
'albedo': albedo,
|
||||
}
|
||||
|
||||
|
||||
def background(self, d):
|
||||
|
||||
h = self.encoder_bg(d) # [N, C]
|
||||
|
||||
h = self.bg_net(h)
|
||||
|
||||
# sigmoid activation for rgb
|
||||
rgbs = torch.sigmoid(h)
|
||||
|
||||
return rgbs
|
||||
|
||||
# optimizer utils
|
||||
def get_params(self, lr):
|
||||
|
||||
params = [
|
||||
{'params': self.encoder.parameters(), 'lr': lr * 10},
|
||||
{'params': self.sigma_net.parameters(), 'lr': lr * self.opt.lr_scale_nerf},
|
||||
# {'params': self.normal_net.parameters(), 'lr': lr},
|
||||
]
|
||||
|
||||
if self.opt.bg_radius > 0:
|
||||
# params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
|
||||
params.append({'params': self.bg_net.parameters(), 'lr': lr})
|
||||
|
||||
if self.opt.dmtet:
|
||||
params.append({'params': self.dmtet.parameters(), 'lr': lr})
|
||||
|
||||
return params
|
||||
|
||||
def reset_sigmanet(self):
|
||||
self.sigma_net.reset_parameters()
|
||||
|
||||
def init_nerf_from_sdf_color(self, rpst, albedo,
|
||||
points=None, pretrain_iters=10000, lr=0.001, rpst_type='sdf',
|
||||
):
|
||||
self.reset_sigmanet()
|
||||
# matching optimization
|
||||
self.train()
|
||||
self.grid_levels_mask = 0
|
||||
loss_fn = torch.nn.MSELoss()
|
||||
optimizer = torch.optim.Adam(list(self.parameters()), lr=lr)
|
||||
|
||||
milestones = [int(0.4 * pretrain_iters), int(0.8 * pretrain_iters)]
|
||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)
|
||||
|
||||
rpst = rpst.squeeze().clamp(0, 1)
|
||||
|
||||
# rpst = torch.ones_like(rpst) * 0.4
|
||||
pbar = tqdm(range(pretrain_iters), desc="NeRF sigma optimization")
|
||||
rgb_loss = torch.tensor(0, device=rpst.device)
|
||||
for i in pbar:
|
||||
output = self.density(points)
|
||||
if rpst_type == 'sdf':
|
||||
pred_rpst = output['sigma'] - self.density_thresh
|
||||
else:
|
||||
pred_rpst = output['sigma']
|
||||
sdf_loss = loss_fn(pred_rpst, rpst)
|
||||
|
||||
if albedo is not None:
|
||||
pred_albedo = output['albedo']
|
||||
rgb_loss = loss_fn(pred_albedo, albedo)
|
||||
loss = 10 * sdf_loss + rgb_loss
|
||||
else:
|
||||
loss = sdf_loss
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
pbar.set_postfix(loss=loss.item(), rgb_loss=rgb_loss.item(), sdf_loss=sdf_loss.item())
|
||||
logger.info(f'lr: {lr} Accuracy: (pred_rpst[rpst>0]>0).sum() / (rpst>0).sum()')
|
||||
pbar = tqdm(range(pretrain_iters), desc="NeRF color optimization")
|
||||
|
||||
|
||||
def init_tet_from_sdf_color(self, sdf, colors=None, pretrain_iters=5000, lr=0.01):
|
||||
self.train()
|
||||
self.grid_levels_mask = 0
|
||||
|
||||
self.dmtet.reset_tet(reset_scale=False)
|
||||
self.dmtet.init_tet_from_sdf(sdf, pretrain_iters=pretrain_iters, lr=lr)
|
||||
|
||||
if colors is not None:
|
||||
self.reset_sigmanet()
|
||||
loss_fn = torch.nn.MSELoss()
|
||||
pretrain_iters = 5000
|
||||
optimizer = torch.optim.Adam(list(self.parameters()), lr=0.01)
|
||||
pbar = tqdm(range(pretrain_iters), desc="NeRF color optimization")
|
||||
for i in pbar:
|
||||
pred_albedo = self.density(self.dmtet.verts)['albedo']
|
||||
loss = loss_fn(pred_albedo, colors)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
pbar.set_postfix(loss=loss.item())
|
||||
161
nerf/network_grid_taichi.py
Normal file
161
nerf/network_grid_taichi.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from activation import trunc_exp
|
||||
from .renderer import NeRFRenderer
|
||||
|
||||
import numpy as np
|
||||
from encoding import get_encoder
|
||||
|
||||
from .utils import safe_normalize
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
|
||||
super().__init__()
|
||||
self.dim_in = dim_in
|
||||
self.dim_out = dim_out
|
||||
self.dim_hidden = dim_hidden
|
||||
self.num_layers = num_layers
|
||||
|
||||
net = []
|
||||
for l in range(num_layers):
|
||||
net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
|
||||
|
||||
self.net = nn.ModuleList(net)
|
||||
|
||||
def forward(self, x):
|
||||
for l in range(self.num_layers):
|
||||
x = self.net[l](x)
|
||||
if l != self.num_layers - 1:
|
||||
x = F.relu(x, inplace=True)
|
||||
return x
|
||||
|
||||
def reset_parameters(self):
|
||||
@torch.no_grad()
|
||||
def weight_init(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
|
||||
nn.init.zeros_(m.bias)
|
||||
self.apply(weight_init)
|
||||
|
||||
|
||||
class NeRFNetwork(NeRFRenderer):
|
||||
def __init__(self,
|
||||
opt,
|
||||
num_layers=2,
|
||||
hidden_dim=32,
|
||||
num_layers_bg=2,
|
||||
hidden_dim_bg=16,
|
||||
):
|
||||
|
||||
super().__init__(opt)
|
||||
self.num_layers = opt.num_layers if hasattr(opt, 'num_layers') else num_layers
|
||||
self.hidden_dim = opt.hidden_dim if hasattr(opt, 'hidden_dim') else hidden_dim
|
||||
num_layers_bg = opt.num_layers_bg if hasattr(opt, 'num_layers_bg') else num_layers_bg
|
||||
hidden_dim_bg = opt.hidden_dim_bg if hasattr(opt, 'hidden_dim_bg') else hidden_dim_bg
|
||||
|
||||
self.encoder, self.in_dim = get_encoder('hashgrid_taichi', input_dim=3, log2_hashmap_size=19, desired_resolution=2048 * self.bound, interpolation='smoothstep')
|
||||
|
||||
self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
|
||||
# self.normal_net = MLP(self.in_dim, 3, hidden_dim, num_layers, bias=True)
|
||||
|
||||
self.grid_levels_mask = 0
|
||||
|
||||
# background network
|
||||
if self.opt.bg_radius > 0:
|
||||
self.num_layers_bg = num_layers_bg
|
||||
self.hidden_dim_bg = hidden_dim_bg
|
||||
# use a very simple network to avoid it learning the prompt...
|
||||
self.encoder_bg, self.in_dim_bg = get_encoder('frequency_torch', input_dim=3, multires=4) # TODO: freq encoder can be replaced by a Taichi implementation
|
||||
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
|
||||
|
||||
else:
|
||||
self.bg_net = None
|
||||
|
||||
def common_forward(self, x):
|
||||
|
||||
# sigma
|
||||
h = self.encoder(x, bound=self.bound)
|
||||
|
||||
# Feature masking for coarse-to-fine training
|
||||
if self.grid_levels_mask > 0:
|
||||
h_mask: torch.Tensor = torch.arange(self.in_dim, device=h.device) < self.in_dim - self.grid_levels_mask * self.level_dim # (self.in_dim)
|
||||
h_mask = h_mask.reshape(1, self.in_dim).float() # (1, self.in_dim)
|
||||
h = h * h_mask # (N, self.in_dim)
|
||||
|
||||
h = self.sigma_net(h)
|
||||
|
||||
sigma = self.density_activation(h[..., 0] + self.density_blob(x))
|
||||
albedo = torch.sigmoid(h[..., 1:])
|
||||
|
||||
return sigma, albedo
|
||||
|
||||
def forward(self, x, d, l=None, ratio=1, shading='albedo'):
|
||||
# x: [N, 3], in [-bound, bound]
|
||||
# d: [N, 3], view direction, nomalized in [-1, 1]
|
||||
# l: [3], plane light direction, nomalized in [-1, 1]
|
||||
# ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
|
||||
|
||||
sigma, albedo = self.common_forward(x)
|
||||
|
||||
if shading == 'albedo':
|
||||
normal = None
|
||||
color = albedo
|
||||
|
||||
else: # lambertian shading
|
||||
normal = self.normal(x)
|
||||
|
||||
lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]
|
||||
|
||||
if shading == 'textureless':
|
||||
color = lambertian.unsqueeze(-1).repeat(1, 3)
|
||||
elif shading == 'normal':
|
||||
color = (normal + 1) / 2
|
||||
else: # 'lambertian'
|
||||
color = albedo * lambertian.unsqueeze(-1)
|
||||
|
||||
return sigma, color, normal
|
||||
|
||||
|
||||
def density(self, x):
|
||||
# x: [N, 3], in [-bound, bound]
|
||||
|
||||
sigma, albedo = self.common_forward(x)
|
||||
|
||||
return {
|
||||
'sigma': sigma,
|
||||
'albedo': albedo,
|
||||
}
|
||||
|
||||
|
||||
def background(self, d):
|
||||
|
||||
h = self.encoder_bg(d) # [N, C]
|
||||
|
||||
h = self.bg_net(h)
|
||||
|
||||
# sigmoid activation for rgb
|
||||
rgbs = torch.sigmoid(h)
|
||||
|
||||
return rgbs
|
||||
|
||||
# optimizer utils
|
||||
def get_params(self, lr):
|
||||
|
||||
params = [
|
||||
{'params': self.encoder.parameters(), 'lr': lr * 10},
|
||||
{'params': self.sigma_net.parameters(), 'lr': lr * self.opt.lr_scale_nerf},
|
||||
# {'params': self.normal_net.parameters(), 'lr': lr},
|
||||
]
|
||||
|
||||
if self.opt.bg_radius > 0:
|
||||
# params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
|
||||
params.append({'params': self.bg_net.parameters(), 'lr': lr})
|
||||
|
||||
if self.opt.dmtet:
|
||||
params.append({'params': self.dmtet.parameters(), 'lr': lr})
|
||||
|
||||
return params
|
||||
178
nerf/network_grid_tcnn.py
Normal file
178
nerf/network_grid_tcnn.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from activation import trunc_exp, biased_softplus
|
||||
from .renderer import NeRFRenderer
|
||||
|
||||
import numpy as np
|
||||
from encoding import get_encoder
|
||||
|
||||
from .utils import safe_normalize
|
||||
|
||||
import tinycudann as tcnn
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
|
||||
super().__init__()
|
||||
self.dim_in = dim_in
|
||||
self.dim_out = dim_out
|
||||
self.dim_hidden = dim_hidden
|
||||
self.num_layers = num_layers
|
||||
|
||||
net = []
|
||||
for l in range(num_layers):
|
||||
net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
|
||||
|
||||
self.net = nn.ModuleList(net)
|
||||
|
||||
def forward(self, x):
|
||||
for l in range(self.num_layers):
|
||||
x = self.net[l](x)
|
||||
if l != self.num_layers - 1:
|
||||
x = F.relu(x, inplace=True)
|
||||
return x
|
||||
|
||||
|
||||
class NeRFNetwork(NeRFRenderer):
|
||||
def __init__(self,
|
||||
opt,
|
||||
num_layers=3,
|
||||
hidden_dim=64,
|
||||
num_layers_bg=2,
|
||||
hidden_dim_bg=32,
|
||||
):
|
||||
|
||||
super().__init__(opt)
|
||||
|
||||
self.num_layers = num_layers
|
||||
self.hidden_dim = hidden_dim
|
||||
|
||||
self.encoder = tcnn.Encoding(
|
||||
n_input_dims=3,
|
||||
encoding_config={
|
||||
"otype": "HashGrid",
|
||||
"n_levels": 16,
|
||||
"n_features_per_level": 2,
|
||||
"log2_hashmap_size": 19,
|
||||
"base_resolution": 16,
|
||||
"interpolation": "Smoothstep",
|
||||
"per_level_scale": np.exp2(np.log2(2048 * self.bound / 16) / (16 - 1)),
|
||||
},
|
||||
dtype=torch.float32, # ENHANCE: default float16 seems unstable...
|
||||
)
|
||||
self.in_dim = self.encoder.n_output_dims
|
||||
# use torch MLP, as tcnn MLP doesn't impl second-order derivative
|
||||
self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
|
||||
|
||||
self.density_activation = trunc_exp if self.opt.density_activation == 'exp' else biased_softplus
|
||||
|
||||
# background network
|
||||
if self.opt.bg_radius > 0:
|
||||
self.num_layers_bg = num_layers_bg
|
||||
self.hidden_dim_bg = hidden_dim_bg
|
||||
|
||||
# use a very simple network to avoid it learning the prompt...
|
||||
self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3, multires=6)
|
||||
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
|
||||
|
||||
else:
|
||||
self.bg_net = None
|
||||
|
||||
def common_forward(self, x):
|
||||
|
||||
# sigma
|
||||
enc = self.encoder((x + self.bound) / (2 * self.bound)).float()
|
||||
h = self.sigma_net(enc)
|
||||
|
||||
sigma = self.density_activation(h[..., 0] + self.density_blob(x))
|
||||
albedo = torch.sigmoid(h[..., 1:])
|
||||
|
||||
return sigma, albedo
|
||||
|
||||
def normal(self, x):
|
||||
|
||||
with torch.enable_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
x.requires_grad_(True)
|
||||
sigma, albedo = self.common_forward(x)
|
||||
# query gradient
|
||||
normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
|
||||
|
||||
# normal = self.finite_difference_normal(x)
|
||||
normal = safe_normalize(normal)
|
||||
normal = torch.nan_to_num(normal)
|
||||
|
||||
return normal
|
||||
|
||||
def forward(self, x, d, l=None, ratio=1, shading='albedo'):
|
||||
# x: [N, 3], in [-bound, bound]
|
||||
# d: [N, 3], view direction, nomalized in [-1, 1]
|
||||
# l: [3], plane light direction, nomalized in [-1, 1]
|
||||
# ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
|
||||
|
||||
|
||||
if shading == 'albedo':
|
||||
sigma, albedo = self.common_forward(x)
|
||||
normal = None
|
||||
color = albedo
|
||||
|
||||
else: # lambertian shading
|
||||
with torch.enable_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
x.requires_grad_(True)
|
||||
sigma, albedo = self.common_forward(x)
|
||||
normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
|
||||
normal = safe_normalize(normal)
|
||||
normal = torch.nan_to_num(normal)
|
||||
|
||||
lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]
|
||||
|
||||
if shading == 'textureless':
|
||||
color = lambertian.unsqueeze(-1).repeat(1, 3)
|
||||
elif shading == 'normal':
|
||||
color = (normal + 1) / 2
|
||||
else: # 'lambertian'
|
||||
color = albedo * lambertian.unsqueeze(-1)
|
||||
|
||||
return sigma, color, normal
|
||||
|
||||
|
||||
def density(self, x):
|
||||
# x: [N, 3], in [-bound, bound]
|
||||
|
||||
sigma, albedo = self.common_forward(x)
|
||||
|
||||
return {
|
||||
'sigma': sigma,
|
||||
'albedo': albedo,
|
||||
}
|
||||
|
||||
|
||||
def background(self, d):
|
||||
|
||||
h = self.encoder_bg(d) # [N, C]
|
||||
|
||||
h = self.bg_net(h)
|
||||
|
||||
# sigmoid activation for rgb
|
||||
rgbs = torch.sigmoid(h)
|
||||
|
||||
return rgbs
|
||||
|
||||
# optimizer utils
|
||||
def get_params(self, lr):
|
||||
|
||||
params = [
|
||||
{'params': self.encoder.parameters(), 'lr': lr * 10},
|
||||
{'params': self.sigma_net.parameters(), 'lr': lr},
|
||||
]
|
||||
|
||||
if self.opt.bg_radius > 0:
|
||||
params.append({'params': self.bg_net.parameters(), 'lr': lr})
|
||||
|
||||
if self.opt.dmtet:
|
||||
params.append({'params': self.sdf, 'lr': lr})
|
||||
params.append({'params': self.deform, 'lr': lr})
|
||||
|
||||
return params
|
||||
329
nerf/provider.py
Normal file
329
nerf/provider.py
Normal file
@@ -0,0 +1,329 @@
|
||||
import random
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Slerp, Rotation
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .utils import get_rays, safe_normalize
|
||||
|
||||
DIR_COLORS = np.array([
|
||||
[255, 0, 0, 255], # front
|
||||
[0, 255, 0, 255], # side
|
||||
[0, 0, 255, 255], # back
|
||||
[255, 255, 0, 255], # side
|
||||
[255, 0, 255, 255], # overhead
|
||||
[0, 255, 255, 255], # bottom
|
||||
], dtype=np.uint8)
|
||||
|
||||
def visualize_poses(poses, dirs, size=0.1):
|
||||
# poses: [B, 4, 4], dirs: [B]
|
||||
import trimesh
|
||||
axes = trimesh.creation.axis(axis_length=4)
|
||||
sphere = trimesh.creation.icosphere(radius=1)
|
||||
objects = [axes, sphere]
|
||||
|
||||
for pose, dir in zip(poses, dirs):
|
||||
# a camera is visualized with 8 line segments.
|
||||
pos = pose[:3, 3]
|
||||
a = pos + size * pose[:3, 0] + size * pose[:3, 1] - size * pose[:3, 2]
|
||||
b = pos - size * pose[:3, 0] + size * pose[:3, 1] - size * pose[:3, 2]
|
||||
c = pos - size * pose[:3, 0] - size * pose[:3, 1] - size * pose[:3, 2]
|
||||
d = pos + size * pose[:3, 0] - size * pose[:3, 1] - size * pose[:3, 2]
|
||||
|
||||
segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a]])
|
||||
segs = trimesh.load_path(segs)
|
||||
|
||||
# different color for different dirs
|
||||
segs.colors = DIR_COLORS[[dir]].repeat(len(segs.entities), 0)
|
||||
|
||||
objects.append(segs)
|
||||
|
||||
trimesh.Scene(objects).show()
|
||||
|
||||
def get_view_direction(thetas, phis, overhead, front):
|
||||
# phis [B,]; thetas: [B,]
|
||||
# front = 0 [0, front)
|
||||
# side (right) = 1 [front, 180)
|
||||
# back = 2 [180, 180+front)
|
||||
# side (left) = 3 [180+front, 360)
|
||||
# top = 4 [0, overhead]
|
||||
# bottom = 5 [180-overhead, 180]
|
||||
res = torch.zeros(thetas.shape[0], dtype=torch.long)
|
||||
# first determine by phis
|
||||
phis = phis % (2 * np.pi)
|
||||
res[(phis < front / 2) | (phis >= 2 * np.pi - front / 2)] = 0
|
||||
res[(phis >= front / 2) & (phis < np.pi - front / 2)] = 1
|
||||
res[(phis >= np.pi - front / 2) & (phis < np.pi + front / 2)] = 2
|
||||
res[(phis >= np.pi + front / 2) & (phis < 2 * np.pi - front / 2)] = 3
|
||||
# override by thetas
|
||||
res[thetas <= overhead] = 4
|
||||
res[thetas >= (np.pi - overhead)] = 5
|
||||
return res
|
||||
|
||||
|
||||
def rand_poses(size, device, opt, radius_range=[1, 1.5], theta_range=[0, 120], phi_range=[0, 360], return_dirs=False, angle_overhead=30, angle_front=60, uniform_sphere_rate=0.5):
|
||||
''' generate random poses from an orbit camera
|
||||
Args:
|
||||
size: batch size of generated poses.
|
||||
device: where to allocate the output.
|
||||
radius: camera radius
|
||||
theta_range: [min, max], should be in [0, pi]
|
||||
phi_range: [min, max], should be in [0, 2 * pi]
|
||||
Return:
|
||||
poses: [size, 4, 4]
|
||||
'''
|
||||
|
||||
theta_range = np.array(theta_range) / 180 * np.pi
|
||||
phi_range = np.array(phi_range) / 180 * np.pi
|
||||
angle_overhead = angle_overhead / 180 * np.pi
|
||||
angle_front = angle_front / 180 * np.pi
|
||||
|
||||
radius = torch.rand(size, device=device) * (radius_range[1] - radius_range[0]) + radius_range[0]
|
||||
|
||||
if random.random() < uniform_sphere_rate:
|
||||
unit_centers = F.normalize(
|
||||
torch.stack([
|
||||
(torch.rand(size, device=device) - 0.5) * 2.0,
|
||||
torch.rand(size, device=device),
|
||||
(torch.rand(size, device=device) - 0.5) * 2.0,
|
||||
], dim=-1), p=2, dim=1
|
||||
)
|
||||
thetas = torch.acos(unit_centers[:,1])
|
||||
phis = torch.atan2(unit_centers[:,0], unit_centers[:,2])
|
||||
phis[phis < 0] += 2 * np.pi
|
||||
centers = unit_centers * radius.unsqueeze(-1)
|
||||
else:
|
||||
thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0]
|
||||
phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0]
|
||||
phis[phis < 0] += 2 * np.pi
|
||||
|
||||
centers = torch.stack([
|
||||
radius * torch.sin(thetas) * torch.sin(phis),
|
||||
radius * torch.cos(thetas),
|
||||
radius * torch.sin(thetas) * torch.cos(phis),
|
||||
], dim=-1) # [B, 3]
|
||||
|
||||
targets = 0
|
||||
|
||||
# jitters
|
||||
if opt.jitter_pose:
|
||||
jit_center = opt.jitter_center # 0.015 # was 0.2
|
||||
jit_target = opt.jitter_target
|
||||
centers += torch.rand_like(centers) * jit_center - jit_center/2.0
|
||||
targets += torch.randn_like(centers) * jit_target
|
||||
|
||||
# lookat
|
||||
forward_vector = safe_normalize(centers - targets)
|
||||
up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(size, 1)
|
||||
right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
|
||||
|
||||
if opt.jitter_pose:
|
||||
up_noise = torch.randn_like(up_vector) * opt.jitter_up
|
||||
else:
|
||||
up_noise = 0
|
||||
|
||||
up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1) + up_noise)
|
||||
|
||||
poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
|
||||
poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
|
||||
poses[:, :3, 3] = centers
|
||||
|
||||
if return_dirs:
|
||||
dirs = get_view_direction(thetas, phis, angle_overhead, angle_front)
|
||||
else:
|
||||
dirs = None
|
||||
|
||||
# back to degree
|
||||
thetas = thetas / np.pi * 180
|
||||
phis = phis / np.pi * 180
|
||||
|
||||
return poses, dirs, thetas, phis, radius
|
||||
|
||||
|
||||
def circle_poses(device, radius=torch.tensor([3.2]), theta=torch.tensor([60]), phi=torch.tensor([0]), return_dirs=False, angle_overhead=30, angle_front=60):
|
||||
|
||||
theta = theta / 180 * np.pi
|
||||
phi = phi / 180 * np.pi
|
||||
angle_overhead = angle_overhead / 180 * np.pi
|
||||
angle_front = angle_front / 180 * np.pi
|
||||
|
||||
centers = torch.stack([
|
||||
radius * torch.sin(theta) * torch.sin(phi),
|
||||
radius * torch.cos(theta),
|
||||
radius * torch.sin(theta) * torch.cos(phi),
|
||||
], dim=-1) # [B, 3]
|
||||
|
||||
# lookat
|
||||
forward_vector = safe_normalize(centers)
|
||||
up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(len(centers), 1)
|
||||
right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
|
||||
up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1))
|
||||
|
||||
poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(len(centers), 1, 1)
|
||||
poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
|
||||
poses[:, :3, 3] = centers
|
||||
|
||||
if return_dirs:
|
||||
dirs = get_view_direction(theta, phi, angle_overhead, angle_front)
|
||||
else:
|
||||
dirs = None
|
||||
|
||||
return poses, dirs
|
||||
|
||||
|
||||
class NeRFDataset:
|
||||
def __init__(self, opt, device, type='train', H=256, W=256, size=100):
|
||||
super().__init__()
|
||||
|
||||
self.opt = opt
|
||||
self.device = device
|
||||
self.type = type # train, val, test
|
||||
|
||||
self.H = H
|
||||
self.W = W
|
||||
self.size = size
|
||||
|
||||
self.training = self.type in ['train', 'all']
|
||||
|
||||
self.cx = self.H / 2
|
||||
self.cy = self.W / 2
|
||||
|
||||
self.near = self.opt.min_near
|
||||
self.far = 1000 # infinite
|
||||
|
||||
# [debug] visualize poses
|
||||
# poses, dirs, _, _, _ = rand_poses(100, self.device, opt, radius_range=self.opt.radius_range, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, jitter=self.opt.jitter_pose, uniform_sphere_rate=1)
|
||||
# visualize_poses(poses.detach().cpu().numpy(), dirs.detach().cpu().numpy())
|
||||
|
||||
def get_default_view_data(self):
|
||||
|
||||
H = int(self.opt.known_view_scale * self.H)
|
||||
W = int(self.opt.known_view_scale * self.W)
|
||||
cx = H / 2
|
||||
cy = W / 2
|
||||
|
||||
radii = torch.FloatTensor(self.opt.ref_radii).to(self.device)
|
||||
thetas = torch.FloatTensor(self.opt.ref_polars).to(self.device)
|
||||
phis = torch.FloatTensor(self.opt.ref_azimuths).to(self.device)
|
||||
poses, dirs = circle_poses(self.device, radius=radii, theta=thetas, phi=phis, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front)
|
||||
fov = self.opt.default_fovy
|
||||
focal = H / (2 * np.tan(np.deg2rad(fov) / 2))
|
||||
intrinsics = np.array([focal, focal, cx, cy])
|
||||
|
||||
projection = torch.tensor([
|
||||
[2*focal/W, 0, 0, 0],
|
||||
[0, -2*focal/H, 0, 0],
|
||||
[0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)],
|
||||
[0, 0, -1, 0]
|
||||
], dtype=torch.float32, device=self.device).unsqueeze(0).repeat(len(radii), 1, 1)
|
||||
|
||||
mvp = projection @ torch.inverse(poses) # [B, 4, 4]
|
||||
|
||||
# sample a low-resolution but full image
|
||||
rays = get_rays(poses, intrinsics, H, W, -1)
|
||||
|
||||
data = {
|
||||
'H': H,
|
||||
'W': W,
|
||||
'rays_o': rays['rays_o'],
|
||||
'rays_d': rays['rays_d'],
|
||||
'dir': dirs,
|
||||
'mvp': mvp,
|
||||
'polar': self.opt.ref_polars,
|
||||
'azimuth': self.opt.ref_azimuths,
|
||||
'radius': self.opt.ref_radii,
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
def collate(self, index):
|
||||
|
||||
B = len(index)
|
||||
|
||||
if self.training:
|
||||
# random pose on the fly
|
||||
poses, dirs, thetas, phis, radius = rand_poses(B, self.device, self.opt, radius_range=self.opt.radius_range, theta_range=self.opt.theta_range, phi_range=self.opt.phi_range, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, uniform_sphere_rate=self.opt.uniform_sphere_rate)
|
||||
|
||||
# random focal
|
||||
fov = random.random() * (self.opt.fovy_range[1] - self.opt.fovy_range[0]) + self.opt.fovy_range[0]
|
||||
|
||||
elif self.type == 'six_views':
|
||||
# six views
|
||||
thetas_six = [90]*4 + [1e-6] + [180]
|
||||
phis_six = [0, 90, 180, -90, 0, 0]
|
||||
thetas = torch.FloatTensor([thetas_six[index[0]]]).to(self.device)
|
||||
phis = torch.FloatTensor([phis_six[index[0]]]).to(self.device)
|
||||
radius = torch.FloatTensor([self.opt.default_radius]).to(self.device)
|
||||
poses, dirs = circle_poses(self.device, radius=radius, theta=thetas, phi=phis, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front)
|
||||
|
||||
# fixed focal
|
||||
fov = self.opt.default_fovy
|
||||
|
||||
else:
|
||||
# circle pose
|
||||
thetas = torch.FloatTensor([self.opt.default_polar]).to(self.device)
|
||||
phis = torch.FloatTensor([(index[0] / self.size) * 360]).to(self.device)
|
||||
radius = torch.FloatTensor([self.opt.default_radius]).to(self.device)
|
||||
poses, dirs = circle_poses(self.device, radius=radius, theta=thetas, phi=phis, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front)
|
||||
|
||||
# fixed focal
|
||||
fov = self.opt.default_fovy
|
||||
|
||||
focal = self.H / (2 * np.tan(np.deg2rad(fov) / 2))
|
||||
intrinsics = np.array([focal, focal, self.cx, self.cy])
|
||||
|
||||
projection = torch.tensor([
|
||||
[2*focal/self.W, 0, 0, 0],
|
||||
[0, -2*focal/self.H, 0, 0],
|
||||
[0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)],
|
||||
[0, 0, -1, 0]
|
||||
], dtype=torch.float32, device=self.device).unsqueeze(0)
|
||||
|
||||
mvp = projection @ torch.inverse(poses) # [1, 4, 4]
|
||||
|
||||
# sample a low-resolution but full image
|
||||
rays = get_rays(poses, intrinsics, self.H, self.W, -1)
|
||||
|
||||
# delta polar/azimuth/radius to default view
|
||||
delta_polar = thetas - self.opt.default_polar
|
||||
delta_azimuth = phis - self.opt.default_azimuth
|
||||
delta_azimuth[delta_azimuth > 180] -= 360 # range in [-180, 180]
|
||||
delta_radius = radius - self.opt.default_radius
|
||||
|
||||
data = {
|
||||
'H': self.H,
|
||||
'W': self.W,
|
||||
'rays_o': rays['rays_o'],
|
||||
'rays_d': rays['rays_d'],
|
||||
'dir': dirs,
|
||||
'mvp': mvp,
|
||||
'polar': delta_polar,
|
||||
'azimuth': delta_azimuth,
|
||||
'radius': delta_radius,
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
def dataloader(self, batch_size=None):
|
||||
batch_size = batch_size or self.opt.batch_size
|
||||
loader = DataLoader(list(range(self.size)), batch_size=batch_size, collate_fn=self.collate, shuffle=self.training, num_workers=0)
|
||||
loader._data = self
|
||||
return loader
|
||||
|
||||
|
||||
def generate_grid_points(resolution=128, device='cuda'):
|
||||
# resolution: number of points along each dimension
|
||||
# Generate the grid points
|
||||
x = torch.linspace(0, 1, resolution)
|
||||
y = torch.linspace(0, 1, resolution)
|
||||
z = torch.linspace(0, 1, resolution)
|
||||
# Create the meshgrid
|
||||
grid_x, grid_y, grid_z = torch.meshgrid(x, y, z)
|
||||
|
||||
# Flatten the grid points if needed
|
||||
grid_points = torch.stack((grid_x.flatten(), grid_y.flatten(), grid_z.flatten()), dim=1).to(device)
|
||||
return grid_points
|
||||
|
||||
1575
nerf/renderer.py
Normal file
1575
nerf/renderer.py
Normal file
File diff suppressed because it is too large
Load Diff
1599
nerf/utils.py
Normal file
1599
nerf/utils.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user