first commit
This commit is contained in:
340
textual-inversion/null_inversion.py
Normal file
340
textual-inversion/null_inversion.py
Normal file
@@ -0,0 +1,340 @@
|
||||
# %% [markdown]
|
||||
# ## Copyright 2022 Google LLC. Double-click for license information.
|
||||
|
||||
# %%
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# %% [markdown]
|
||||
# # Null-text inversion + Editing with Prompt-to-Prompt
|
||||
|
||||
# %%
|
||||
from typing import Optional, Union, Tuple, List, Callable, Dict
|
||||
# from tqdm.notebook import tqdm
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline, DDIMScheduler
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
import torch.nn.functional as nnf
|
||||
import numpy as np
|
||||
import abc
|
||||
import ptp_utils
|
||||
import seq_aligner
|
||||
import shutil
|
||||
from torch.optim.adam import Adam
|
||||
from PIL import Image
|
||||
|
||||
# %% [markdown]
|
||||
# For loading the Stable Diffusion using Diffusers, follow the instuctions https://huggingface.co/blog/stable_diffusion and update MY_TOKEN with your token.
|
||||
|
||||
# %%
|
||||
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
|
||||
MY_TOKEN = ''
|
||||
LOW_RESOURCE = False
|
||||
NUM_DDIM_STEPS = 50
|
||||
GUIDANCE_SCALE = 7.5
|
||||
MAX_NUM_WORDS = 77
|
||||
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# ldm_stable = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", scheduler=scheduler).to(device)
|
||||
ldm_stable = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base").to(device)
|
||||
# try:
|
||||
# ldm_stable.disable_xformers_memory_efficient_attention()
|
||||
# except AttributeError:
|
||||
# print("Attribute disable_xformers_memory_efficient_attention() is missing")
|
||||
if is_xformers_available():
|
||||
ldm_stable.enable_xformers_memory_efficient_attention()
|
||||
|
||||
tokenizer = ldm_stable.tokenizer
|
||||
def load_512(image_path, left=0, right=0, top=0, bottom=0):
|
||||
if type(image_path) is str:
|
||||
image = np.array(Image.open(image_path))[:, :, :3]
|
||||
else:
|
||||
image = image_path
|
||||
h, w, c = image.shape
|
||||
left = min(left, w-1)
|
||||
right = min(right, w - left - 1)
|
||||
top = min(top, h - left - 1)
|
||||
bottom = min(bottom, h - top - 1)
|
||||
image = image[top:h-bottom, left:w-right]
|
||||
h, w, c = image.shape
|
||||
if h < w:
|
||||
offset = (w - h) // 2
|
||||
image = image[:, offset:offset + h]
|
||||
elif w < h:
|
||||
offset = (h - w) // 2
|
||||
image = image[offset:offset + w]
|
||||
image = np.array(Image.fromarray(image).resize((512, 512)))
|
||||
return image
|
||||
|
||||
|
||||
class NullInversion:
|
||||
|
||||
def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
|
||||
prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
|
||||
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
|
||||
pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output
|
||||
prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
|
||||
return prev_sample
|
||||
|
||||
def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
|
||||
timestep, next_timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep
|
||||
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
|
||||
alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
|
||||
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
|
||||
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
|
||||
return next_sample
|
||||
|
||||
def get_noise_pred_single(self, latents, t, context):
|
||||
noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"]
|
||||
return noise_pred
|
||||
|
||||
def get_noise_pred(self, latents, t, is_forward=True, context=None):
|
||||
latents_input = torch.cat([latents] * 2)
|
||||
if context is None:
|
||||
context = self.context
|
||||
guidance_scale = 1 if is_forward else GUIDANCE_SCALE
|
||||
noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
|
||||
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
|
||||
if is_forward:
|
||||
latents = self.next_step(noise_pred, t, latents)
|
||||
else:
|
||||
latents = self.prev_step(noise_pred, t, latents)
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def latent2image(self, latents, return_type='np'):
|
||||
latents = 1 / 0.18215 * latents.detach()
|
||||
image = self.model.vae.decode(latents)['sample']
|
||||
if return_type == 'np':
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
|
||||
image = (image * 255).astype(np.uint8)
|
||||
return image
|
||||
|
||||
@torch.no_grad()
|
||||
def image2latent(self, image):
|
||||
with torch.no_grad():
|
||||
if type(image) is Image:
|
||||
image = np.array(image)
|
||||
if type(image) is torch.Tensor and image.dim() == 4:
|
||||
latents = image
|
||||
else:
|
||||
image = torch.from_numpy(image).float() / 127.5 - 1
|
||||
image = image.permute(2, 0, 1).unsqueeze(0).to(device)
|
||||
latents = self.model.vae.encode(image)['latent_dist'].mean
|
||||
latents = latents * 0.18215
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def init_prompt(self, prompt: str):
|
||||
uncond_input = self.model.tokenizer(
|
||||
[""], padding="max_length", max_length=self.model.tokenizer.model_max_length,
|
||||
return_tensors="pt"
|
||||
)
|
||||
uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0]
|
||||
text_input = self.model.tokenizer(
|
||||
[prompt],
|
||||
padding="max_length",
|
||||
max_length=self.model.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0]
|
||||
self.context = torch.cat([uncond_embeddings, text_embeddings])
|
||||
self.prompt = prompt
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_loop(self, latent):
|
||||
uncond_embeddings, cond_embeddings = self.context.chunk(2)
|
||||
all_latent = [latent]
|
||||
latent = latent.clone().detach()
|
||||
for i in range(NUM_DDIM_STEPS):
|
||||
t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1]
|
||||
noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings)
|
||||
latent = self.next_step(noise_pred, t, latent)
|
||||
all_latent.append(latent)
|
||||
return all_latent
|
||||
|
||||
@property
|
||||
def scheduler(self):
|
||||
return self.model.scheduler
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_inversion(self, image):
|
||||
latent = self.image2latent(image)
|
||||
image_rec = self.latent2image(latent)
|
||||
ddim_latents = self.ddim_loop(latent)
|
||||
return image_rec, ddim_latents
|
||||
|
||||
def null_optimization(self, latents, num_inner_steps, epsilon):
|
||||
uncond_embeddings, cond_embeddings = self.context.chunk(2)
|
||||
uncond_embeddings_list = []
|
||||
latent_cur = latents[-1]
|
||||
bar = tqdm(total=num_inner_steps * NUM_DDIM_STEPS)
|
||||
for i in range(NUM_DDIM_STEPS):
|
||||
uncond_embeddings = uncond_embeddings.clone().detach()
|
||||
uncond_embeddings.requires_grad = True
|
||||
optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.))
|
||||
latent_prev = latents[len(latents) - i - 2]
|
||||
t = self.model.scheduler.timesteps[i]
|
||||
with torch.no_grad():
|
||||
noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings)
|
||||
for j in range(num_inner_steps):
|
||||
noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings)
|
||||
noise_pred = noise_pred_uncond + GUIDANCE_SCALE * (noise_pred_cond - noise_pred_uncond)
|
||||
latents_prev_rec = self.prev_step(noise_pred, t, latent_cur)
|
||||
loss = nnf.mse_loss(latents_prev_rec, latent_prev)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
loss_item = loss.item()
|
||||
bar.update()
|
||||
if loss_item < epsilon + i * 2e-5:
|
||||
break
|
||||
for j in range(j + 1, num_inner_steps):
|
||||
bar.update()
|
||||
uncond_embeddings_list.append(uncond_embeddings[:1].detach())
|
||||
with torch.no_grad():
|
||||
context = torch.cat([uncond_embeddings, cond_embeddings])
|
||||
latent_cur = self.get_noise_pred(latent_cur, t, False, context)
|
||||
bar.close()
|
||||
return uncond_embeddings_list
|
||||
|
||||
def invert(self, image_path: str, prompt: str, offsets=(0,0,0,0), num_inner_steps=10, early_stop_epsilon=1e-5, verbose=False):
|
||||
self.init_prompt(prompt)
|
||||
image_gt = load_512(image_path, *offsets)
|
||||
if verbose:
|
||||
print("DDIM inversion...")
|
||||
image_rec, ddim_latents = self.ddim_inversion(image_gt)
|
||||
|
||||
uncond_embeddings = None
|
||||
# if verbose:
|
||||
# print("Null-text optimization...")
|
||||
# uncond_embeddings = self.null_optimization(ddim_latents, num_inner_steps, early_stop_epsilon)
|
||||
return (image_gt, image_rec), ddim_latents[-1], uncond_embeddings
|
||||
|
||||
|
||||
def __init__(self, model):
|
||||
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False,
|
||||
set_alpha_to_one=False)
|
||||
self.model = model
|
||||
self.tokenizer = self.model.tokenizer
|
||||
self.model.scheduler.set_timesteps(NUM_DDIM_STEPS)
|
||||
self.prompt = None
|
||||
self.context = None
|
||||
|
||||
null_inversion = NullInversion(ldm_stable)
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
# ## Infernce Code
|
||||
|
||||
# %%
|
||||
@torch.no_grad()
|
||||
def text2image_ldm_stable(
|
||||
model,
|
||||
prompt: List[str],
|
||||
controller,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latent: Optional[torch.FloatTensor] = None,
|
||||
uncond_embeddings=None,
|
||||
start_time=50,
|
||||
return_type='image'
|
||||
):
|
||||
batch_size = len(prompt)
|
||||
height = width = 512
|
||||
|
||||
text_input = model.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=model.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
|
||||
max_length = text_input.input_ids.shape[-1]
|
||||
if uncond_embeddings is None:
|
||||
uncond_input = model.tokenizer(
|
||||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
||||
)
|
||||
uncond_embeddings_ = model.text_encoder(uncond_input.input_ids.to(model.device))[0]
|
||||
else:
|
||||
uncond_embeddings_ = None
|
||||
|
||||
latent, latents = ptp_utils.init_latent(latent, model, height, width, generator, batch_size)
|
||||
model.scheduler.set_timesteps(num_inference_steps)
|
||||
for i, t in enumerate(tqdm(model.scheduler.timesteps[-start_time:])):
|
||||
if uncond_embeddings_ is None:
|
||||
context = torch.cat([uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings])
|
||||
else:
|
||||
context = torch.cat([uncond_embeddings_, text_embeddings])
|
||||
latents = ptp_utils.diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False)
|
||||
|
||||
if return_type == 'image':
|
||||
image = ptp_utils.latent2image(model.vae, latents)
|
||||
else:
|
||||
image = latents
|
||||
return image, latent
|
||||
|
||||
|
||||
|
||||
# def run_and_display(prompts, latent=None, run_baseline=False, generator=None, uncond_embeddings=None, verbose=True):
|
||||
# images, latent = run_and_display(prompts, latent=latent, run_baseline=False, generator=generator)
|
||||
# if verbose:
|
||||
# ptp_utils.view_images(images)
|
||||
# return images, x_t
|
||||
|
||||
class EmptyControl:
|
||||
|
||||
|
||||
def step_callback(self, x_t):
|
||||
return x_t
|
||||
|
||||
def between_steps(self):
|
||||
return
|
||||
|
||||
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
||||
return attn
|
||||
|
||||
def run_and_display(prompts, controller, latent=None, run_baseline=False, generator=None, uncond_embeddings=None, verbose=True):
|
||||
if run_baseline:
|
||||
print("w.o. prompt-to-prompt")
|
||||
images, latent = run_and_display(prompts, EmptyControl(), latent=latent, run_baseline=False, generator=generator)
|
||||
print("with prompt-to-prompt")
|
||||
images, x_t = text2image_ldm_stable(ldm_stable, prompts, controller, latent=latent, num_inference_steps=NUM_DDIM_STEPS, guidance_scale=GUIDANCE_SCALE, generator=generator, uncond_embeddings=uncond_embeddings)
|
||||
if verbose:
|
||||
ptp_utils.view_images(images)
|
||||
return images, x_t
|
||||
|
||||
# %%
|
||||
image_path = "../data/dragon_statue_1/image.png"
|
||||
prompt = "A high-resolution DSLR image of a grey dragon statue"
|
||||
offsets = (0, 0, 0, 0)
|
||||
img = load_512(image_path, *offsets)
|
||||
ptp_utils.view_images(img)
|
||||
|
||||
(image_gt, image_enc), x_t, uncond_embeddings = null_inversion.invert(image_path, prompt, offsets=offsets, verbose=True)
|
||||
|
||||
prompts = [prompt]
|
||||
image_inv, x_t = run_and_display(prompts, EmptyControl(), run_baseline=False, latent=x_t, uncond_embeddings=uncond_embeddings, verbose=False)
|
||||
print("showing from left to right: the ground truth image, the vq-autoencoder reconstruction, the null-text inverted image")
|
||||
ptp_utils.view_images([image_gt, image_enc, image_inv[0]])
|
||||
Reference in New Issue
Block a user