333 lines
14 KiB
Python
333 lines
14 KiB
Python
import math
|
|
import numpy as np
|
|
from omegaconf import OmegaConf
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
|
from torchvision.utils import save_image
|
|
|
|
from diffusers import DDIMScheduler
|
|
|
|
import sys
|
|
from os import path
|
|
sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
|
|
|
|
from ldm.util import instantiate_from_config
|
|
|
|
class SpecifyGradient(torch.autograd.Function):
|
|
@staticmethod
|
|
@custom_fwd
|
|
def forward(ctx, input_tensor, gt_grad):
|
|
ctx.save_for_backward(gt_grad)
|
|
# we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
|
|
return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
|
|
|
|
@staticmethod
|
|
@custom_bwd
|
|
def backward(ctx, grad_scale):
|
|
gt_grad, = ctx.saved_tensors
|
|
gt_grad = gt_grad * grad_scale
|
|
return gt_grad, None
|
|
|
|
# load model
|
|
def load_model_from_config(config, ckpt, device, vram_O=False, verbose=False):
|
|
|
|
pl_sd = torch.load(ckpt, map_location='cpu')
|
|
|
|
if 'global_step' in pl_sd and verbose:
|
|
print(f'[INFO] Global Step: {pl_sd["global_step"]}')
|
|
|
|
sd = pl_sd['state_dict']
|
|
|
|
model = instantiate_from_config(config.model)
|
|
m, u = model.load_state_dict(sd, strict=False)
|
|
|
|
if len(m) > 0 and verbose:
|
|
print('[INFO] missing keys: \n', m)
|
|
if len(u) > 0 and verbose:
|
|
print('[INFO] unexpected keys: \n', u)
|
|
|
|
# manually load ema and delete it to save GPU memory
|
|
if model.use_ema:
|
|
if verbose:
|
|
print('[INFO] loading EMA...')
|
|
model.model_ema.copy_to(model.model)
|
|
del model.model_ema
|
|
|
|
if vram_O:
|
|
# we don't need decoder
|
|
del model.first_stage_model.decoder
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
model.eval().to(device)
|
|
|
|
return model
|
|
|
|
class Zero123(nn.Module):
|
|
def __init__(self, device, fp16,
|
|
config='./pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml',
|
|
ckpt='./pretrained/zero123/105000.ckpt', vram_O=False, t_range=[0.02, 0.98], opt=None):
|
|
super().__init__()
|
|
|
|
self.device = device
|
|
self.fp16 = fp16
|
|
self.vram_O = vram_O
|
|
self.t_range = t_range
|
|
self.opt = opt
|
|
|
|
self.config = OmegaConf.load(config)
|
|
self.model = load_model_from_config(self.config, ckpt, device=self.device, vram_O=vram_O)
|
|
|
|
# timesteps: use diffuser for convenience... hope it's alright.
|
|
self.num_train_timesteps = self.config.model.params.timesteps
|
|
|
|
self.scheduler = DDIMScheduler(
|
|
self.num_train_timesteps,
|
|
self.config.model.params.linear_start,
|
|
self.config.model.params.linear_end,
|
|
beta_schedule='scaled_linear',
|
|
clip_sample=False,
|
|
set_alpha_to_one=False,
|
|
steps_offset=1,
|
|
)
|
|
|
|
self.min_step = int(self.num_train_timesteps * t_range[0])
|
|
self.max_step = int(self.num_train_timesteps * t_range[1])
|
|
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
|
|
|
|
@torch.no_grad()
|
|
def get_img_embeds(self, x):
|
|
# x: image tensor [B, 3, 256, 256] in [0, 1]
|
|
x = x * 2 - 1
|
|
c = [self.model.get_learned_conditioning(xx.unsqueeze(0)) for xx in x] #.tile(n_samples, 1, 1)
|
|
v = [self.model.encode_first_stage(xx.unsqueeze(0)).mode() for xx in x]
|
|
return c, v
|
|
|
|
def angle_between(self, sph_v1, sph_v2):
|
|
def sph2cart(sv):
|
|
r, theta, phi = sv[0], sv[1], sv[2]
|
|
return torch.tensor([r * torch.sin(theta) * torch.cos(phi), r * torch.sin(theta) * torch.sin(phi), r * torch.cos(theta)])
|
|
def unit_vector(v):
|
|
return v / torch.linalg.norm(v)
|
|
def angle_between_2_sph(sv1, sv2):
|
|
v1, v2 = sph2cart(sv1), sph2cart(sv2)
|
|
v1_u, v2_u = unit_vector(v1), unit_vector(v2)
|
|
return torch.arccos(torch.clip(torch.dot(v1_u, v2_u), -1.0, 1.0))
|
|
angles = torch.empty(len(sph_v1), len(sph_v2))
|
|
for i, sv1 in enumerate(sph_v1):
|
|
for j, sv2 in enumerate(sph_v2):
|
|
angles[i][j] = angle_between_2_sph(sv1, sv2)
|
|
return angles
|
|
|
|
def train_step(self, embeddings, pred_rgb, polar, azimuth, radius, guidance_scale=3, as_latent=False, grad_scale=1, save_guidance_path:Path=None):
|
|
# pred_rgb: tensor [1, 3, H, W] in [0, 1]
|
|
|
|
# adjust SDS scale based on how far the novel view is from the known view
|
|
ref_radii = embeddings['ref_radii']
|
|
ref_polars = embeddings['ref_polars']
|
|
ref_azimuths = embeddings['ref_azimuths']
|
|
v1 = torch.stack([radius + ref_radii[0], torch.deg2rad(polar + ref_polars[0]), torch.deg2rad(azimuth + ref_azimuths[0])], dim=-1) # polar,azimuth,radius are all actually delta wrt default
|
|
v2 = torch.stack([torch.tensor(ref_radii), torch.deg2rad(torch.tensor(ref_polars)), torch.deg2rad(torch.tensor(ref_azimuths))], dim=-1)
|
|
angles = torch.rad2deg(self.angle_between(v1, v2)).to(self.device)
|
|
if self.opt.zero123_grad_scale == 'angle':
|
|
grad_scale = (angles.min(dim=1)[0] / (180/len(ref_azimuths))) * grad_scale # rethink 180/len(ref_azimuths) # claforte: try inverting grad_scale or just fixing it to 1.0
|
|
elif self.opt.zero123_grad_scale == 'None':
|
|
grad_scale = 1.0 # claforte: I think this might converge faster...?
|
|
else:
|
|
assert False, f'Unrecognized `zero123_grad_scale`: {self.opt.zero123_grad_scale}'
|
|
|
|
if as_latent:
|
|
latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1
|
|
else:
|
|
pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False)
|
|
latents = self.encode_imgs(pred_rgb_256)
|
|
|
|
t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device)
|
|
|
|
# Set weights acc to closeness in angle
|
|
if len(ref_azimuths) > 1:
|
|
inv_angles = 1/angles
|
|
inv_angles[inv_angles > 100] = 100
|
|
inv_angles /= inv_angles.max(dim=-1, keepdim=True)[0]
|
|
inv_angles[inv_angles < 0.1] = 0
|
|
else:
|
|
inv_angles = torch.tensor([1.]).to(self.device)
|
|
|
|
# Multiply closeness-weight by user-given weights
|
|
zero123_ws = torch.tensor(embeddings['zero123_ws'])[None, :].to(self.device) * inv_angles
|
|
zero123_ws /= zero123_ws.max(dim=-1, keepdim=True)[0]
|
|
zero123_ws[zero123_ws < 0.1] = 0
|
|
|
|
with torch.no_grad():
|
|
noise = torch.randn_like(latents)
|
|
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
|
|
|
x_in = torch.cat([latents_noisy] * 2)
|
|
t_in = torch.cat([t] * 2)
|
|
|
|
noise_preds = []
|
|
# Loop through each ref image
|
|
for (zero123_w, c_crossattn, c_concat, ref_polar, ref_azimuth, ref_radius) in zip(zero123_ws.T,
|
|
embeddings['c_crossattn'], embeddings['c_concat'],
|
|
ref_polars, ref_azimuths, ref_radii):
|
|
# polar,azimuth,radius are all actually delta wrt default
|
|
p = polar + ref_polars[0] - ref_polar
|
|
a = azimuth + ref_azimuths[0] - ref_azimuth
|
|
a[a > 180] -= 360 # range in [-180, 180]
|
|
r = radius + ref_radii[0] - ref_radius
|
|
# T = torch.tensor([math.radians(p), math.sin(math.radians(-a)), math.cos(math.radians(a)), r])
|
|
# T = T[None, None, :].to(self.device)
|
|
T = torch.stack([torch.deg2rad(p), torch.sin(torch.deg2rad(-a)), torch.cos(torch.deg2rad(a)), r], dim=-1)[:, None, :]
|
|
cond = {}
|
|
clip_emb = self.model.cc_projection(torch.cat([c_crossattn.repeat(len(T), 1, 1), T], dim=-1))
|
|
cond['c_crossattn'] = [torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)]
|
|
cond['c_concat'] = [torch.cat([torch.zeros_like(c_concat).repeat(len(T), 1, 1, 1).to(self.device), c_concat.repeat(len(T), 1, 1, 1)], dim=0)]
|
|
noise_pred = self.model.apply_model(x_in, t_in, cond)
|
|
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
|
noise_preds.append(zero123_w[:, None, None, None] * noise_pred)
|
|
|
|
noise_pred = torch.stack(noise_preds).sum(dim=0) / zero123_ws.sum(dim=-1)[:, None, None, None]
|
|
|
|
w = (1 - self.alphas[t])
|
|
grad = (grad_scale * w)[:, None, None, None] * (noise_pred - noise)
|
|
grad = torch.nan_to_num(grad)
|
|
|
|
# import kiui
|
|
# if not as_latent:
|
|
# kiui.vis.plot_image(pred_rgb_256)
|
|
# kiui.vis.plot_matrix(latents)
|
|
# kiui.vis.plot_matrix(grad)
|
|
|
|
# import kiui
|
|
# latents = torch.randn((1, 4, 32, 32), device=self.device)
|
|
# kiui.lo(latents)
|
|
# self.scheduler.set_timesteps(30)
|
|
# with torch.no_grad():
|
|
# for i, t in enumerate(self.scheduler.timesteps):
|
|
# x_in = torch.cat([latents] * 2)
|
|
# t_in = torch.cat([t.view(1)] * 2).to(self.device)
|
|
|
|
# noise_pred = self.model.apply_model(x_in, t_in, cond)
|
|
# noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
|
# noise_pred = noise_pred_uncond + 3 * (noise_pred_cond - noise_pred_uncond)
|
|
|
|
# latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']
|
|
# imgs = self.decode_latents(latents)
|
|
# print(polar, azimuth, radius)
|
|
# kiui.vis.plot_image(pred_rgb_256, imgs)
|
|
|
|
if save_guidance_path:
|
|
with torch.no_grad():
|
|
if as_latent:
|
|
pred_rgb_256 = self.decode_latents(latents) # claforte: test!
|
|
|
|
# visualize predicted denoised image
|
|
result_hopefully_less_noisy_image = self.decode_latents(self.model.predict_start_from_noise(latents_noisy, t, noise_pred))
|
|
|
|
# visualize noisier image
|
|
result_noisier_image = self.decode_latents(latents_noisy)
|
|
|
|
# all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]
|
|
viz_images = torch.cat([pred_rgb_256, result_noisier_image, result_hopefully_less_noisy_image],dim=-1)
|
|
save_image(viz_images, save_guidance_path)
|
|
|
|
# since we omitted an item in grad, we need to use the custom function to specify the gradient
|
|
# loss = SpecifyGradient.apply(latents, grad)
|
|
latents.backward(gradient=grad, retain_graph=True)
|
|
loss = grad.abs().mean().detach()
|
|
return loss
|
|
|
|
# verification
|
|
@torch.no_grad()
|
|
def __call__(self,
|
|
image, # image tensor [1, 3, H, W] in [0, 1]
|
|
polar=0, azimuth=0, radius=0, # new view params
|
|
scale=3, ddim_steps=50, ddim_eta=1, h=256, w=256, # diffusion params
|
|
c_crossattn=None, c_concat=None, post_process=True,
|
|
):
|
|
|
|
if c_crossattn is None:
|
|
embeddings = self.get_img_embeds(image)
|
|
|
|
T = torch.tensor([math.radians(polar), math.sin(math.radians(azimuth)), math.cos(math.radians(azimuth)), radius])
|
|
T = T[None, None, :].to(self.device)
|
|
|
|
cond = {}
|
|
clip_emb = self.model.cc_projection(torch.cat([embeddings['c_crossattn'] if c_crossattn is None else c_crossattn, T], dim=-1))
|
|
cond['c_crossattn'] = [torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)]
|
|
cond['c_concat'] = [torch.cat([torch.zeros_like(embeddings['c_concat']).to(self.device), embeddings['c_concat']], dim=0)] if c_concat is None else [torch.cat([torch.zeros_like(c_concat).to(self.device), c_concat], dim=0)]
|
|
|
|
# produce latents loop
|
|
latents = torch.randn((1, 4, h // 8, w // 8), device=self.device)
|
|
self.scheduler.set_timesteps(ddim_steps)
|
|
|
|
for i, t in enumerate(self.scheduler.timesteps):
|
|
x_in = torch.cat([latents] * 2)
|
|
t_in = torch.cat([t.view(1)] * 2).to(self.device)
|
|
|
|
noise_pred = self.model.apply_model(x_in, t_in, cond)
|
|
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
|
noise_pred = noise_pred_uncond + scale * (noise_pred_cond - noise_pred_uncond)
|
|
|
|
latents = self.scheduler.step(noise_pred, t, latents, eta=ddim_eta)['prev_sample']
|
|
|
|
imgs = self.decode_latents(latents)
|
|
imgs = imgs.cpu().numpy().transpose(0, 2, 3, 1) if post_process else imgs
|
|
|
|
return imgs
|
|
|
|
def decode_latents(self, latents):
|
|
# zs: [B, 4, 32, 32] Latent space image
|
|
# with self.model.ema_scope():
|
|
imgs = self.model.decode_first_stage(latents)
|
|
imgs = (imgs / 2 + 0.5).clamp(0, 1)
|
|
|
|
return imgs # [B, 3, 256, 256] RGB space image
|
|
|
|
def encode_imgs(self, imgs):
|
|
# imgs: [B, 3, 256, 256] RGB space image
|
|
# with self.model.ema_scope():
|
|
imgs = imgs * 2 - 1
|
|
latents = torch.cat([self.model.get_first_stage_encoding(self.model.encode_first_stage(img.unsqueeze(0))) for img in imgs], dim=0)
|
|
return latents # [B, 4, 32, 32] Latent space image
|
|
|
|
|
|
if __name__ == '__main__':
|
|
import cv2
|
|
import argparse
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument('input', type=str)
|
|
parser.add_argument('--fp16', action='store_true', help="use float16 for training") # no use now, can only run in fp32
|
|
|
|
parser.add_argument('--polar', type=float, default=0, help='delta polar angle in [-90, 90]')
|
|
parser.add_argument('--azimuth', type=float, default=0, help='delta azimuth angle in [-180, 180]')
|
|
parser.add_argument('--radius', type=float, default=0, help='delta camera radius multiplier in [-0.5, 0.5]')
|
|
|
|
opt = parser.parse_args()
|
|
|
|
device = torch.device('cuda')
|
|
|
|
print(f'[INFO] loading image from {opt.input} ...')
|
|
image = cv2.imread(opt.input, cv2.IMREAD_UNCHANGED)
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA)
|
|
image = image.astype(np.float32) / 255.0
|
|
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).contiguous().to(device)
|
|
|
|
print(f'[INFO] loading model ...')
|
|
zero123 = Zero123(device, opt.fp16, opt=opt)
|
|
|
|
print(f'[INFO] running model ...')
|
|
outputs = zero123(image, polar=opt.polar, azimuth=opt.azimuth, radius=opt.radius)
|
|
plt.imshow(outputs[0])
|
|
plt.show()
|