import torch import argparse import sys import os import pandas as pd from nerf.provider import NeRFDataset, generate_grid_points from nerf.utils import * import yaml from easydict import EasyDict as edict import dnnultis import logging logger = logging.getLogger(__name__) # The first arg parser parses out only the --config argument, this argument is used to # load a yaml file containing key-values that override the defaults for the main parser below config_parser = parser = argparse.ArgumentParser( description='Training Config', add_help=False) parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', help='YAML config file specifying default arguments') parser = argparse.ArgumentParser(description='3D AIGC Training') parser.add_argument('--workspace', type=str, default='', help='path to log') parser.add_argument('--text', default=None, help="text prompt") parser.add_argument('--negative', default='', type=str, help="negative text prompt") parser.add_argument('--dir_texts_neg', action='store_true', help="enable negative directional text") parser.add_argument('--check_prompt', action='store_true', help="check prompt") parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray") parser.add_argument('-O2', action='store_true', help="equals --backbone vanilla") parser.add_argument('--test', action='store_true', help="test mode") parser.add_argument('--six_views', action='store_true', help="six_views mode: save the images of the six views") parser.add_argument('--eval_interval', type=int, default=1, help="evaluate on the valid set every interval epochs") parser.add_argument('--test_interval', type=int, default=50, help="test on the test set every interval epochs") parser.add_argument('--seed', type=int, default=101) parser.add_argument('--log_every', type=int, default=20, help="log losses every X iterations") parser.add_argument('--use_wandb', action='store_true', help="log online into wandb") # guidance parser.add_argument('--guidance', type=str, nargs='*', default=['SD'], help='guidance model') parser.add_argument('--guidance_scale', type=float, nargs='*', default=[100], help="diffusion model classifier-free guidance scale") parser.add_argument('--gudiance_spatial_weighting', action='store_true', help="add spatial weighting to guidance") parser.add_argument('--save_train_every', type=int, default=-1, help="save sds guidance") # clip guidance # lambda_clip, set to 1 if use clip loss outside sds parser.add_argument('--lambda_clip', type=float, default=0, help="loss scale for clip loss outside sds") # set to 100 if use clip guidance in sds parser.add_argument('--clip_version', type=str, default='large', help="clip version, large is ued in stable diffusion") parser.add_argument('--clip_guidance', type=float, default=0, help="diffusion model classifier-free guidance scale") parser.add_argument('--clip_t', type=float, default=0.4, help="time step thresh started to use clip") parser.add_argument('--clip_iterative', action='store_true', help="use clipd iteratively with sds") parser.add_argument('--clip_image_loss', action='store_true', help="use image as reference in clip") parser.add_argument('--save_guidance_every', type=int, default=-1, help="save sds guidance") # 3D prior: Shap-E. Does not work. parser.add_argument('--use_shape', action='store_true', help="enable shap-e initization") parser.add_argument('--shape_guidance', type=float, default=3, help="guidance scaling for shap-e prior") parser.add_argument('--shape_radius', type=float, default=4, help="camera raidus for shap-e prior") parser.add_argument('--shape_fovy', type=float, default=40, help="fov for shap-e prior") parser.add_argument('--shape_no_color', action='store_false', dest='shape_init_color', help="do not use shap-E color for initization") parser.add_argument('--shape_rpst', type=str, default='sdf', help="use sdf to init NeRF/mesh by default") # image options. parser.add_argument('--image', default=None, help="image prompt") parser.add_argument('--image_config', default=None, help="image config csv") parser.add_argument('--learned_embeds_path', type=str, default=None, help="path to learned embeds of the given image") parser.add_argument('--known_iters', type=int, default=100, help="loss scale for alpha entropy") parser.add_argument('--known_view_interval', type=int, default=4, help="do reconstruction every X iterations to save on compute") parser.add_argument('--bg_color_known', type=str, default=None, help='pixelnoise, noise, None') # pixelnoise parser.add_argument('--known_shading', type=str, default='lambertian') # DMTet and Mesh options parser.add_argument('--save_mesh', action='store_true', help="export an obj mesh with texture") parser.add_argument('--mcubes_resolution', type=int, default=256, help="mcubes resolution for extracting mesh") parser.add_argument('--decimate_target', type=int, default=5e4, help="target face number for mesh decimation") parser.add_argument('--dmtet', action='store_true', help="use dmtet finetuning") parser.add_argument('--tet_mlp', action='store_true', help="use tet_mlp finetuning") parser.add_argument('--base_mesh', default=None, help="base mesh for dmtet init") parser.add_argument('--tet_grid_size', type=int, default=256, help="tet grid size") parser.add_argument('--init_ckpt', type=str, default='', help="ckpt to init dmtet") parser.add_argument('--lock_geo', action='store_true', help="disable dmtet to learn geometry") # training options parser.add_argument('--iters', type=int, default=5000, help="training iters") parser.add_argument('--lr', type=float, default=1e-3, help="max learning rate") parser.add_argument('--lr_scale_nerf', type=float, default=1, help="max learning rate") parser.add_argument('--lr_scale_texture', type=float, default=1, help="max learning rate") parser.add_argument('--ckpt', type=str, default='latest') parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") parser.add_argument('--taichi_ray', action='store_true', help="use taichi raymarching") parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)") parser.add_argument('--num_steps', type=int, default=64, help="num steps sampled per ray (only valid when not using --cuda_ray)") parser.add_argument('--upsample_steps', type=int, default=32, help="num steps up-sampled per ray (only valid when not using --cuda_ray)") parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)") parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)") parser.add_argument('--latent_iter_ratio', type=float, default=0.0, help="training iters that only use latent normal shading") parser.add_argument('--normal_iter_ratio', type=float, default=0.0, help="training iters that only use normal shading") parser.add_argument('--textureless_iter_ratio', type=float, default=0.0, help="training iters that only use textureless shading") parser.add_argument('--albedo_iter_ratio', type=float, default=0, help="training iters that only use albedo shading") parser.add_argument('--warmup_bg_color', type=str, default=None, help="bg color [None | pixelnoise | noise | white]") parser.add_argument('--bg_color', type=str, default=None) parser.add_argument('--bg_color_test', default='white') parser.add_argument('--ema_decay', type=float, default=0.95, help="exponential moving average of model weights") parser.add_argument('--jitter_pose', action='store_true', help="add jitters to the randomly sampled camera poses") parser.add_argument('--jitter_center', type=float, default=0.2, help="amount of jitter to add to sampled camera pose's center (camera location)") parser.add_argument('--jitter_target', type=float, default=0.2, help="amount of jitter to add to sampled camera pose's target (i.e. 'look-at')") parser.add_argument('--jitter_up', type=float, default=0.02, help="amount of jitter to add to sampled camera pose's up-axis (i.e. 'camera roll')") parser.add_argument('--uniform_sphere_rate', type=float, default=0.5, help="likelihood of sampling camera location uniformly on the sphere surface area") parser.add_argument('--grad_clip', type=float, default=-1, help="clip grad of all grad to this limit, negative value disables it") parser.add_argument('--grad_clip_rgb', type=float, default=-1, help="clip grad of rgb space grad to this limit, negative value disables it") parser.add_argument('--grid_levels_mask', type=int, default=8, help="the number of levels in the feature grid to mask (to disable use 0)") parser.add_argument('--grid_levels_mask_iters', type=int, default=3000, help="the number of iterations for feature grid masking (to disable use 0)") # model options parser.add_argument('--bg_radius', type=float, default=1.4, help="if positive, use a background model at sphere(bg_radius)") parser.add_argument('--density_activation', type=str, default='exp', choices=['softplus', 'exp', 'relu'], help="density activation function") parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied") # add more strength to the center, believe the center is more likely to have objects. parser.add_argument('--blob_density', type=float, default=10, help="max (center) density for the density blob") parser.add_argument('--blob_radius', type=float, default=0.2, help="control the radius for the density blob") # network backbone parser.add_argument('--backbone', type=str, default='grid', choices=['grid', 'vanilla', 'grid_taichi'], help="nerf backbone") parser.add_argument('--grid_type', type=str, default='hashgrid', help="grid type") parser.add_argument('--hidden_dim_bg', type=int, default=32, help="channels for background network") parser.add_argument('--optim', type=str, default='adam', choices=['adan', 'adam'], help="optimizer") parser.add_argument('--sd_version', type=str, default='1.5', choices=['1.5', '2.0', '2.1'], help="stable diffusion version") parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key") # try this if CUDA OOM parser.add_argument('--fp16', action='store_true', help="use float16 for training") parser.add_argument('--vram_O', action='store_true', help="optimization for low VRAM usage") # rendering resolution in training, increase these for better quality / decrease these if CUDA OOM even if --vram_O enabled. parser.add_argument('--w', type=int, default=128, help="render width for NeRF in training") parser.add_argument('--h', type=int, default=128, help="render height for NeRF in training") parser.add_argument('--known_view_scale', type=float, default=1.5, help="multiply --h/w by this for known view rendering") parser.add_argument('--known_view_noise_scale', type=float, default=1e-3, help="random camera noise added to rays_o and rays_d") parser.add_argument('--noise_known_camera_annealing', action='store_true', help="anneal the noise to zero over the coarse of training") parser.add_argument('--dmtet_reso_scale', type=float, default=8, help="multiply --h/w by this for dmtet finetuning") parser.add_argument('--rm_edge', action='store_true', help="remove edge (ideally only enale for high resolution cases)") parser.add_argument('--edge_threshold', type=float, default=0.1, help="remove edges with value > threshold") parser.add_argument('--edge_width', type=float, default=5, help="edge width") parser.add_argument('--batch_size', type=int, default=1, help="images to render per batch using NeRF") # dataset options parser.add_argument('--bound', type=float, default=1.0, help="assume the scene is bounded in box(-bound, bound)") parser.add_argument('--dt_gamma', type=float, default=0, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") parser.add_argument('--min_near', type=float, default=0.1, help="minimum near distance for camera") parser.add_argument('--radius_range', type=float, nargs='*', default=[1.8, 1.8], help="training camera radius range") parser.add_argument('--theta_range', type=float, nargs='*', default=[45, 135], help="training camera elevation/polar range, 90 is front") parser.add_argument('--phi_range', type=float, nargs='*', default=[-180, 180], help="training camera azimuth range") parser.add_argument('--fovy_range', type=float, nargs='*', default=[40, 40], help="training camera fovy range") parser.add_argument('--default_radius', type=float, default=1.8, help="radius for the default view") parser.add_argument('--default_polar', type=float, default=90, help="polar for the default view") parser.add_argument('--default_azimuth', type=float, default=0, help="azimuth for the default view") parser.add_argument('--default_fovy', type=float, default=40, help="fovy for the default view") parser.add_argument('--progressive_view', action='store_true', help="progressively expand view sampling range from default to full") parser.add_argument('--progressive_level', action='store_true', help="progressively increase gridencoder's max_level") parser.add_argument('--angle_overhead', type=float, default=30, help="[0, angle_overhead] is the overhead region") parser.add_argument('--angle_front', type=float, default=60, help="[0, angle_front] is the front region, [180, 180+angle_front] the back region, otherwise the side region.") parser.add_argument('--t_range', type=float, nargs='*', default=[0.2, 0.6], help="stable diffusion time steps range") # regularizations parser.add_argument('--lambda_entropy', type=float, default=1e-3, help="loss scale for alpha entropy, favors 0 or 1") # Try increasing/decreasing lambda_opacity if your scene is stuffed with floaters/becoming empty. parser.add_argument('--lambda_opacity', type=float, default=0., help="loss scale for alpha value, avoid uncessary filling") # Try increasing/decreasing lambda_orient if you object is foggy/over-smoothed. parser.add_argument('--lambda_orient', type=float, default=1e-2, help="loss scale for orientation") parser.add_argument('--lambda_tv', type=float, default=0, help="loss scale for total variation of grad") parser.add_argument('--lambda_wd', type=float, default=0, help="loss scale for weight decay of grad") parser.add_argument('--lambda_normal_smooth', type=float, default=0.5, help="loss scale for first-order 2D normal image smoothness") parser.add_argument('--lambda_normal_smooth2d', type=float, default=0.5, help="loss scale for second-order 2D normal image smoothness") parser.add_argument('--lambda_3d_normal_smooth', type=float, default=0.0, help="loss scale for second-order 2D normal image smoothness") parser.add_argument('--lambda_guidance', type=float, nargs='*', default=[1], help="loss scale for SDS") parser.add_argument('--lambda_rgb', type=float, default=5, help="loss scale for RGB") parser.add_argument('--lambda_mask', type=float, default=0.5, help="loss scale for mask (alpha)") parser.add_argument('--lambda_depth', type=float, default=0.01, help="loss scale for relative depth of the known view") parser.add_argument('--lambda_normal', type=float, default=0.0, help="loss scale for normals of the known view") parser.add_argument('--lambda_depth_mse', type=float, default=0.0, help="loss scale for depth of the known view") parser.add_argument('--no_normalize_depth', action='store_false', dest='normalize_depth', help="normalize depth") # for DMTet parser.add_argument('--lambda_mesh_normal', type=float, default=0.1, help="loss scale for mesh normal smoothness") parser.add_argument('--lambda_mesh_lap', type=float, default=0.1, help="loss scale for mesh laplacian") # GUI options parser.add_argument('--gui', action='store_true', help="start a GUI") parser.add_argument('--W', type=int, default=800, help="GUI width") parser.add_argument('--H', type=int, default=800, help="GUI height") parser.add_argument('--radius', type=float, default=1.8, help="default GUI camera radius from center") parser.add_argument('--fovy', type=float, default=40, help="default GUI camera fovy") parser.add_argument('--light_theta', type=float, default=60, help="default GUI light direction in [0, 180], corresponding to elevation [90, -90]") parser.add_argument('--light_phi', type=float, default=0, help="default GUI light direction in [0, 360), azimuth") parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel") parser.add_argument('--zero123_config', type=str, default='./pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml', help="config file for zero123") parser.add_argument('--zero123_ckpt', type=str, default='./pretrained/zero123/105000.ckpt', help="ckpt for zero123") parser.add_argument('--zero123_grad_scale', type=str, default='angle', help="whether to scale the gradients based on 'angle' or 'None'") parser.add_argument('--dataset_size_train', type=int, default=100, help="Length of train dataset i.e. # of iterations per epoch") parser.add_argument('--dataset_size_valid', type=int, default=8, help="# of frames to render in the turntable video in validation") parser.add_argument('--dataset_size_test', type=int, default=100, help="# of frames to render in the turntable video at test time") def _parse_args(): args_config, remaining = config_parser.parse_known_args() if args_config.config: with open(args_config.config, 'r') as f: cfg = yaml.safe_load(f) parser.set_defaults(**cfg) # The main arg parser parses the rest of the args, the usual # defaults will have been overridden if config file specified. args = parser.parse_args(remaining) # Cache the args as a text string to save them in the output dir later args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) return args, args_text if __name__ == '__main__': args, args_text = _parse_args() opt = edict(vars(args)) if opt.O: opt.fp16 = True opt.cuda_ray = True elif opt.O2: opt.fp16 = True opt.backbone = 'vanilla' opt.images, opt.ref_radii, opt.ref_polars, opt.ref_azimuths, opt.zero123_ws = [], [], [], [], [] opt.default_zero123_w = 1 # parameters for image-conditioned generation if opt.image is not None or opt.image_config is not None: if 'zero123' in opt.guidance: # fix fov as zero123 doesn't support changing fov opt.fovy_range = [opt.default_fovy, opt.default_fovy] else: opt.known_view_interval = 2 if 'SD' in opt.guidance: opt.t_range = [0.2, 0.6] opt.bg_radius = -1 # latent warmup is not needed opt.latent_iter_ratio = 0 opt.albedo_iter_ratio = 0 if opt.image is not None: opt.images += [opt.image] opt.ref_radii += [opt.default_radius] opt.ref_polars += [opt.default_polar] opt.ref_azimuths += [opt.default_azimuth] opt.zero123_ws += [opt.default_zero123_w] if opt.image_config is not None: # for multiview (zero123) conf = pd.read_csv(opt.image_config, skipinitialspace=True) opt.images += list(conf.image) opt.ref_radii += list(conf.radius) opt.ref_polars += list(conf.polar) opt.ref_azimuths += list(conf.azimuth) opt.zero123_ws += list(conf.zero123_weight) if opt.image is None: opt.default_radius = opt.ref_radii[0] opt.default_polar = opt.ref_polars[0] opt.default_azimuth = opt.ref_azimuths[0] opt.default_zero123_w = opt.zero123_ws[0] # reset to None if len(opt.images) == 0: opt.images = None # default parameters for finetuning if opt.dmtet: opt.h = int(opt.h * opt.dmtet_reso_scale) opt.w = int(opt.w * opt.dmtet_reso_scale) opt.known_view_scale = 1 opt.grid_levels_mask = -1 # disable corse nerf (fine to keep, not necesary) opt.t_range = [0.02, 0.50] # ref: magic3D if opt.images is not None: opt.lambda_normal = 0 opt.lambda_depth = 0 # assume finetuning opt.latent_iter_ratio = 0 opt.textureless_iter_ratio = 0 opt.albedo_iter_ratio = 0 opt.normal_iter_ratio = 0 opt.progressive_view = False opt.progressive_level = False # record full range for progressive view expansion if opt.progressive_view: # disable as they disturb progressive view opt.jitter_pose = False opt.uniform_sphere_rate = 0 # back up full range opt.full_radius_range = opt.radius_range opt.full_theta_range = opt.theta_range opt.full_phi_range = opt.phi_range opt.full_fovy_range = opt.fovy_range opt.use_clip = opt.clip_guidance > 0 or opt.lambda_clip > 0 # Do not support Shap-E for NeRF yet. opt.use_shape = False if not opt.dmtet else opt.use_shape # workspace prepare setup_workspace(opt) dnnultis.setup_logging(opt.log_path) if opt.seed < 0: opt.seed = random.randint(0, 10000) seed_everything(int(opt.seed)) if opt.backbone == 'vanilla': from nerf.network import NeRFNetwork elif opt.backbone == 'grid': from nerf.network_grid import NeRFNetwork elif opt.backbone == 'grid_tcnn': from nerf.network_grid_tcnn import NeRFNetwork elif opt.backbone == 'grid_taichi': opt.cuda_ray = False opt.taichi_ray = True import taichi as ti from nerf.network_grid_taichi import NeRFNetwork taichi_half2_opt = True taichi_init_args = {"arch": ti.cuda, "device_memory_GB": 4.0} if taichi_half2_opt: taichi_init_args["half2_vectorization"] = True ti.init(**taichi_init_args) else: raise NotImplementedError( f'--backbone {opt.backbone} is not implemented!') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') opt.device = device model = NeRFNetwork(opt).to(device) if opt.init_ckpt != '': if not os.path.exists(opt.init_ckpt): logger.warning(f'ckpt {opt.init_ckpt} is not found') else: # load pretrained weights to init dmtet state_dict = torch.load(opt.init_ckpt, map_location=device) model.load_state_dict(state_dict['model'], strict=False) if opt.cuda_ray: model.mean_density = state_dict['mean_density'] logger.info(f'init from {opt.init_ckpt}...') # if init ckpt is provided, we assume the color network is well learned and do not need base_mesh init opt.shape_init_color = False opt.base_mesh = None if opt.use_shape and opt.dmtet: # now only supports shape for dmtet init from guidance.shape_utils import get_shape_from_image opt.points = generate_grid_points( 128, device=device) if not opt.dmtet else model.dmtet.verts opt.rpsts, opt.colors = get_shape_from_image( opt.image.replace('rgba', 'rgb'), opt.points, rpst_type=opt.shape_rpst, get_color=opt.shape_init_color, shape_guidance=opt.shape_guidance, device=device) scale = opt.default_radius / opt.shape_radius * \ np.tan(np.deg2rad(opt.default_fovy / 2)) / \ np.tan(np.deg2rad(opt.shape_fovy / 2)) if opt.dmtet: model.dmtet.reset_tet_scale(scale) else: opt.points *= scale logger.info(f'Got sdf from Shap-E init...') logger.info(model) if opt.six_views: guidance = None # no need to load guidance model at test trainer = Trainer(' '.join(sys.argv), 'df', opt, model, guidance, device=device, workspace=opt.workspace, fp16=opt.fp16, use_checkpoint=opt.ckpt) test_loader = NeRFDataset( opt, device=device, type='six_views', H=opt.H, W=opt.W, size=6).dataloader(batch_size=1) trainer.test(test_loader, write_video=False) if opt.save_mesh: trainer.save_mesh() elif opt.test: guidance = None # no need to load guidance model at test trainer = Trainer(' '.join(sys.argv), os.path.basename(opt.workspace), opt, model, guidance, device=device, workspace=opt.workspace, fp16=opt.fp16, use_checkpoint=opt.ckpt) if opt.gui: from nerf.gui import NeRFGUI gui = NeRFGUI(opt, trainer) gui.render() else: test_loader = NeRFDataset( opt, device=device, type='test', H=opt.H, W=opt.W, size=opt.dataset_size_test).dataloader() trainer.test(test_loader) trainer.test(test_loader, shading='normal') # save normal if opt.save_mesh: try: trainer.save_mesh() except: pass else: train_loader = NeRFDataset( opt, device=device, type='train', H=opt.h, W=opt.w, size=opt.dataset_size_train * opt.batch_size).dataloader() if opt.optim == 'adan': from optimizer import Adan # Adan usually requires a larger LR def optimizer(model): return Adan(model.get_params( 5 * opt.lr), eps=1e-8, weight_decay=2e-5, max_grad_norm=5.0, foreach=False) else: # adam def optimizer(model): return torch.optim.Adam( model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15) if opt.backbone == 'vanilla': def scheduler(optimizer): return optim.lr_scheduler.LambdaLR( optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1)) else: def scheduler(optimizer): return optim.lr_scheduler.LambdaLR( optimizer, lambda iter: 1) # fixed guidance = nn.ModuleDict() lambda_guidance, guidance_scale = {}, {} for idx, guidance_type in enumerate(opt.guidance): lambda_guidance[guidance_type] = opt.lambda_guidance[idx] if idx < len( opt.lambda_guidance) else opt.lambda_guidance[-1] guidance_scale[guidance_type] = opt.guidance_scale[idx] if idx < len( opt.guidance_scale) else opt.guidance_scale[-1] if 'SD' == guidance_type: from guidance.sd_utils import StableDiffusion, token_replace guidance['SD'] = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key, opt.t_range, learned_embeds_path=opt.learned_embeds_path, use_clip=opt.use_clip, clip_t=opt.clip_t, clip_iterative=opt.clip_iterative, clip_version=opt.clip_version, ) if opt.learned_embeds_path is not None and os.path.exists(opt.learned_embeds_path): # add textual inversion tokens to model opt.text, opt.negative = token_replace( opt.text, opt.negative, opt.learned_embeds_path) logger.info( f'prompt: {opt.text}, negative: {opt.negative}') if opt.check_prompt: guidance['SD'].check_prompt(opt) else: opt.text = opt.text.replace('', os.path.basename(os.path.dirname(opt.image))) logger.warning('No learned_embeds_path provided, using the folowing pure text prompt with degraded performance: ' + opt.text) if 'IF' == guidance_type: from guidance.if_utils import IF guidance['IF'] = IF(device, opt.vram_O, opt.t_range) if 'zero123' == guidance_type: from guidance.zero123_utils import Zero123 guidance['zero123'] = Zero123(device=device, fp16=opt.fp16, config=opt.zero123_config, ckpt=opt.zero123_ckpt, vram_O=opt.vram_O, t_range=opt.t_range, opt=opt) if 'clip' == guidance_type: from guidance.clip_utils import CLIP guidance['clip'] = CLIP(device) opt.lambda_guidance = lambda_guidance opt.guidance_scale = guidance_scale logger.info(opt) trainer = Trainer(' '.join(sys.argv), os.path.basename(opt.workspace), opt, model, guidance, device=device, workspace=opt.workspace, optimizer=optimizer, ema_decay=opt.ema_decay, fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint=opt.ckpt, scheduler_update_every_step=True) trainer.default_view_data = train_loader._data.get_default_view_data() if opt.gui: from nerf.gui import NeRFGUI gui = NeRFGUI(opt, trainer, train_loader) gui.render() else: valid_loader = NeRFDataset( opt, device=device, type='val', H=opt.H, W=opt.W, size=opt.dataset_size_valid).dataloader() test_loader = NeRFDataset( opt, device=device, type='test', H=opt.H, W=opt.W, size=opt.dataset_size_test).dataloader() max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32) trainer.train(train_loader, valid_loader, test_loader, max_epoch) trainer.test(test_loader) trainer.test(test_loader, shading='normal') # save normal if opt.save_mesh: try: trainer.save_mesh() except: pass