first commit

This commit is contained in:
Guocheng Qian
2023-08-02 19:51:43 -07:00
parent c2891c38cc
commit 13e18567fa
202 changed files with 43362 additions and 17 deletions

630
main.py Normal file
View File

@@ -0,0 +1,630 @@
import torch
import argparse
import sys
import os
import pandas as pd
from nerf.provider import NeRFDataset, generate_grid_points
from nerf.utils import *
import yaml
from easydict import EasyDict as edict
import dnnultis
import logging
logger = logging.getLogger(__name__)
# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(
description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
help='YAML config file specifying default arguments')
parser = argparse.ArgumentParser(description='3D AIGC Training')
parser.add_argument('--workspace', type=str, default='', help='path to log')
parser.add_argument('--text', default=None, help="text prompt")
parser.add_argument('--negative', default='', type=str,
help="negative text prompt")
parser.add_argument('--dir_texts_neg', action='store_true',
help="enable negative directional text")
parser.add_argument('--check_prompt', action='store_true', help="check prompt")
parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray")
parser.add_argument('-O2', action='store_true',
help="equals --backbone vanilla")
parser.add_argument('--test', action='store_true', help="test mode")
parser.add_argument('--six_views', action='store_true',
help="six_views mode: save the images of the six views")
parser.add_argument('--eval_interval', type=int, default=1,
help="evaluate on the valid set every interval epochs")
parser.add_argument('--test_interval', type=int, default=50,
help="test on the test set every interval epochs")
parser.add_argument('--seed', type=int, default=101)
parser.add_argument('--log_every', type=int, default=20,
help="log losses every X iterations")
parser.add_argument('--use_wandb', action='store_true',
help="log online into wandb")
# guidance
parser.add_argument('--guidance', type=str, nargs='*',
default=['SD'], help='guidance model')
parser.add_argument('--guidance_scale', type=float, nargs='*', default=[100],
help="diffusion model classifier-free guidance scale")
parser.add_argument('--gudiance_spatial_weighting',
action='store_true', help="add spatial weighting to guidance")
parser.add_argument('--save_train_every', type=int,
default=-1, help="save sds guidance")
# clip guidance
# lambda_clip, set to 1 if use clip loss outside sds
parser.add_argument('--lambda_clip', type=float, default=0,
help="loss scale for clip loss outside sds")
# set to 100 if use clip guidance in sds
parser.add_argument('--clip_version', type=str,
default='large', help="clip version, large is ued in stable diffusion")
parser.add_argument('--clip_guidance', type=float, default=0,
help="diffusion model classifier-free guidance scale")
parser.add_argument('--clip_t', type=float, default=0.4,
help="time step thresh started to use clip")
parser.add_argument('--clip_iterative', action='store_true',
help="use clipd iteratively with sds")
parser.add_argument('--clip_image_loss', action='store_true',
help="use image as reference in clip")
parser.add_argument('--save_guidance_every', type=int,
default=-1, help="save sds guidance")
# 3D prior: Shap-E. Does not work.
parser.add_argument('--use_shape', action='store_true',
help="enable shap-e initization")
parser.add_argument('--shape_guidance', type=float, default=3,
help="guidance scaling for shap-e prior")
parser.add_argument('--shape_radius', type=float, default=4,
help="camera raidus for shap-e prior")
parser.add_argument('--shape_fovy', type=float, default=40,
help="fov for shap-e prior")
parser.add_argument('--shape_no_color', action='store_false',
dest='shape_init_color', help="do not use shap-E color for initization")
parser.add_argument('--shape_rpst', type=str, default='sdf',
help="use sdf to init NeRF/mesh by default")
# image options.
parser.add_argument('--image', default=None, help="image prompt")
parser.add_argument('--image_config', default=None, help="image config csv")
parser.add_argument('--learned_embeds_path', type=str,
default=None, help="path to learned embeds of the given image")
parser.add_argument('--known_iters', type=int, default=100,
help="loss scale for alpha entropy")
parser.add_argument('--known_view_interval', type=int, default=4,
help="do reconstruction every X iterations to save on compute")
parser.add_argument('--bg_color_known', type=str,
default=None, help='pixelnoise, noise, None') # pixelnoise
parser.add_argument('--known_shading', type=str, default='lambertian')
# DMTet and Mesh options
parser.add_argument('--save_mesh', action='store_true',
help="export an obj mesh with texture")
parser.add_argument('--mcubes_resolution', type=int, default=256,
help="mcubes resolution for extracting mesh")
parser.add_argument('--decimate_target', type=int, default=5e4,
help="target face number for mesh decimation")
parser.add_argument('--dmtet', action='store_true',
help="use dmtet finetuning")
parser.add_argument('--tet_mlp', action='store_true',
help="use tet_mlp finetuning")
parser.add_argument('--base_mesh', default=None,
help="base mesh for dmtet init")
parser.add_argument('--tet_grid_size', type=int,
default=256, help="tet grid size")
parser.add_argument('--init_ckpt', type=str, default='',
help="ckpt to init dmtet")
parser.add_argument('--lock_geo', action='store_true',
help="disable dmtet to learn geometry")
# training options
parser.add_argument('--iters', type=int, default=5000, help="training iters")
parser.add_argument('--lr', type=float, default=1e-3, help="max learning rate")
parser.add_argument('--lr_scale_nerf', type=float,
default=1, help="max learning rate")
parser.add_argument('--lr_scale_texture', type=float,
default=1, help="max learning rate")
parser.add_argument('--ckpt', type=str, default='latest')
parser.add_argument('--cuda_ray', action='store_true',
help="use CUDA raymarching instead of pytorch")
parser.add_argument('--taichi_ray', action='store_true',
help="use taichi raymarching")
parser.add_argument('--max_steps', type=int, default=1024,
help="max num steps sampled per ray (only valid when using --cuda_ray)")
parser.add_argument('--num_steps', type=int, default=64,
help="num steps sampled per ray (only valid when not using --cuda_ray)")
parser.add_argument('--upsample_steps', type=int, default=32,
help="num steps up-sampled per ray (only valid when not using --cuda_ray)")
parser.add_argument('--update_extra_interval', type=int, default=16,
help="iter interval to update extra status (only valid when using --cuda_ray)")
parser.add_argument('--max_ray_batch', type=int, default=4096,
help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)")
parser.add_argument('--latent_iter_ratio', type=float, default=0.0,
help="training iters that only use latent normal shading")
parser.add_argument('--normal_iter_ratio', type=float, default=0.0,
help="training iters that only use normal shading")
parser.add_argument('--textureless_iter_ratio', type=float, default=0.0,
help="training iters that only use textureless shading")
parser.add_argument('--albedo_iter_ratio', type=float, default=0,
help="training iters that only use albedo shading")
parser.add_argument('--warmup_bg_color', type=str, default=None,
help="bg color [None | pixelnoise | noise | white]")
parser.add_argument('--bg_color', type=str, default=None)
parser.add_argument('--bg_color_test', default='white')
parser.add_argument('--ema_decay', type=float, default=0.95,
help="exponential moving average of model weights")
parser.add_argument('--jitter_pose', action='store_true',
help="add jitters to the randomly sampled camera poses")
parser.add_argument('--jitter_center', type=float, default=0.2,
help="amount of jitter to add to sampled camera pose's center (camera location)")
parser.add_argument('--jitter_target', type=float, default=0.2,
help="amount of jitter to add to sampled camera pose's target (i.e. 'look-at')")
parser.add_argument('--jitter_up', type=float, default=0.02,
help="amount of jitter to add to sampled camera pose's up-axis (i.e. 'camera roll')")
parser.add_argument('--uniform_sphere_rate', type=float, default=0.5,
help="likelihood of sampling camera location uniformly on the sphere surface area")
parser.add_argument('--grad_clip', type=float, default=-1,
help="clip grad of all grad to this limit, negative value disables it")
parser.add_argument('--grad_clip_rgb', type=float, default=-1,
help="clip grad of rgb space grad to this limit, negative value disables it")
parser.add_argument('--grid_levels_mask', type=int, default=8,
help="the number of levels in the feature grid to mask (to disable use 0)")
parser.add_argument('--grid_levels_mask_iters', type=int, default=3000,
help="the number of iterations for feature grid masking (to disable use 0)")
# model options
parser.add_argument('--bg_radius', type=float, default=1.4,
help="if positive, use a background model at sphere(bg_radius)")
parser.add_argument('--density_activation', type=str, default='exp',
choices=['softplus', 'exp', 'relu'], help="density activation function")
parser.add_argument('--density_thresh', type=float, default=10,
help="threshold for density grid to be occupied")
# add more strength to the center, believe the center is more likely to have objects.
parser.add_argument('--blob_density', type=float, default=10,
help="max (center) density for the density blob")
parser.add_argument('--blob_radius', type=float, default=0.2,
help="control the radius for the density blob")
# network backbone
parser.add_argument('--backbone', type=str, default='grid',
choices=['grid', 'vanilla', 'grid_taichi'], help="nerf backbone")
parser.add_argument('--grid_type', type=str,
default='hashgrid', help="grid type")
parser.add_argument('--hidden_dim_bg', type=int, default=32,
help="channels for background network")
parser.add_argument('--optim', type=str, default='adam',
choices=['adan', 'adam'], help="optimizer")
parser.add_argument('--sd_version', type=str, default='1.5',
choices=['1.5', '2.0', '2.1'], help="stable diffusion version")
parser.add_argument('--hf_key', type=str, default=None,
help="hugging face Stable diffusion model key")
# try this if CUDA OOM
parser.add_argument('--fp16', action='store_true',
help="use float16 for training")
parser.add_argument('--vram_O', action='store_true',
help="optimization for low VRAM usage")
# rendering resolution in training, increase these for better quality / decrease these if CUDA OOM even if --vram_O enabled.
parser.add_argument('--w', type=int, default=128,
help="render width for NeRF in training")
parser.add_argument('--h', type=int, default=128,
help="render height for NeRF in training")
parser.add_argument('--known_view_scale', type=float, default=1.5,
help="multiply --h/w by this for known view rendering")
parser.add_argument('--known_view_noise_scale', type=float, default=1e-3,
help="random camera noise added to rays_o and rays_d")
parser.add_argument('--noise_known_camera_annealing', action='store_true',
help="anneal the noise to zero over the coarse of training")
parser.add_argument('--dmtet_reso_scale', type=float, default=8,
help="multiply --h/w by this for dmtet finetuning")
parser.add_argument('--rm_edge', action='store_true',
help="remove edge (ideally only enale for high resolution cases)")
parser.add_argument('--edge_threshold', type=float, default=0.1,
help="remove edges with value > threshold")
parser.add_argument('--edge_width', type=float, default=5,
help="edge width")
parser.add_argument('--batch_size', type=int, default=1,
help="images to render per batch using NeRF")
# dataset options
parser.add_argument('--bound', type=float, default=1.0,
help="assume the scene is bounded in box(-bound, bound)")
parser.add_argument('--dt_gamma', type=float, default=0,
help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
parser.add_argument('--min_near', type=float, default=0.1,
help="minimum near distance for camera")
parser.add_argument('--radius_range', type=float, nargs='*',
default=[1.8, 1.8], help="training camera radius range")
parser.add_argument('--theta_range', type=float, nargs='*',
default=[45, 135], help="training camera elevation/polar range, 90 is front")
parser.add_argument('--phi_range', type=float, nargs='*',
default=[-180, 180], help="training camera azimuth range")
parser.add_argument('--fovy_range', type=float, nargs='*',
default=[40, 40], help="training camera fovy range")
parser.add_argument('--default_radius', type=float, default=1.8,
help="radius for the default view")
parser.add_argument('--default_polar', type=float,
default=90, help="polar for the default view")
parser.add_argument('--default_azimuth', type=float,
default=0, help="azimuth for the default view")
parser.add_argument('--default_fovy', type=float, default=40,
help="fovy for the default view")
parser.add_argument('--progressive_view', action='store_true',
help="progressively expand view sampling range from default to full")
parser.add_argument('--progressive_level', action='store_true',
help="progressively increase gridencoder's max_level")
parser.add_argument('--angle_overhead', type=float, default=30,
help="[0, angle_overhead] is the overhead region")
parser.add_argument('--angle_front', type=float, default=60,
help="[0, angle_front] is the front region, [180, 180+angle_front] the back region, otherwise the side region.")
parser.add_argument('--t_range', type=float, nargs='*',
default=[0.2, 0.6], help="stable diffusion time steps range")
# regularizations
parser.add_argument('--lambda_entropy', type=float, default=1e-3,
help="loss scale for alpha entropy, favors 0 or 1")
# Try increasing/decreasing lambda_opacity if your scene is stuffed with floaters/becoming empty.
parser.add_argument('--lambda_opacity', type=float, default=0.,
help="loss scale for alpha value, avoid uncessary filling")
# Try increasing/decreasing lambda_orient if you object is foggy/over-smoothed.
parser.add_argument('--lambda_orient', type=float,
default=1e-2, help="loss scale for orientation")
parser.add_argument('--lambda_tv', type=float, default=0,
help="loss scale for total variation of grad")
parser.add_argument('--lambda_wd', type=float, default=0,
help="loss scale for weight decay of grad")
parser.add_argument('--lambda_normal_smooth', type=float, default=0.5,
help="loss scale for first-order 2D normal image smoothness")
parser.add_argument('--lambda_normal_smooth2d', type=float, default=0.5,
help="loss scale for second-order 2D normal image smoothness")
parser.add_argument('--lambda_3d_normal_smooth', type=float, default=0.0,
help="loss scale for second-order 2D normal image smoothness")
parser.add_argument('--lambda_guidance', type=float, nargs='*',
default=[1], help="loss scale for SDS")
parser.add_argument('--lambda_rgb', type=float,
default=5, help="loss scale for RGB")
parser.add_argument('--lambda_mask', type=float, default=0.5,
help="loss scale for mask (alpha)")
parser.add_argument('--lambda_depth', type=float, default=0.01,
help="loss scale for relative depth of the known view")
parser.add_argument('--lambda_normal', type=float,
default=0.0, help="loss scale for normals of the known view")
parser.add_argument('--lambda_depth_mse', type=float, default=0.0,
help="loss scale for depth of the known view")
parser.add_argument('--no_normalize_depth', action='store_false', dest='normalize_depth', help="normalize depth")
# for DMTet
parser.add_argument('--lambda_mesh_normal', type=float,
default=0.1, help="loss scale for mesh normal smoothness")
parser.add_argument('--lambda_mesh_lap', type=float,
default=0.1, help="loss scale for mesh laplacian")
# GUI options
parser.add_argument('--gui', action='store_true', help="start a GUI")
parser.add_argument('--W', type=int, default=800, help="GUI width")
parser.add_argument('--H', type=int, default=800, help="GUI height")
parser.add_argument('--radius', type=float, default=1.8,
help="default GUI camera radius from center")
parser.add_argument('--fovy', type=float, default=40,
help="default GUI camera fovy")
parser.add_argument('--light_theta', type=float, default=60,
help="default GUI light direction in [0, 180], corresponding to elevation [90, -90]")
parser.add_argument('--light_phi', type=float, default=0,
help="default GUI light direction in [0, 360), azimuth")
parser.add_argument('--max_spp', type=int, default=1,
help="GUI rendering max sample per pixel")
parser.add_argument('--zero123_config', type=str,
default='./pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml', help="config file for zero123")
parser.add_argument('--zero123_ckpt', type=str,
default='./pretrained/zero123/105000.ckpt', help="ckpt for zero123")
parser.add_argument('--zero123_grad_scale', type=str, default='angle',
help="whether to scale the gradients based on 'angle' or 'None'")
parser.add_argument('--dataset_size_train', type=int, default=100,
help="Length of train dataset i.e. # of iterations per epoch")
parser.add_argument('--dataset_size_valid', type=int, default=8,
help="# of frames to render in the turntable video in validation")
parser.add_argument('--dataset_size_test', type=int, default=100,
help="# of frames to render in the turntable video at test time")
def _parse_args():
args_config, remaining = config_parser.parse_known_args()
if args_config.config:
with open(args_config.config, 'r') as f:
cfg = yaml.safe_load(f)
parser.set_defaults(**cfg)
# The main arg parser parses the rest of the args, the usual
# defaults will have been overridden if config file specified.
args = parser.parse_args(remaining)
# Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
return args, args_text
if __name__ == '__main__':
args, args_text = _parse_args()
opt = edict(vars(args))
if opt.O:
opt.fp16 = True
opt.cuda_ray = True
elif opt.O2:
opt.fp16 = True
opt.backbone = 'vanilla'
opt.images, opt.ref_radii, opt.ref_polars, opt.ref_azimuths, opt.zero123_ws = [], [], [], [], []
opt.default_zero123_w = 1
# parameters for image-conditioned generation
if opt.image is not None or opt.image_config is not None:
if 'zero123' in opt.guidance:
# fix fov as zero123 doesn't support changing fov
opt.fovy_range = [opt.default_fovy, opt.default_fovy]
else:
opt.known_view_interval = 2
if 'SD' in opt.guidance:
opt.t_range = [0.2, 0.6]
opt.bg_radius = -1
# latent warmup is not needed
opt.latent_iter_ratio = 0
opt.albedo_iter_ratio = 0
if opt.image is not None:
opt.images += [opt.image]
opt.ref_radii += [opt.default_radius]
opt.ref_polars += [opt.default_polar]
opt.ref_azimuths += [opt.default_azimuth]
opt.zero123_ws += [opt.default_zero123_w]
if opt.image_config is not None:
# for multiview (zero123)
conf = pd.read_csv(opt.image_config, skipinitialspace=True)
opt.images += list(conf.image)
opt.ref_radii += list(conf.radius)
opt.ref_polars += list(conf.polar)
opt.ref_azimuths += list(conf.azimuth)
opt.zero123_ws += list(conf.zero123_weight)
if opt.image is None:
opt.default_radius = opt.ref_radii[0]
opt.default_polar = opt.ref_polars[0]
opt.default_azimuth = opt.ref_azimuths[0]
opt.default_zero123_w = opt.zero123_ws[0]
# reset to None
if len(opt.images) == 0:
opt.images = None
# default parameters for finetuning
if opt.dmtet:
opt.h = int(opt.h * opt.dmtet_reso_scale)
opt.w = int(opt.w * opt.dmtet_reso_scale)
opt.known_view_scale = 1
opt.grid_levels_mask = -1 # disable corse nerf (fine to keep, not necesary)
opt.t_range = [0.02, 0.50] # ref: magic3D
if opt.images is not None:
opt.lambda_normal = 0
opt.lambda_depth = 0
# assume finetuning
opt.latent_iter_ratio = 0
opt.textureless_iter_ratio = 0
opt.albedo_iter_ratio = 0
opt.normal_iter_ratio = 0
opt.progressive_view = False
opt.progressive_level = False
# record full range for progressive view expansion
if opt.progressive_view:
# disable as they disturb progressive view
opt.jitter_pose = False
opt.uniform_sphere_rate = 0
# back up full range
opt.full_radius_range = opt.radius_range
opt.full_theta_range = opt.theta_range
opt.full_phi_range = opt.phi_range
opt.full_fovy_range = opt.fovy_range
opt.use_clip = opt.clip_guidance > 0 or opt.lambda_clip > 0
# Do not support Shap-E for NeRF yet.
opt.use_shape = False if not opt.dmtet else opt.use_shape
# workspace prepare
setup_workspace(opt)
dnnultis.setup_logging(opt.log_path)
if opt.seed < 0:
opt.seed = random.randint(0, 10000)
seed_everything(int(opt.seed))
if opt.backbone == 'vanilla':
from nerf.network import NeRFNetwork
elif opt.backbone == 'grid':
from nerf.network_grid import NeRFNetwork
elif opt.backbone == 'grid_tcnn':
from nerf.network_grid_tcnn import NeRFNetwork
elif opt.backbone == 'grid_taichi':
opt.cuda_ray = False
opt.taichi_ray = True
import taichi as ti
from nerf.network_grid_taichi import NeRFNetwork
taichi_half2_opt = True
taichi_init_args = {"arch": ti.cuda, "device_memory_GB": 4.0}
if taichi_half2_opt:
taichi_init_args["half2_vectorization"] = True
ti.init(**taichi_init_args)
else:
raise NotImplementedError(
f'--backbone {opt.backbone} is not implemented!')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
opt.device = device
model = NeRFNetwork(opt).to(device)
if opt.init_ckpt != '':
if not os.path.exists(opt.init_ckpt):
logger.warning(f'ckpt {opt.init_ckpt} is not found')
else:
# load pretrained weights to init dmtet
state_dict = torch.load(opt.init_ckpt, map_location=device)
model.load_state_dict(state_dict['model'], strict=False)
if opt.cuda_ray:
model.mean_density = state_dict['mean_density']
logger.info(f'init from {opt.init_ckpt}...')
# if init ckpt is provided, we assume the color network is well learned and do not need base_mesh init
opt.shape_init_color = False
opt.base_mesh = None
if opt.use_shape and opt.dmtet:
# now only supports shape for dmtet init
from guidance.shape_utils import get_shape_from_image
opt.points = generate_grid_points(
128, device=device) if not opt.dmtet else model.dmtet.verts
opt.rpsts, opt.colors = get_shape_from_image(
opt.image.replace('rgba', 'rgb'),
opt.points,
rpst_type=opt.shape_rpst,
get_color=opt.shape_init_color,
shape_guidance=opt.shape_guidance, device=device)
scale = opt.default_radius / opt.shape_radius * \
np.tan(np.deg2rad(opt.default_fovy / 2)) / \
np.tan(np.deg2rad(opt.shape_fovy / 2))
if opt.dmtet:
model.dmtet.reset_tet_scale(scale)
else:
opt.points *= scale
logger.info(f'Got sdf from Shap-E init...')
logger.info(model)
if opt.six_views:
guidance = None # no need to load guidance model at test
trainer = Trainer(' '.join(sys.argv), 'df', opt, model, guidance, device=device,
workspace=opt.workspace, fp16=opt.fp16, use_checkpoint=opt.ckpt)
test_loader = NeRFDataset(
opt, device=device, type='six_views', H=opt.H, W=opt.W, size=6).dataloader(batch_size=1)
trainer.test(test_loader, write_video=False)
if opt.save_mesh:
trainer.save_mesh()
elif opt.test:
guidance = None # no need to load guidance model at test
trainer = Trainer(' '.join(sys.argv), os.path.basename(opt.workspace), opt, model, guidance,
device=device, workspace=opt.workspace, fp16=opt.fp16, use_checkpoint=opt.ckpt)
if opt.gui:
from nerf.gui import NeRFGUI
gui = NeRFGUI(opt, trainer)
gui.render()
else:
test_loader = NeRFDataset(
opt, device=device, type='test', H=opt.H, W=opt.W, size=opt.dataset_size_test).dataloader()
trainer.test(test_loader)
trainer.test(test_loader, shading='normal') # save normal
if opt.save_mesh:
try:
trainer.save_mesh()
except:
pass
else:
train_loader = NeRFDataset(
opt, device=device, type='train', H=opt.h, W=opt.w, size=opt.dataset_size_train * opt.batch_size).dataloader()
if opt.optim == 'adan':
from optimizer import Adan
# Adan usually requires a larger LR
def optimizer(model): return Adan(model.get_params(
5 * opt.lr), eps=1e-8, weight_decay=2e-5, max_grad_norm=5.0, foreach=False)
else: # adam
def optimizer(model): return torch.optim.Adam(
model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)
if opt.backbone == 'vanilla':
def scheduler(optimizer): return optim.lr_scheduler.LambdaLR(
optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1))
else:
def scheduler(optimizer): return optim.lr_scheduler.LambdaLR(
optimizer, lambda iter: 1) # fixed
guidance = nn.ModuleDict()
lambda_guidance, guidance_scale = {}, {}
for idx, guidance_type in enumerate(opt.guidance):
lambda_guidance[guidance_type] = opt.lambda_guidance[idx] if idx < len(
opt.lambda_guidance) else opt.lambda_guidance[-1]
guidance_scale[guidance_type] = opt.guidance_scale[idx] if idx < len(
opt.guidance_scale) else opt.guidance_scale[-1]
if 'SD' == guidance_type:
from guidance.sd_utils import StableDiffusion, token_replace
guidance['SD'] = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key, opt.t_range,
learned_embeds_path=opt.learned_embeds_path,
use_clip=opt.use_clip, clip_t=opt.clip_t, clip_iterative=opt.clip_iterative, clip_version=opt.clip_version,
)
if opt.learned_embeds_path is not None and os.path.exists(opt.learned_embeds_path): # add textual inversion tokens to model
opt.text, opt.negative = token_replace(
opt.text, opt.negative, opt.learned_embeds_path)
logger.info(
f'prompt: {opt.text}, negative: {opt.negative}')
if opt.check_prompt:
guidance['SD'].check_prompt(opt)
else:
opt.text = opt.text.replace('<token>', os.path.basename(os.path.dirname(opt.image)))
logger.warning('No learned_embeds_path provided, using the folowing pure text prompt with degraded performance: ' + opt.text)
if 'IF' == guidance_type:
from guidance.if_utils import IF
guidance['IF'] = IF(device, opt.vram_O, opt.t_range)
if 'zero123' == guidance_type:
from guidance.zero123_utils import Zero123
guidance['zero123'] = Zero123(device=device, fp16=opt.fp16, config=opt.zero123_config,
ckpt=opt.zero123_ckpt, vram_O=opt.vram_O, t_range=opt.t_range, opt=opt)
if 'clip' == guidance_type:
from guidance.clip_utils import CLIP
guidance['clip'] = CLIP(device)
opt.lambda_guidance = lambda_guidance
opt.guidance_scale = guidance_scale
logger.info(opt)
trainer = Trainer(' '.join(sys.argv), os.path.basename(opt.workspace), opt, model,
guidance,
device=device, workspace=opt.workspace, optimizer=optimizer,
ema_decay=opt.ema_decay, fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint=opt.ckpt, scheduler_update_every_step=True)
trainer.default_view_data = train_loader._data.get_default_view_data()
if opt.gui:
from nerf.gui import NeRFGUI
gui = NeRFGUI(opt, trainer, train_loader)
gui.render()
else:
valid_loader = NeRFDataset(
opt, device=device, type='val', H=opt.H, W=opt.W, size=opt.dataset_size_valid).dataloader()
test_loader = NeRFDataset(
opt, device=device, type='test', H=opt.H, W=opt.W, size=opt.dataset_size_test).dataloader()
max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32)
trainer.train(train_loader, valid_loader, test_loader, max_epoch)
trainer.test(test_loader)
trainer.test(test_loader, shading='normal') # save normal
if opt.save_mesh:
try:
trainer.save_mesh()
except:
pass