Files
Magic123/guidance/zero123_utils.py
Guocheng Qian 13e18567fa first commit
2023-08-02 19:51:43 -07:00

333 lines
14 KiB
Python

import math
import numpy as np
from omegaconf import OmegaConf
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import custom_bwd, custom_fwd
from torchvision.utils import save_image
from diffusers import DDIMScheduler
import sys
from os import path
sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
from ldm.util import instantiate_from_config
class SpecifyGradient(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, input_tensor, gt_grad):
ctx.save_for_backward(gt_grad)
# we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
@staticmethod
@custom_bwd
def backward(ctx, grad_scale):
gt_grad, = ctx.saved_tensors
gt_grad = gt_grad * grad_scale
return gt_grad, None
# load model
def load_model_from_config(config, ckpt, device, vram_O=False, verbose=False):
pl_sd = torch.load(ckpt, map_location='cpu')
if 'global_step' in pl_sd and verbose:
print(f'[INFO] Global Step: {pl_sd["global_step"]}')
sd = pl_sd['state_dict']
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print('[INFO] missing keys: \n', m)
if len(u) > 0 and verbose:
print('[INFO] unexpected keys: \n', u)
# manually load ema and delete it to save GPU memory
if model.use_ema:
if verbose:
print('[INFO] loading EMA...')
model.model_ema.copy_to(model.model)
del model.model_ema
if vram_O:
# we don't need decoder
del model.first_stage_model.decoder
torch.cuda.empty_cache()
model.eval().to(device)
return model
class Zero123(nn.Module):
def __init__(self, device, fp16,
config='./pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml',
ckpt='./pretrained/zero123/105000.ckpt', vram_O=False, t_range=[0.02, 0.98], opt=None):
super().__init__()
self.device = device
self.fp16 = fp16
self.vram_O = vram_O
self.t_range = t_range
self.opt = opt
self.config = OmegaConf.load(config)
self.model = load_model_from_config(self.config, ckpt, device=self.device, vram_O=vram_O)
# timesteps: use diffuser for convenience... hope it's alright.
self.num_train_timesteps = self.config.model.params.timesteps
self.scheduler = DDIMScheduler(
self.num_train_timesteps,
self.config.model.params.linear_start,
self.config.model.params.linear_end,
beta_schedule='scaled_linear',
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
self.min_step = int(self.num_train_timesteps * t_range[0])
self.max_step = int(self.num_train_timesteps * t_range[1])
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
@torch.no_grad()
def get_img_embeds(self, x):
# x: image tensor [B, 3, 256, 256] in [0, 1]
x = x * 2 - 1
c = [self.model.get_learned_conditioning(xx.unsqueeze(0)) for xx in x] #.tile(n_samples, 1, 1)
v = [self.model.encode_first_stage(xx.unsqueeze(0)).mode() for xx in x]
return c, v
def angle_between(self, sph_v1, sph_v2):
def sph2cart(sv):
r, theta, phi = sv[0], sv[1], sv[2]
return torch.tensor([r * torch.sin(theta) * torch.cos(phi), r * torch.sin(theta) * torch.sin(phi), r * torch.cos(theta)])
def unit_vector(v):
return v / torch.linalg.norm(v)
def angle_between_2_sph(sv1, sv2):
v1, v2 = sph2cart(sv1), sph2cart(sv2)
v1_u, v2_u = unit_vector(v1), unit_vector(v2)
return torch.arccos(torch.clip(torch.dot(v1_u, v2_u), -1.0, 1.0))
angles = torch.empty(len(sph_v1), len(sph_v2))
for i, sv1 in enumerate(sph_v1):
for j, sv2 in enumerate(sph_v2):
angles[i][j] = angle_between_2_sph(sv1, sv2)
return angles
def train_step(self, embeddings, pred_rgb, polar, azimuth, radius, guidance_scale=3, as_latent=False, grad_scale=1, save_guidance_path:Path=None):
# pred_rgb: tensor [1, 3, H, W] in [0, 1]
# adjust SDS scale based on how far the novel view is from the known view
ref_radii = embeddings['ref_radii']
ref_polars = embeddings['ref_polars']
ref_azimuths = embeddings['ref_azimuths']
v1 = torch.stack([radius + ref_radii[0], torch.deg2rad(polar + ref_polars[0]), torch.deg2rad(azimuth + ref_azimuths[0])], dim=-1) # polar,azimuth,radius are all actually delta wrt default
v2 = torch.stack([torch.tensor(ref_radii), torch.deg2rad(torch.tensor(ref_polars)), torch.deg2rad(torch.tensor(ref_azimuths))], dim=-1)
angles = torch.rad2deg(self.angle_between(v1, v2)).to(self.device)
if self.opt.zero123_grad_scale == 'angle':
grad_scale = (angles.min(dim=1)[0] / (180/len(ref_azimuths))) * grad_scale # rethink 180/len(ref_azimuths) # claforte: try inverting grad_scale or just fixing it to 1.0
elif self.opt.zero123_grad_scale == 'None':
grad_scale = 1.0 # claforte: I think this might converge faster...?
else:
assert False, f'Unrecognized `zero123_grad_scale`: {self.opt.zero123_grad_scale}'
if as_latent:
latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1
else:
pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False)
latents = self.encode_imgs(pred_rgb_256)
t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device)
# Set weights acc to closeness in angle
if len(ref_azimuths) > 1:
inv_angles = 1/angles
inv_angles[inv_angles > 100] = 100
inv_angles /= inv_angles.max(dim=-1, keepdim=True)[0]
inv_angles[inv_angles < 0.1] = 0
else:
inv_angles = torch.tensor([1.]).to(self.device)
# Multiply closeness-weight by user-given weights
zero123_ws = torch.tensor(embeddings['zero123_ws'])[None, :].to(self.device) * inv_angles
zero123_ws /= zero123_ws.max(dim=-1, keepdim=True)[0]
zero123_ws[zero123_ws < 0.1] = 0
with torch.no_grad():
noise = torch.randn_like(latents)
latents_noisy = self.scheduler.add_noise(latents, noise, t)
x_in = torch.cat([latents_noisy] * 2)
t_in = torch.cat([t] * 2)
noise_preds = []
# Loop through each ref image
for (zero123_w, c_crossattn, c_concat, ref_polar, ref_azimuth, ref_radius) in zip(zero123_ws.T,
embeddings['c_crossattn'], embeddings['c_concat'],
ref_polars, ref_azimuths, ref_radii):
# polar,azimuth,radius are all actually delta wrt default
p = polar + ref_polars[0] - ref_polar
a = azimuth + ref_azimuths[0] - ref_azimuth
a[a > 180] -= 360 # range in [-180, 180]
r = radius + ref_radii[0] - ref_radius
# T = torch.tensor([math.radians(p), math.sin(math.radians(-a)), math.cos(math.radians(a)), r])
# T = T[None, None, :].to(self.device)
T = torch.stack([torch.deg2rad(p), torch.sin(torch.deg2rad(-a)), torch.cos(torch.deg2rad(a)), r], dim=-1)[:, None, :]
cond = {}
clip_emb = self.model.cc_projection(torch.cat([c_crossattn.repeat(len(T), 1, 1), T], dim=-1))
cond['c_crossattn'] = [torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)]
cond['c_concat'] = [torch.cat([torch.zeros_like(c_concat).repeat(len(T), 1, 1, 1).to(self.device), c_concat.repeat(len(T), 1, 1, 1)], dim=0)]
noise_pred = self.model.apply_model(x_in, t_in, cond)
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
noise_preds.append(zero123_w[:, None, None, None] * noise_pred)
noise_pred = torch.stack(noise_preds).sum(dim=0) / zero123_ws.sum(dim=-1)[:, None, None, None]
w = (1 - self.alphas[t])
grad = (grad_scale * w)[:, None, None, None] * (noise_pred - noise)
grad = torch.nan_to_num(grad)
# import kiui
# if not as_latent:
# kiui.vis.plot_image(pred_rgb_256)
# kiui.vis.plot_matrix(latents)
# kiui.vis.plot_matrix(grad)
# import kiui
# latents = torch.randn((1, 4, 32, 32), device=self.device)
# kiui.lo(latents)
# self.scheduler.set_timesteps(30)
# with torch.no_grad():
# for i, t in enumerate(self.scheduler.timesteps):
# x_in = torch.cat([latents] * 2)
# t_in = torch.cat([t.view(1)] * 2).to(self.device)
# noise_pred = self.model.apply_model(x_in, t_in, cond)
# noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
# noise_pred = noise_pred_uncond + 3 * (noise_pred_cond - noise_pred_uncond)
# latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']
# imgs = self.decode_latents(latents)
# print(polar, azimuth, radius)
# kiui.vis.plot_image(pred_rgb_256, imgs)
if save_guidance_path:
with torch.no_grad():
if as_latent:
pred_rgb_256 = self.decode_latents(latents) # claforte: test!
# visualize predicted denoised image
result_hopefully_less_noisy_image = self.decode_latents(self.model.predict_start_from_noise(latents_noisy, t, noise_pred))
# visualize noisier image
result_noisier_image = self.decode_latents(latents_noisy)
# all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]
viz_images = torch.cat([pred_rgb_256, result_noisier_image, result_hopefully_less_noisy_image],dim=-1)
save_image(viz_images, save_guidance_path)
# since we omitted an item in grad, we need to use the custom function to specify the gradient
# loss = SpecifyGradient.apply(latents, grad)
latents.backward(gradient=grad, retain_graph=True)
loss = grad.abs().mean().detach()
return loss
# verification
@torch.no_grad()
def __call__(self,
image, # image tensor [1, 3, H, W] in [0, 1]
polar=0, azimuth=0, radius=0, # new view params
scale=3, ddim_steps=50, ddim_eta=1, h=256, w=256, # diffusion params
c_crossattn=None, c_concat=None, post_process=True,
):
if c_crossattn is None:
embeddings = self.get_img_embeds(image)
T = torch.tensor([math.radians(polar), math.sin(math.radians(azimuth)), math.cos(math.radians(azimuth)), radius])
T = T[None, None, :].to(self.device)
cond = {}
clip_emb = self.model.cc_projection(torch.cat([embeddings['c_crossattn'] if c_crossattn is None else c_crossattn, T], dim=-1))
cond['c_crossattn'] = [torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)]
cond['c_concat'] = [torch.cat([torch.zeros_like(embeddings['c_concat']).to(self.device), embeddings['c_concat']], dim=0)] if c_concat is None else [torch.cat([torch.zeros_like(c_concat).to(self.device), c_concat], dim=0)]
# produce latents loop
latents = torch.randn((1, 4, h // 8, w // 8), device=self.device)
self.scheduler.set_timesteps(ddim_steps)
for i, t in enumerate(self.scheduler.timesteps):
x_in = torch.cat([latents] * 2)
t_in = torch.cat([t.view(1)] * 2).to(self.device)
noise_pred = self.model.apply_model(x_in, t_in, cond)
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + scale * (noise_pred_cond - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, t, latents, eta=ddim_eta)['prev_sample']
imgs = self.decode_latents(latents)
imgs = imgs.cpu().numpy().transpose(0, 2, 3, 1) if post_process else imgs
return imgs
def decode_latents(self, latents):
# zs: [B, 4, 32, 32] Latent space image
# with self.model.ema_scope():
imgs = self.model.decode_first_stage(latents)
imgs = (imgs / 2 + 0.5).clamp(0, 1)
return imgs # [B, 3, 256, 256] RGB space image
def encode_imgs(self, imgs):
# imgs: [B, 3, 256, 256] RGB space image
# with self.model.ema_scope():
imgs = imgs * 2 - 1
latents = torch.cat([self.model.get_first_stage_encoding(self.model.encode_first_stage(img.unsqueeze(0))) for img in imgs], dim=0)
return latents # [B, 4, 32, 32] Latent space image
if __name__ == '__main__':
import cv2
import argparse
import numpy as np
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument('input', type=str)
parser.add_argument('--fp16', action='store_true', help="use float16 for training") # no use now, can only run in fp32
parser.add_argument('--polar', type=float, default=0, help='delta polar angle in [-90, 90]')
parser.add_argument('--azimuth', type=float, default=0, help='delta azimuth angle in [-180, 180]')
parser.add_argument('--radius', type=float, default=0, help='delta camera radius multiplier in [-0.5, 0.5]')
opt = parser.parse_args()
device = torch.device('cuda')
print(f'[INFO] loading image from {opt.input} ...')
image = cv2.imread(opt.input, cv2.IMREAD_UNCHANGED)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA)
image = image.astype(np.float32) / 255.0
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).contiguous().to(device)
print(f'[INFO] loading model ...')
zero123 = Zero123(device, opt.fp16, opt=opt)
print(f'[INFO] running model ...')
outputs = zero123(image, polar=opt.polar, azimuth=opt.azimuth, radius=opt.radius)
plt.imshow(outputs[0])
plt.show()