first commit
This commit is contained in:
52
guidance/clip_utils.py
Normal file
52
guidance/clip_utils.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as TF
|
||||
|
||||
import clip
|
||||
|
||||
class CLIP(nn.Module):
|
||||
def __init__(self, device, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.device = device
|
||||
self.clip_model, self.clip_preprocess = clip.load("ViT-B/16", device=self.device, jit=False)
|
||||
|
||||
self.aug = T.Compose([
|
||||
T.Resize((224, 224)),
|
||||
T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
])
|
||||
|
||||
def get_text_embeds(self, prompt, **kwargs):
|
||||
|
||||
text = clip.tokenize(prompt).to(self.device)
|
||||
text_z = self.clip_model.encode_text(text)
|
||||
text_z = text_z / text_z.norm(dim=-1, keepdim=True)
|
||||
|
||||
return text_z
|
||||
|
||||
def get_img_embeds(self, image, **kwargs):
|
||||
|
||||
image_z = self.clip_model.encode_image(self.aug(image))
|
||||
image_z = image_z / image_z.norm(dim=-1, keepdim=True)
|
||||
|
||||
return image_z
|
||||
|
||||
|
||||
def train_step(self, clip_z, pred_rgb, grad_scale=10, **kwargs):
|
||||
|
||||
image_z = self.clip_model.encode_image(self.aug(pred_rgb))
|
||||
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
|
||||
|
||||
loss = 0
|
||||
if 'image' in clip_z:
|
||||
loss = loss - (image_z * clip_z['image']).sum(-1).mean()
|
||||
|
||||
if 'text' in clip_z:
|
||||
loss = loss - (image_z * clip_z['text']).sum(-1).mean()
|
||||
|
||||
loss = loss * grad_scale
|
||||
|
||||
return loss
|
||||
|
||||
207
guidance/if_utils.py
Normal file
207
guidance/if_utils.py
Normal file
@@ -0,0 +1,207 @@
|
||||
from transformers import logging
|
||||
from diffusers import IFPipeline, DDPMScheduler
|
||||
|
||||
# suppress partial model loading warning
|
||||
logging.set_verbosity_error()
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
|
||||
class SpecifyGradient(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, input_tensor, gt_grad):
|
||||
ctx.save_for_backward(gt_grad)
|
||||
# we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
|
||||
return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_scale):
|
||||
gt_grad, = ctx.saved_tensors
|
||||
gt_grad = gt_grad * grad_scale
|
||||
return gt_grad, None
|
||||
|
||||
def seed_everything(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
#torch.backends.cudnn.deterministic = True
|
||||
#torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
class IF(nn.Module):
|
||||
def __init__(self, device, vram_O, t_range=[0.02, 0.98]):
|
||||
super().__init__()
|
||||
|
||||
self.device = device
|
||||
|
||||
print(f'[INFO] loading DeepFloyd IF-I-XL...')
|
||||
|
||||
model_key = "DeepFloyd/IF-I-XL-v1.0"
|
||||
|
||||
is_torch2 = torch.__version__[0] == '2'
|
||||
|
||||
# Create model
|
||||
pipe = IFPipeline.from_pretrained(model_key, variant="fp16", torch_dtype=torch.float16)
|
||||
if not is_torch2:
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
if vram_O:
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.enable_attention_slicing(1)
|
||||
pipe.enable_model_cpu_offload()
|
||||
else:
|
||||
pipe.to(device)
|
||||
|
||||
self.unet = pipe.unet
|
||||
self.tokenizer = pipe.tokenizer
|
||||
self.text_encoder = pipe.text_encoder
|
||||
self.unet = pipe.unet
|
||||
self.scheduler = pipe.scheduler
|
||||
|
||||
self.pipe = pipe
|
||||
|
||||
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
|
||||
self.min_step = int(self.num_train_timesteps * t_range[0])
|
||||
self.max_step = int(self.num_train_timesteps * t_range[1])
|
||||
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
|
||||
|
||||
print(f'[INFO] loaded DeepFloyd IF-I-XL!')
|
||||
|
||||
@torch.no_grad()
|
||||
def get_text_embeds(self, prompt):
|
||||
# prompt: [str]
|
||||
|
||||
# TODO: should I add the preprocessing at https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py#LL486C10-L486C28
|
||||
prompt = self.pipe._text_preprocessing(prompt, clean_caption=False)
|
||||
inputs = self.tokenizer(prompt, padding='max_length', max_length=77, truncation=True, add_special_tokens=True, return_tensors='pt')
|
||||
embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, grad_scale=1):
|
||||
|
||||
# [0, 1] to [-1, 1] and make sure shape is [64, 64]
|
||||
images = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
|
||||
|
||||
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
|
||||
t = torch.randint(self.min_step, self.max_step + 1, (images.shape[0],), dtype=torch.long, device=self.device)
|
||||
|
||||
# predict the noise residual with unet, NO grad!
|
||||
with torch.no_grad():
|
||||
# add noise
|
||||
noise = torch.randn_like(images)
|
||||
images_noisy = self.scheduler.add_noise(images, noise, t)
|
||||
|
||||
# pred noise
|
||||
model_input = torch.cat([images_noisy] * 2)
|
||||
model_input = self.scheduler.scale_model_input(model_input, t)
|
||||
tt = torch.cat([t] * 2)
|
||||
noise_pred = self.unet(model_input, tt, encoder_hidden_states=text_embeddings).sample
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
|
||||
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# TODO: how to use the variance here?
|
||||
# noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
||||
|
||||
# w(t), sigma_t^2
|
||||
w = (1 - self.alphas[t])
|
||||
grad = grad_scale * w[:, None, None, None] * (noise_pred - noise)
|
||||
grad = torch.nan_to_num(grad)
|
||||
|
||||
# since we omitted an item in grad, we need to use the custom function to specify the gradient
|
||||
loss = SpecifyGradient.apply(images, grad)
|
||||
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
def produce_imgs(self, text_embeddings, height=64, width=64, num_inference_steps=50, guidance_scale=7.5):
|
||||
|
||||
images = torch.randn((1, 3, height, width), device=text_embeddings.device, dtype=text_embeddings.dtype)
|
||||
images = images * self.scheduler.init_noise_sigma
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for i, t in enumerate(self.scheduler.timesteps):
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
model_input = torch.cat([images] * 2)
|
||||
model_input = self.scheduler.scale_model_input(model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
|
||||
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
images = self.scheduler.step(noise_pred, t, images).prev_sample
|
||||
|
||||
images = (images + 1) / 2
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
|
||||
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(negative_prompts, str):
|
||||
negative_prompts = [negative_prompts]
|
||||
|
||||
# Prompts -> text embeds
|
||||
pos_embeds = self.get_text_embeds(prompts) # [1, 77, 768]
|
||||
neg_embeds = self.get_text_embeds(negative_prompts)
|
||||
text_embeds = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768]
|
||||
|
||||
# Text embeds -> img
|
||||
imgs = self.produce_imgs(text_embeds, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]
|
||||
|
||||
# Img to Numpy
|
||||
imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
|
||||
imgs = (imgs * 255).round().astype('uint8')
|
||||
|
||||
return imgs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
import argparse
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('prompt', type=str)
|
||||
parser.add_argument('--negative', default='', type=str)
|
||||
parser.add_argument('--vram_O', action='store_true', help="optimization for low VRAM usage")
|
||||
parser.add_argument('-H', type=int, default=64)
|
||||
parser.add_argument('-W', type=int, default=64)
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--steps', type=int, default=50)
|
||||
opt = parser.parse_args()
|
||||
|
||||
seed_everything(opt.seed)
|
||||
|
||||
device = torch.device('cuda')
|
||||
|
||||
sd = IF(device, opt.vram_O)
|
||||
|
||||
imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
|
||||
|
||||
# visualize image
|
||||
plt.imshow(imgs[0])
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
|
||||
707
guidance/sd_utils.py
Normal file
707
guidance/sd_utils.py
Normal file
@@ -0,0 +1,707 @@
|
||||
from typing import List, Optional, Sequence, Tuple, Union, Mapping
|
||||
import os
|
||||
|
||||
from dataclasses import dataclass
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from os.path import isfile
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from torchvision.io import read_image
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import functional as TVF
|
||||
from torchvision.utils import make_grid, save_image
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer, CLIPProcessor
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def spherical_dist_loss(x, y):
|
||||
x = F.normalize(x, dim=-1)
|
||||
y = F.normalize(y, dim=-1)
|
||||
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
|
||||
|
||||
|
||||
def seed_everything(seed=None):
|
||||
if seed:
|
||||
seed = int(seed)
|
||||
random.seed(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
def save_tensor2image(x: torch.Tensor, path, channel_last=True, quality=75, **kwargs):
|
||||
# assume the input x is channel last
|
||||
if x.ndim == 4 and channel_last:
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
TVF.to_pil_image(make_grid(x, value_range=(0, 1), **kwargs)).save(path, quality=quality)
|
||||
|
||||
|
||||
def to_pil(x: torch.Tensor, **kwargs) -> Image.Image:
|
||||
return TVF.to_pil_image(make_grid(x, value_range=(0, 1), **kwargs))
|
||||
|
||||
|
||||
def to_np_img(x: torch.Tensor) -> np.ndarray:
|
||||
return (x.detach().permute(0, 2, 3, 1).cpu().numpy() * 255).round().astype(np.uint8)
|
||||
|
||||
|
||||
class SpecifyGradient(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, input_tensor, gt_grad):
|
||||
ctx.save_for_backward(gt_grad)
|
||||
# we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
|
||||
return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_scale):
|
||||
gt_grad, = ctx.saved_tensors
|
||||
gt_grad = gt_grad * grad_scale
|
||||
return gt_grad, None
|
||||
|
||||
|
||||
def token_replace(prompt, negative, learned_embeds_path):
|
||||
# Set up automatic token replacement for prompt
|
||||
if '<token>' in prompt or '<token>' in negative:
|
||||
if learned_embeds_path is None:
|
||||
raise ValueError(
|
||||
'--learned_embeds_path must be specified when using <token>')
|
||||
import torch
|
||||
tmp = list(torch.load(learned_embeds_path, map_location='cpu').keys())
|
||||
if len(tmp) != 1:
|
||||
raise ValueError(
|
||||
'Something is wrong with the dict passed in for --learned_embeds_path')
|
||||
token = tmp[0]
|
||||
prompt = prompt.replace('<token>', token)
|
||||
negative = negative.replace('<token>', token)
|
||||
logger.info(f'Prompt after replacing <token>: {prompt}')
|
||||
logger.info(f'Negative prompt after replacing <token>: {negative}')
|
||||
return prompt, negative
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet2DConditionOutput:
|
||||
# Not sure how to check what unet_traced.pt contains, and user wants. HalfTensor or FloatTensor
|
||||
sample: torch.HalfTensor
|
||||
|
||||
|
||||
def enable_vram(pipe):
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
pipe.enable_vae_slicing()
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.enable_attention_slicing(1)
|
||||
# pipe.enable_model_cpu_offload()
|
||||
|
||||
|
||||
def get_model_path(sd_version='2.1', clip_version='large', hf_key=None):
|
||||
if hf_key is not None:
|
||||
logger.info(f'[INFO] using hugging face custom model key: {hf_key}')
|
||||
sd_path = hf_key
|
||||
elif sd_version == '2.1':
|
||||
sd_path = "stabilityai/stable-diffusion-2-1-base"
|
||||
elif sd_version == '2.0':
|
||||
sd_path = "stabilityai/stable-diffusion-2-base"
|
||||
elif sd_version == '1.5':
|
||||
sd_path = "runwayml/stable-diffusion-v1-5"
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Stable-diffusion version {sd_version} not supported.')
|
||||
if clip_version == 'base':
|
||||
clip_path = "openai/clip-vit-base-patch32"
|
||||
else:
|
||||
clip_path = "openai/clip-vit-large-patch14"
|
||||
return sd_path, clip_path
|
||||
|
||||
|
||||
class StableDiffusion(nn.Module):
|
||||
def __init__(self, device, fp16, vram_O,
|
||||
sd_version='2.1', hf_key=None,
|
||||
t_range=[0.02, 0.98],
|
||||
use_clip=False,
|
||||
clip_version='base',
|
||||
clip_iterative=True,
|
||||
clip_t=0.4,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.device = device
|
||||
self.sd_version = sd_version
|
||||
self.vram_O = vram_O
|
||||
self.fp16 = fp16
|
||||
|
||||
logger.info(f'[INFO] loading stable diffusion...')
|
||||
|
||||
sd_path, clip_path = get_model_path(sd_version, clip_version, hf_key)
|
||||
self.precision_t = torch.float16 if fp16 else torch.float32
|
||||
|
||||
# Create model
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
sd_path, torch_dtype=self.precision_t, local_files_only=False)
|
||||
|
||||
if isfile('./unet_traced.pt'):
|
||||
# use jitted unet
|
||||
unet_traced = torch.jit.load('./unet_traced.pt')
|
||||
|
||||
class TracedUNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.in_channels = pipe.unet.in_channels
|
||||
self.device = pipe.unet.device
|
||||
|
||||
def forward(self, latent_model_input, t, encoder_hidden_states):
|
||||
sample = unet_traced(
|
||||
latent_model_input, t, encoder_hidden_states)[0]
|
||||
return UNet2DConditionOutput(sample=sample)
|
||||
pipe.unet = TracedUNet()
|
||||
|
||||
self.vae = pipe.vae
|
||||
self.tokenizer = pipe.tokenizer
|
||||
self.text_encoder = pipe.text_encoder
|
||||
self.unet = pipe.unet
|
||||
|
||||
if kwargs.get('learned_embeds_path', None) is not None:
|
||||
learned_embeds_path = kwargs['learned_embeds_path']
|
||||
if os.path.exists(learned_embeds_path):
|
||||
logger.info(
|
||||
f'[INFO] loading learned embeddings from {kwargs["learned_embeds_path"]}')
|
||||
self.add_tokens_to_model_from_path(learned_embeds_path, kwargs.get('overrride_token', None))
|
||||
else:
|
||||
logger.warning(f'learned_embeds_path {learned_embeds_path} does not exist!')
|
||||
|
||||
if vram_O:
|
||||
# this will change device from gpu to other types (meta)
|
||||
enable_vram(pipe)
|
||||
else:
|
||||
if is_xformers_available():
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
pipe.to(device)
|
||||
|
||||
self.scheduler = DDIMScheduler.from_pretrained(
|
||||
sd_path, subfolder="scheduler", torch_dtype=self.precision_t, local_files_only=False)
|
||||
|
||||
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
|
||||
self.min_step = int(self.num_train_timesteps * t_range[0])
|
||||
self.max_step = int(self.num_train_timesteps * t_range[1])
|
||||
self.alphas = self.scheduler.alphas_cumprod.to(
|
||||
self.device) # for convenience
|
||||
|
||||
logger.info(f'[INFO] loaded stable diffusion!')
|
||||
|
||||
# for CLIP
|
||||
self.use_clip = use_clip
|
||||
if self.use_clip:
|
||||
#breakpoint()
|
||||
self.clip_model = CLIPModel.from_pretrained(clip_path).to(device)
|
||||
image_processor = CLIPProcessor.from_pretrained(clip_path).image_processor
|
||||
self.image_processor = transforms.Compose([
|
||||
transforms.Resize((image_processor.crop_size['height'], image_processor.crop_size['width'])),
|
||||
transforms.Normalize(image_processor.image_mean, image_processor.image_std),
|
||||
])
|
||||
for p in self.clip_model.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
self.clip_iterative = clip_iterative
|
||||
self.clip_t = int(self.num_train_timesteps * clip_t)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_text_embeds(self, prompt):
|
||||
# Tokenize text and get embeddings
|
||||
text_input = self.tokenizer(
|
||||
prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
|
||||
|
||||
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
||||
return text_embeddings
|
||||
|
||||
@torch.no_grad()
|
||||
def get_all_text_embeds(self, prompt):
|
||||
# Tokenize text and get embeddings
|
||||
text_input = self.tokenizer(
|
||||
prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
|
||||
|
||||
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))
|
||||
# text_z = text_z / text_z.norm(dim=-1, keepdim=True)
|
||||
|
||||
# return all text embeddings and class embeddings
|
||||
return torch.cat([text_embeddings[0], text_embeddings[1].unsqueeze(1)], dim=1)
|
||||
|
||||
# @torch.no_grad()
|
||||
def get_clip_img_embeds(self, img):
|
||||
img = self.image_processor(img)
|
||||
image_z = self.clip_model.get_image_features(img)
|
||||
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
|
||||
return image_z
|
||||
|
||||
def clip_loss(self, ref_z, pred_rgb):
|
||||
image_z = self.get_clip_img_embeds(pred_rgb)
|
||||
loss = spherical_dist_loss(image_z, ref_z)
|
||||
return loss
|
||||
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, as_latent=False, grad_clip=None, grad_scale=1.0,
|
||||
image_ref_clip=None, text_ref_clip=None, clip_guidance=100, clip_image_loss=False,
|
||||
density=None,
|
||||
save_guidance_path=None
|
||||
):
|
||||
enable_clip = self.use_clip and clip_guidance > 0 and not as_latent
|
||||
enable_sds = True
|
||||
#breakpoint()
|
||||
if as_latent:
|
||||
latents = F.interpolate(
|
||||
pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
|
||||
else:
|
||||
# interp to 512x512 to be fed into vae.
|
||||
pred_rgb_512 = F.interpolate(
|
||||
pred_rgb, (512, 512), mode='bilinear', align_corners=False)
|
||||
# encode image into latents with vae, requires grad!
|
||||
latents = self.encode_imgs(pred_rgb_512)
|
||||
|
||||
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
|
||||
# Since during the optimzation, the 3D is getting better.
|
||||
# mn = max(self.min_step, int(self.max_step - (self.max_step - self.min_step) / (self.opt.max_epoch // 3) * self.epoch + 0.5))
|
||||
t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device)
|
||||
if enable_clip and self.clip_iterative:
|
||||
if t > self.clip_t:
|
||||
enable_clip = False
|
||||
else:
|
||||
enable_sds = False
|
||||
|
||||
# predict the noise residual with unet, NO grad!
|
||||
with torch.no_grad():
|
||||
# add noise
|
||||
noise = torch.randn_like(latents)
|
||||
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
||||
|
||||
# pred noise
|
||||
latent_model_input = torch.cat([latents_noisy] * 2)
|
||||
# Save input tensors for UNet
|
||||
# torch.save(latent_model_input, "train_latent_model_input.pt")
|
||||
# torch.save(t, "train_t.pt")
|
||||
# torch.save(text_embeddings, "train_text_embeddings.pt")
|
||||
tt = torch.cat([t]*2)
|
||||
noise_pred = self.unet(latent_model_input, tt,
|
||||
encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance (high scale from paper!)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_text + guidance_scale * \
|
||||
(noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if enable_clip:
|
||||
pred_original_sample = (latents_noisy - (1 - self.alphas[t]) ** (0.5) * noise_pred) / self.alphas[t] ** (0.5)
|
||||
sample = pred_original_sample
|
||||
sample = sample.detach().requires_grad_()
|
||||
|
||||
sample = 1 / self.vae.config.scaling_factor * sample
|
||||
out_image = self.vae.decode(sample).sample
|
||||
out_image = (out_image / 2 + 0.5)#.clamp(0, 1)
|
||||
image_embeddings_clip = self.get_clip_img_embeds(out_image)
|
||||
ref_clip = image_ref_clip if clip_image_loss else text_ref_clip
|
||||
loss_clip = spherical_dist_loss(image_embeddings_clip, ref_clip).mean() * clip_guidance * 50 # 100
|
||||
grad_clipd = - torch.autograd.grad(loss_clip, sample, retain_graph=True)[0]
|
||||
else:
|
||||
grad_clipd = 0
|
||||
|
||||
# import kiui
|
||||
# latents_tmp = torch.randn((1, 4, 64, 64), device=self.device)
|
||||
# latents_tmp = latents_tmp.detach()
|
||||
# kiui.lo(latents_tmp)
|
||||
# self.scheduler.set_timesteps(30)
|
||||
# for i, t in enumerate(self.scheduler.timesteps):
|
||||
# latent_model_input = torch.cat([latents_tmp] * 3)
|
||||
# noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
|
||||
# noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
|
||||
# noise_pred = noise_pred_uncond + 10 * (noise_pred_pos - noise_pred_uncond)
|
||||
# latents_tmp = self.scheduler.step(noise_pred, t, latents_tmp)['prev_sample']
|
||||
# imgs = self.decode_latents(latents_tmp)
|
||||
# kiui.vis.plot_image(imgs)
|
||||
|
||||
if density is not None:
|
||||
with torch.no_grad():
|
||||
density = F.interpolate(density.detach(), (64, 64), mode='bilinear', align_corners=False)
|
||||
ids = torch.nonzero(density.squeeze())
|
||||
spatial_weight = torch.ones_like(density, device=density.device)
|
||||
try:
|
||||
up = ids[:, 0].min()
|
||||
down = ids[:, 0].max() + 1
|
||||
ll = ids[:, 1].min()
|
||||
rr = ids[:, 1].max() + 1
|
||||
spatial_weight[:, :, up:down, ll:rr] += 1
|
||||
except:
|
||||
pass
|
||||
# breakpoint()
|
||||
# w(t), sigma_t^2
|
||||
w = (1 - self.alphas[t])[:, None, None, None]
|
||||
# w = self.alphas[t] ** 0.5 * (1 - self.alphas[t])
|
||||
|
||||
if enable_sds:
|
||||
grad_sds = grad_scale * w * (noise_pred - noise)
|
||||
loss_sds = grad_sds.abs().mean().detach()
|
||||
else:
|
||||
grad_sds = 0.
|
||||
loss_sds = 0.
|
||||
|
||||
if enable_clip:
|
||||
grad_clipd = w * grad_clipd.detach()
|
||||
loss_clipd = grad_clipd.abs().mean().detach()
|
||||
else:
|
||||
grad_clipd = 0.
|
||||
loss_clipd = 0.
|
||||
|
||||
grad = grad_clipd + grad_sds
|
||||
|
||||
if grad_clip is not None:
|
||||
grad = grad.clamp(-grad_clip, grad_clip)
|
||||
|
||||
if density is not None:
|
||||
grad = grad * spatial_weight / 2
|
||||
|
||||
grad = torch.nan_to_num(grad)
|
||||
|
||||
# since we omitted an item in grad, we need to use the custom function to specify the gradient
|
||||
# loss = SpecifyGradient.apply(latents, grad)
|
||||
# loss = loss.abs().mean().detach()
|
||||
latents.backward(gradient=grad, retain_graph=True)
|
||||
loss = grad.abs().mean().detach()
|
||||
|
||||
if not enable_clip:
|
||||
loss_sds = loss
|
||||
|
||||
if save_guidance_path:
|
||||
with torch.no_grad():
|
||||
# save original input
|
||||
images = []
|
||||
os.makedirs(os.path.dirname(save_guidance_path), exist_ok=True)
|
||||
timesteps = torch.arange(-1, 1000, 100, dtype=torch.long, device=self.device)
|
||||
timesteps[0] *= 0
|
||||
for t in timesteps:
|
||||
if as_latent:
|
||||
pred_rgb_512 = self.decode_latents(latents)
|
||||
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
||||
|
||||
# pred noise
|
||||
latent_model_input = torch.cat([latents_noisy] * 2)
|
||||
|
||||
noise_pred = self.unet(latent_model_input, t,
|
||||
encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance (high scale from paper!)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_text + guidance_scale * \
|
||||
(noise_pred_text - noise_pred_uncond)
|
||||
|
||||
pred_original_sample = self.decode_latents((latents_noisy - (1 - self.alphas[t]) ** (0.5) * noise_pred) / self.alphas[t] ** (0.5))
|
||||
|
||||
# visualize predicted denoised image
|
||||
# claforte: discuss this with Vikram!!
|
||||
result_hopefully_less_noisy_image = self.decode_latents(latents - w*(noise_pred - noise))
|
||||
|
||||
# visualize noisier image
|
||||
result_noisier_image = self.decode_latents(latents_noisy)
|
||||
|
||||
# add in the last col, w/o rendered view contraint, using random noise as latent.
|
||||
latent_model_input = torch.cat([noise] * 2)
|
||||
noise_pred = self.unet(latent_model_input, t,
|
||||
encoder_hidden_states=text_embeddings).sample
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_text + guidance_scale * \
|
||||
(noise_pred_text - noise_pred_uncond)
|
||||
noise_diffusion_out = self.decode_latents((noise - (1 - self.alphas[t]) ** (0.5) * noise_pred) / self.alphas[t] ** (0.5))
|
||||
# all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]
|
||||
image = torch.cat([pred_rgb_512, pred_original_sample, result_noisier_image, result_hopefully_less_noisy_image, noise_diffusion_out],dim=0)
|
||||
images.append(image)
|
||||
viz_images = torch.cat(images, dim=0)
|
||||
save_image(viz_images, save_guidance_path, nrow=5)
|
||||
|
||||
return loss, {'loss_sds': loss_sds, 'loss_clipd': loss_clipd}
|
||||
|
||||
@torch.no_grad()
|
||||
def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
|
||||
|
||||
if latents is None:
|
||||
latents = torch.randn(
|
||||
(text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
with torch.autocast('cuda'):
|
||||
for i, t in enumerate(self.scheduler.timesteps):
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
# latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# Save input tensors for UNet
|
||||
# torch.save(latent_model_input, "produce_latents_latent_model_input.pt")
|
||||
# torch.save(t, "produce_latents_t.pt")
|
||||
# torch.save(text_embeddings, "produce_latents_text_embeddings.pt")
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
|
||||
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_text + guidance_scale * \
|
||||
(noise_pred_text - noise_pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(noise_pred, t, latents)[
|
||||
'prev_sample']
|
||||
|
||||
return latents
|
||||
|
||||
def decode_latents(self, latents):
|
||||
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
# with torch.no_grad():
|
||||
imgs = self.vae.decode(latents).sample
|
||||
|
||||
imgs = (imgs / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
return imgs
|
||||
|
||||
def encode_imgs(self, imgs):
|
||||
# imgs: [B, 3, H, W]
|
||||
|
||||
imgs = 2 * imgs - 1
|
||||
|
||||
posterior = self.vae.encode(imgs).latent_dist
|
||||
latents = posterior.sample() * self.vae.config.scaling_factor
|
||||
|
||||
return latents
|
||||
|
||||
def encode_imgs_mean(self, imgs):
|
||||
# imgs: [B, 3, H, W]
|
||||
|
||||
imgs = 2 * imgs - 1
|
||||
|
||||
latents = self.vae.encode(imgs).latent_dist.mean
|
||||
latents = latents * self.vae.config.scaling_factor
|
||||
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, to_numpy=True):
|
||||
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(negative_prompts, str):
|
||||
negative_prompts = [negative_prompts] * len(prompts)
|
||||
|
||||
prompts = tuple(prompts)
|
||||
negative_prompts = tuple(negative_prompts)
|
||||
# Prompts -> text embeds
|
||||
pos_embeds = self.get_text_embeds(prompts) # [1, 77, 768]
|
||||
neg_embeds = self.get_text_embeds(negative_prompts)
|
||||
text_embeds = torch.cat(
|
||||
[neg_embeds, pos_embeds], dim=0) # [2, 77, 768]
|
||||
|
||||
# Text embeds -> img latents
|
||||
latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents,
|
||||
num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]
|
||||
|
||||
# Img latents -> imgs
|
||||
imgs = self.decode_latents(latents.to(
|
||||
text_embeds.dtype)) # [1, 3, 512, 512]
|
||||
|
||||
# Img to Numpy
|
||||
if to_numpy:
|
||||
imgs = to_np_img(imgs)
|
||||
return imgs
|
||||
|
||||
@torch.no_grad()
|
||||
def img_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, img=None, to_numpy=True, t=50):
|
||||
"""
|
||||
Known issues:
|
||||
1. Not able to reconstruct images even with no noise.
|
||||
"""
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(negative_prompts, str):
|
||||
negative_prompts = [negative_prompts]
|
||||
|
||||
# Prompts -> text embeds
|
||||
pos_embeds = self.get_text_embeds(prompts) # [1, 77, 768]
|
||||
neg_embeds = self.get_text_embeds(negative_prompts)
|
||||
text_embeds = torch.cat(
|
||||
[neg_embeds, pos_embeds], dim=0) # [2, 77, 768]
|
||||
|
||||
# image to latent
|
||||
# interp to 512x512 to be fed into vae.
|
||||
if isinstance(img, str):
|
||||
img = TVF.to_tensor(Image.open(img))[None, :3].cuda()
|
||||
|
||||
img_512 = F.interpolate(
|
||||
img.to(text_embeds.dtype), (512, 512), mode='bilinear', align_corners=False)
|
||||
# logger.info(img_512.shape, img_512, '\n', img_512.min(), img_512.max(), img_512.mean())
|
||||
|
||||
# encode image into latents with vae, requires grad!
|
||||
latents = self.encode_imgs(img_512).repeat(
|
||||
text_embeds.shape[0] // 2, 1, 1, 1)
|
||||
# logger.info(latents.shape, latents, '\n', latents.min(), latents.max(), latents.mean())
|
||||
|
||||
noise = torch.randn_like(latents)
|
||||
if t > 0:
|
||||
latents_noise = self.scheduler.add_noise(
|
||||
latents, noise, torch.tensor(t).to(torch.int32))
|
||||
else:
|
||||
latents_noise = latents
|
||||
|
||||
# Text embeds -> img latents
|
||||
latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents_noise,
|
||||
num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]
|
||||
|
||||
# Img latents -> imgs
|
||||
imgs = self.decode_latents(latents.to(
|
||||
text_embeds.dtype)) # [1, 3, 512, 512]
|
||||
|
||||
# Img to Numpy
|
||||
if to_numpy:
|
||||
imgs = to_np_img(imgs)
|
||||
return imgs
|
||||
|
||||
def add_tokens_to_model(self, learned_embeds: Mapping[str, Tensor], override_token: Optional[Union[str, dict]] = None) -> None:
|
||||
r"""Adds tokens to the tokenizer and text encoder of a model."""
|
||||
|
||||
# Loop over learned embeddings
|
||||
new_tokens = []
|
||||
for token, embedding in learned_embeds.items():
|
||||
embedding = embedding.to(
|
||||
self.text_encoder.get_input_embeddings().weight.dtype)
|
||||
if override_token is not None:
|
||||
token = override_token if isinstance(
|
||||
override_token, str) else override_token[token]
|
||||
|
||||
# Add the token to the tokenizer
|
||||
num_added_tokens = self.tokenizer.add_tokens(token)
|
||||
if num_added_tokens == 0:
|
||||
raise ValueError((f"The tokenizer already contains the token {token}. Please pass a "
|
||||
"different `token` that is not already in the tokenizer."))
|
||||
|
||||
# Resize the token embeddings
|
||||
self.text_encoder._resize_token_embeddings(len(self.tokenizer))
|
||||
|
||||
# Get the id for the token and assign the embeds
|
||||
token_id = self.tokenizer.convert_tokens_to_ids(token)
|
||||
self.text_encoder.get_input_embeddings(
|
||||
).weight.data[token_id] = embedding
|
||||
new_tokens.append(token)
|
||||
|
||||
logger.info(
|
||||
f'Added {len(new_tokens)} tokens to tokenizer and text embedding: {new_tokens}')
|
||||
|
||||
def add_tokens_to_model_from_path(self, learned_embeds_path: str, override_token: Optional[Union[str, dict]] = None) -> None:
|
||||
r"""Loads tokens from a file and adds them to the tokenizer and text encoder of a model."""
|
||||
learned_embeds: Mapping[str, Tensor] = torch.load(
|
||||
learned_embeds_path, map_location='cpu')
|
||||
self.add_tokens_to_model(learned_embeds, override_token)
|
||||
|
||||
def check_prompt(self, opt):
|
||||
texts = ['', ', front view', ', side view', ', back view']
|
||||
for view_text in texts:
|
||||
text = opt.text + view_text
|
||||
logger.info(f'Checking stable diffusion model with prompt: {text}')
|
||||
# Generate
|
||||
image_check = self.prompt_to_img(
|
||||
prompts=[text] * opt.get('prompt_check_nums', 5), guidance_scale=7.5, to_numpy=False,
|
||||
num_inference_steps=opt.get('num_inference_steps', 50))
|
||||
# Save
|
||||
output_dir_check = Path(opt.workspace) / 'prompt_check'
|
||||
output_dir_check.mkdir(exist_ok=True, parents=True)
|
||||
to_pil(image_check).save(output_dir_check / f'generations_{view_text}.png')
|
||||
(output_dir_check / 'prompt.txt').write_text(text)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
import argparse
|
||||
import matplotlib.pyplot as plt
|
||||
from easydict import EasyDict as edict
|
||||
import glob
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--text', type=str)
|
||||
parser.add_argument('--negative', default='', type=str)
|
||||
parser.add_argument('--workspace', default='out/sd', type=str)
|
||||
parser.add_argument('--image_path', default=None, type=str)
|
||||
parser.add_argument('--learned_embeds_path', type=str,
|
||||
default=None, help="path to learned embeds"
|
||||
)
|
||||
parser.add_argument('--sd_version', type=str, default='1.5',
|
||||
choices=['1.5', '2.0', '2.1'], help="stable diffusion version")
|
||||
parser.add_argument('--hf_key', type=str, default=None,
|
||||
help="hugging face Stable diffusion model key")
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="use float16 for training")
|
||||
parser.add_argument('--vram_O', action='store_true',
|
||||
help="optimization for low VRAM usage")
|
||||
parser.add_argument('--gudiance_scale', type=float, default=100)
|
||||
parser.add_argument('-H', type=int, default=512)
|
||||
parser.add_argument('-W', type=int, default=512)
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--num_inference_steps', type=int, default=50)
|
||||
parser.add_argument('--noise_t', type=int, default=50)
|
||||
parser.add_argument('--prompt_check_nums', type=int, default=5)
|
||||
opt, unknown = parser.parse_known_args()
|
||||
|
||||
# seed_everything(opt.seed)
|
||||
device = torch.device('cuda')
|
||||
opt = edict(vars(opt))
|
||||
workspace = opt.workspace
|
||||
|
||||
opt.original_text = opt.text
|
||||
opt.original_negative = opt.negative
|
||||
if opt.learned_embeds_path is not None:
|
||||
# cml:
|
||||
# python guidance/sd_utils.py --text "A high-resolution DSLR image of <token>" --learned_embeds_path out/learned_embeds/ --workspace out/teddy_bear
|
||||
# check prompt
|
||||
if os.path.isdir(opt.learned_embeds_path):
|
||||
learned_embeds_paths = glob.glob(os.path.join(opt.learned_embeds_path, 'learned_embeds*bin'))
|
||||
else:
|
||||
learned_embeds_paths = [opt.learned_embeds_path]
|
||||
|
||||
for learned_embeds_path in learned_embeds_paths:
|
||||
embed_name = os.path.basename(learned_embeds_path).split('.')[0]
|
||||
opt.workspace = os.path.join(workspace, embed_name)
|
||||
sd = StableDiffusion(device, opt.fp16, opt.vram_O,
|
||||
opt.sd_version, opt.hf_key,
|
||||
learned_embeds_path=learned_embeds_path
|
||||
)
|
||||
# Add tokenizer
|
||||
if learned_embeds_path is not None: # add textual inversion tokens to model
|
||||
opt.text, opt.negative = token_replace(
|
||||
opt.original_text, opt.original_negative, learned_embeds_path)
|
||||
logger.info(opt.text, opt.negative)
|
||||
sd.check_prompt(opt)
|
||||
else:
|
||||
#breakpoint()
|
||||
if opt.image_path is not None:
|
||||
save_promt = '_'.join(opt.text.split(' ')) + '_' + opt.image_path.split(
|
||||
'/')[-1].split('.')[0] + '_' + str(opt.noise_t) + '_' + str(opt.num_inference_steps)
|
||||
imgs = sd.img_to_img([opt.text]*opt.prompt_check_nums, [opt.negative]*opt.prompt_check_nums, opt.H, opt.W, opt.num_inference_steps,
|
||||
to_numpy=False, img=opt.image_path, t=opt.noise_t, guidance_scale=opt.gudiance_scale)
|
||||
else:
|
||||
save_promt = '_'.join(opt.text.split(' '))
|
||||
imgs = sd.prompt_to_img([opt.text]*opt.prompt_check_nums, [opt.negative]
|
||||
* opt.prompt_check_nums, opt.H, opt.W, opt.num_inference_steps, to_numpy=False)
|
||||
# visualize image
|
||||
output_dir_check = Path(opt.workspace)
|
||||
output_dir_check.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
to_pil(imgs).save(output_dir_check / f'{save_promt}.png')
|
||||
81
guidance/shape_utils.py
Normal file
81
guidance/shape_utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
from shap_e.models.transmitter.base import Transmitter
|
||||
from shap_e.models.query import Query
|
||||
from shap_e.models.nerstf.renderer import NeRSTFRenderer
|
||||
from shap_e.util.collections import AttrDict
|
||||
from shap_e.diffusion.sample import sample_latents
|
||||
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
|
||||
from shap_e.models.download import load_model, load_config
|
||||
from shap_e.util.image_util import load_image
|
||||
from shap_e.models.nn.meta import subdict
|
||||
import torch
|
||||
import gc
|
||||
|
||||
|
||||
camera_to_shapes = [
|
||||
torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32),
|
||||
torch.tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=torch.float32), # to bird view
|
||||
torch.tensor([[0, 1, 0], [0, 0, 1], [-1, 0, 0]], dtype=torch.float32), # to rotaed bird view
|
||||
torch.tensor([[0, -1, 0], [0, 0, 1], [-1, 0, 0]], dtype=torch.float32), # to rotaed bird view
|
||||
torch.tensor([[-1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=torch.float32), # to bird view
|
||||
]
|
||||
|
||||
|
||||
def get_density(
|
||||
render,
|
||||
query: Query,
|
||||
params: Dict[str, torch.Tensor],
|
||||
options: AttrDict[str, Any],
|
||||
) -> torch.Tensor:
|
||||
assert render.nerstf is not None
|
||||
return render.nerstf(query, params=subdict(params, "nerstf"), options=options).density
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_shape_from_image(image_path, pos,
|
||||
rpst_type='sdf', # or 'density'
|
||||
get_color=True,
|
||||
shape_guidance=3, device='cuda'):
|
||||
xm = load_model('transmitter', device=device)
|
||||
model = load_model('image300M', device=device)
|
||||
diffusion = diffusion_from_config(load_config('diffusion'))
|
||||
latent = sample_latents(
|
||||
batch_size=1,
|
||||
model=model,
|
||||
diffusion=diffusion,
|
||||
guidance_scale=shape_guidance,
|
||||
model_kwargs=dict(images=[load_image(image_path)]),
|
||||
progress=True,
|
||||
clip_denoised=True,
|
||||
use_fp16=True,
|
||||
use_karras=True,
|
||||
karras_steps=64,
|
||||
sigma_min=1e-3,
|
||||
sigma_max=160,
|
||||
s_churn=0,
|
||||
)[0]
|
||||
|
||||
params = (xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params(
|
||||
latent[None]
|
||||
)
|
||||
|
||||
rpsts, colors = [], []
|
||||
for camera_to_shape in camera_to_shapes:
|
||||
query = Query(
|
||||
position=pos @ camera_to_shape.to(pos.device),
|
||||
direction=None,
|
||||
)
|
||||
|
||||
if rpst_type == 'sdf':
|
||||
rpst = xm.renderer.get_signed_distance(query, params, AttrDict())
|
||||
else:
|
||||
rpst = get_density(xm.renderer, query, params, AttrDict())
|
||||
rpsts.append(rpst.squeeze())
|
||||
|
||||
if get_color:
|
||||
color = xm.renderer.get_texture(query, params, AttrDict())
|
||||
else:
|
||||
color = None
|
||||
colors.append(color)
|
||||
|
||||
return rpsts, colors
|
||||
332
guidance/zero123_utils.py
Normal file
332
guidance/zero123_utils.py
Normal file
@@ -0,0 +1,332 @@
|
||||
import math
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from torchvision.utils import save_image
|
||||
|
||||
from diffusers import DDIMScheduler
|
||||
|
||||
import sys
|
||||
from os import path
|
||||
sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
class SpecifyGradient(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, input_tensor, gt_grad):
|
||||
ctx.save_for_backward(gt_grad)
|
||||
# we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
|
||||
return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_scale):
|
||||
gt_grad, = ctx.saved_tensors
|
||||
gt_grad = gt_grad * grad_scale
|
||||
return gt_grad, None
|
||||
|
||||
# load model
|
||||
def load_model_from_config(config, ckpt, device, vram_O=False, verbose=False):
|
||||
|
||||
pl_sd = torch.load(ckpt, map_location='cpu')
|
||||
|
||||
if 'global_step' in pl_sd and verbose:
|
||||
print(f'[INFO] Global Step: {pl_sd["global_step"]}')
|
||||
|
||||
sd = pl_sd['state_dict']
|
||||
|
||||
model = instantiate_from_config(config.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
|
||||
if len(m) > 0 and verbose:
|
||||
print('[INFO] missing keys: \n', m)
|
||||
if len(u) > 0 and verbose:
|
||||
print('[INFO] unexpected keys: \n', u)
|
||||
|
||||
# manually load ema and delete it to save GPU memory
|
||||
if model.use_ema:
|
||||
if verbose:
|
||||
print('[INFO] loading EMA...')
|
||||
model.model_ema.copy_to(model.model)
|
||||
del model.model_ema
|
||||
|
||||
if vram_O:
|
||||
# we don't need decoder
|
||||
del model.first_stage_model.decoder
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model.eval().to(device)
|
||||
|
||||
return model
|
||||
|
||||
class Zero123(nn.Module):
|
||||
def __init__(self, device, fp16,
|
||||
config='./pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml',
|
||||
ckpt='./pretrained/zero123/105000.ckpt', vram_O=False, t_range=[0.02, 0.98], opt=None):
|
||||
super().__init__()
|
||||
|
||||
self.device = device
|
||||
self.fp16 = fp16
|
||||
self.vram_O = vram_O
|
||||
self.t_range = t_range
|
||||
self.opt = opt
|
||||
|
||||
self.config = OmegaConf.load(config)
|
||||
self.model = load_model_from_config(self.config, ckpt, device=self.device, vram_O=vram_O)
|
||||
|
||||
# timesteps: use diffuser for convenience... hope it's alright.
|
||||
self.num_train_timesteps = self.config.model.params.timesteps
|
||||
|
||||
self.scheduler = DDIMScheduler(
|
||||
self.num_train_timesteps,
|
||||
self.config.model.params.linear_start,
|
||||
self.config.model.params.linear_end,
|
||||
beta_schedule='scaled_linear',
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
steps_offset=1,
|
||||
)
|
||||
|
||||
self.min_step = int(self.num_train_timesteps * t_range[0])
|
||||
self.max_step = int(self.num_train_timesteps * t_range[1])
|
||||
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
|
||||
|
||||
@torch.no_grad()
|
||||
def get_img_embeds(self, x):
|
||||
# x: image tensor [B, 3, 256, 256] in [0, 1]
|
||||
x = x * 2 - 1
|
||||
c = [self.model.get_learned_conditioning(xx.unsqueeze(0)) for xx in x] #.tile(n_samples, 1, 1)
|
||||
v = [self.model.encode_first_stage(xx.unsqueeze(0)).mode() for xx in x]
|
||||
return c, v
|
||||
|
||||
def angle_between(self, sph_v1, sph_v2):
|
||||
def sph2cart(sv):
|
||||
r, theta, phi = sv[0], sv[1], sv[2]
|
||||
return torch.tensor([r * torch.sin(theta) * torch.cos(phi), r * torch.sin(theta) * torch.sin(phi), r * torch.cos(theta)])
|
||||
def unit_vector(v):
|
||||
return v / torch.linalg.norm(v)
|
||||
def angle_between_2_sph(sv1, sv2):
|
||||
v1, v2 = sph2cart(sv1), sph2cart(sv2)
|
||||
v1_u, v2_u = unit_vector(v1), unit_vector(v2)
|
||||
return torch.arccos(torch.clip(torch.dot(v1_u, v2_u), -1.0, 1.0))
|
||||
angles = torch.empty(len(sph_v1), len(sph_v2))
|
||||
for i, sv1 in enumerate(sph_v1):
|
||||
for j, sv2 in enumerate(sph_v2):
|
||||
angles[i][j] = angle_between_2_sph(sv1, sv2)
|
||||
return angles
|
||||
|
||||
def train_step(self, embeddings, pred_rgb, polar, azimuth, radius, guidance_scale=3, as_latent=False, grad_scale=1, save_guidance_path:Path=None):
|
||||
# pred_rgb: tensor [1, 3, H, W] in [0, 1]
|
||||
|
||||
# adjust SDS scale based on how far the novel view is from the known view
|
||||
ref_radii = embeddings['ref_radii']
|
||||
ref_polars = embeddings['ref_polars']
|
||||
ref_azimuths = embeddings['ref_azimuths']
|
||||
v1 = torch.stack([radius + ref_radii[0], torch.deg2rad(polar + ref_polars[0]), torch.deg2rad(azimuth + ref_azimuths[0])], dim=-1) # polar,azimuth,radius are all actually delta wrt default
|
||||
v2 = torch.stack([torch.tensor(ref_radii), torch.deg2rad(torch.tensor(ref_polars)), torch.deg2rad(torch.tensor(ref_azimuths))], dim=-1)
|
||||
angles = torch.rad2deg(self.angle_between(v1, v2)).to(self.device)
|
||||
if self.opt.zero123_grad_scale == 'angle':
|
||||
grad_scale = (angles.min(dim=1)[0] / (180/len(ref_azimuths))) * grad_scale # rethink 180/len(ref_azimuths) # claforte: try inverting grad_scale or just fixing it to 1.0
|
||||
elif self.opt.zero123_grad_scale == 'None':
|
||||
grad_scale = 1.0 # claforte: I think this might converge faster...?
|
||||
else:
|
||||
assert False, f'Unrecognized `zero123_grad_scale`: {self.opt.zero123_grad_scale}'
|
||||
|
||||
if as_latent:
|
||||
latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1
|
||||
else:
|
||||
pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False)
|
||||
latents = self.encode_imgs(pred_rgb_256)
|
||||
|
||||
t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device)
|
||||
|
||||
# Set weights acc to closeness in angle
|
||||
if len(ref_azimuths) > 1:
|
||||
inv_angles = 1/angles
|
||||
inv_angles[inv_angles > 100] = 100
|
||||
inv_angles /= inv_angles.max(dim=-1, keepdim=True)[0]
|
||||
inv_angles[inv_angles < 0.1] = 0
|
||||
else:
|
||||
inv_angles = torch.tensor([1.]).to(self.device)
|
||||
|
||||
# Multiply closeness-weight by user-given weights
|
||||
zero123_ws = torch.tensor(embeddings['zero123_ws'])[None, :].to(self.device) * inv_angles
|
||||
zero123_ws /= zero123_ws.max(dim=-1, keepdim=True)[0]
|
||||
zero123_ws[zero123_ws < 0.1] = 0
|
||||
|
||||
with torch.no_grad():
|
||||
noise = torch.randn_like(latents)
|
||||
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
||||
|
||||
x_in = torch.cat([latents_noisy] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
|
||||
noise_preds = []
|
||||
# Loop through each ref image
|
||||
for (zero123_w, c_crossattn, c_concat, ref_polar, ref_azimuth, ref_radius) in zip(zero123_ws.T,
|
||||
embeddings['c_crossattn'], embeddings['c_concat'],
|
||||
ref_polars, ref_azimuths, ref_radii):
|
||||
# polar,azimuth,radius are all actually delta wrt default
|
||||
p = polar + ref_polars[0] - ref_polar
|
||||
a = azimuth + ref_azimuths[0] - ref_azimuth
|
||||
a[a > 180] -= 360 # range in [-180, 180]
|
||||
r = radius + ref_radii[0] - ref_radius
|
||||
# T = torch.tensor([math.radians(p), math.sin(math.radians(-a)), math.cos(math.radians(a)), r])
|
||||
# T = T[None, None, :].to(self.device)
|
||||
T = torch.stack([torch.deg2rad(p), torch.sin(torch.deg2rad(-a)), torch.cos(torch.deg2rad(a)), r], dim=-1)[:, None, :]
|
||||
cond = {}
|
||||
clip_emb = self.model.cc_projection(torch.cat([c_crossattn.repeat(len(T), 1, 1), T], dim=-1))
|
||||
cond['c_crossattn'] = [torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)]
|
||||
cond['c_concat'] = [torch.cat([torch.zeros_like(c_concat).repeat(len(T), 1, 1, 1).to(self.device), c_concat.repeat(len(T), 1, 1, 1)], dim=0)]
|
||||
noise_pred = self.model.apply_model(x_in, t_in, cond)
|
||||
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
noise_preds.append(zero123_w[:, None, None, None] * noise_pred)
|
||||
|
||||
noise_pred = torch.stack(noise_preds).sum(dim=0) / zero123_ws.sum(dim=-1)[:, None, None, None]
|
||||
|
||||
w = (1 - self.alphas[t])
|
||||
grad = (grad_scale * w)[:, None, None, None] * (noise_pred - noise)
|
||||
grad = torch.nan_to_num(grad)
|
||||
|
||||
# import kiui
|
||||
# if not as_latent:
|
||||
# kiui.vis.plot_image(pred_rgb_256)
|
||||
# kiui.vis.plot_matrix(latents)
|
||||
# kiui.vis.plot_matrix(grad)
|
||||
|
||||
# import kiui
|
||||
# latents = torch.randn((1, 4, 32, 32), device=self.device)
|
||||
# kiui.lo(latents)
|
||||
# self.scheduler.set_timesteps(30)
|
||||
# with torch.no_grad():
|
||||
# for i, t in enumerate(self.scheduler.timesteps):
|
||||
# x_in = torch.cat([latents] * 2)
|
||||
# t_in = torch.cat([t.view(1)] * 2).to(self.device)
|
||||
|
||||
# noise_pred = self.model.apply_model(x_in, t_in, cond)
|
||||
# noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
||||
# noise_pred = noise_pred_uncond + 3 * (noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
# latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']
|
||||
# imgs = self.decode_latents(latents)
|
||||
# print(polar, azimuth, radius)
|
||||
# kiui.vis.plot_image(pred_rgb_256, imgs)
|
||||
|
||||
if save_guidance_path:
|
||||
with torch.no_grad():
|
||||
if as_latent:
|
||||
pred_rgb_256 = self.decode_latents(latents) # claforte: test!
|
||||
|
||||
# visualize predicted denoised image
|
||||
result_hopefully_less_noisy_image = self.decode_latents(self.model.predict_start_from_noise(latents_noisy, t, noise_pred))
|
||||
|
||||
# visualize noisier image
|
||||
result_noisier_image = self.decode_latents(latents_noisy)
|
||||
|
||||
# all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]
|
||||
viz_images = torch.cat([pred_rgb_256, result_noisier_image, result_hopefully_less_noisy_image],dim=-1)
|
||||
save_image(viz_images, save_guidance_path)
|
||||
|
||||
# since we omitted an item in grad, we need to use the custom function to specify the gradient
|
||||
# loss = SpecifyGradient.apply(latents, grad)
|
||||
latents.backward(gradient=grad, retain_graph=True)
|
||||
loss = grad.abs().mean().detach()
|
||||
return loss
|
||||
|
||||
# verification
|
||||
@torch.no_grad()
|
||||
def __call__(self,
|
||||
image, # image tensor [1, 3, H, W] in [0, 1]
|
||||
polar=0, azimuth=0, radius=0, # new view params
|
||||
scale=3, ddim_steps=50, ddim_eta=1, h=256, w=256, # diffusion params
|
||||
c_crossattn=None, c_concat=None, post_process=True,
|
||||
):
|
||||
|
||||
if c_crossattn is None:
|
||||
embeddings = self.get_img_embeds(image)
|
||||
|
||||
T = torch.tensor([math.radians(polar), math.sin(math.radians(azimuth)), math.cos(math.radians(azimuth)), radius])
|
||||
T = T[None, None, :].to(self.device)
|
||||
|
||||
cond = {}
|
||||
clip_emb = self.model.cc_projection(torch.cat([embeddings['c_crossattn'] if c_crossattn is None else c_crossattn, T], dim=-1))
|
||||
cond['c_crossattn'] = [torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)]
|
||||
cond['c_concat'] = [torch.cat([torch.zeros_like(embeddings['c_concat']).to(self.device), embeddings['c_concat']], dim=0)] if c_concat is None else [torch.cat([torch.zeros_like(c_concat).to(self.device), c_concat], dim=0)]
|
||||
|
||||
# produce latents loop
|
||||
latents = torch.randn((1, 4, h // 8, w // 8), device=self.device)
|
||||
self.scheduler.set_timesteps(ddim_steps)
|
||||
|
||||
for i, t in enumerate(self.scheduler.timesteps):
|
||||
x_in = torch.cat([latents] * 2)
|
||||
t_in = torch.cat([t.view(1)] * 2).to(self.device)
|
||||
|
||||
noise_pred = self.model.apply_model(x_in, t_in, cond)
|
||||
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + scale * (noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(noise_pred, t, latents, eta=ddim_eta)['prev_sample']
|
||||
|
||||
imgs = self.decode_latents(latents)
|
||||
imgs = imgs.cpu().numpy().transpose(0, 2, 3, 1) if post_process else imgs
|
||||
|
||||
return imgs
|
||||
|
||||
def decode_latents(self, latents):
|
||||
# zs: [B, 4, 32, 32] Latent space image
|
||||
# with self.model.ema_scope():
|
||||
imgs = self.model.decode_first_stage(latents)
|
||||
imgs = (imgs / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
return imgs # [B, 3, 256, 256] RGB space image
|
||||
|
||||
def encode_imgs(self, imgs):
|
||||
# imgs: [B, 3, 256, 256] RGB space image
|
||||
# with self.model.ema_scope():
|
||||
imgs = imgs * 2 - 1
|
||||
latents = torch.cat([self.model.get_first_stage_encoding(self.model.encode_first_stage(img.unsqueeze(0))) for img in imgs], dim=0)
|
||||
return latents # [B, 4, 32, 32] Latent space image
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import cv2
|
||||
import argparse
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('input', type=str)
|
||||
parser.add_argument('--fp16', action='store_true', help="use float16 for training") # no use now, can only run in fp32
|
||||
|
||||
parser.add_argument('--polar', type=float, default=0, help='delta polar angle in [-90, 90]')
|
||||
parser.add_argument('--azimuth', type=float, default=0, help='delta azimuth angle in [-180, 180]')
|
||||
parser.add_argument('--radius', type=float, default=0, help='delta camera radius multiplier in [-0.5, 0.5]')
|
||||
|
||||
opt = parser.parse_args()
|
||||
|
||||
device = torch.device('cuda')
|
||||
|
||||
print(f'[INFO] loading image from {opt.input} ...')
|
||||
image = cv2.imread(opt.input, cv2.IMREAD_UNCHANGED)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA)
|
||||
image = image.astype(np.float32) / 255.0
|
||||
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).contiguous().to(device)
|
||||
|
||||
print(f'[INFO] loading model ...')
|
||||
zero123 = Zero123(device, opt.fp16, opt=opt)
|
||||
|
||||
print(f'[INFO] running model ...')
|
||||
outputs = zero123(image, polar=opt.polar, azimuth=opt.azimuth, radius=opt.radius)
|
||||
plt.imshow(outputs[0])
|
||||
plt.show()
|
||||
Reference in New Issue
Block a user