# %% [markdown] # ## Copyright 2022 Google LLC. Double-click for license information. # %% # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # %% [markdown] # # Null-text inversion + Editing with Prompt-to-Prompt # %% from typing import Optional, Union, Tuple, List, Callable, Dict # from tqdm.notebook import tqdm from tqdm import tqdm import torch from diffusers import StableDiffusionPipeline, DDIMScheduler from diffusers.utils.import_utils import is_xformers_available import torch.nn.functional as nnf import numpy as np import abc import ptp_utils import seq_aligner import shutil from torch.optim.adam import Adam from PIL import Image # %% [markdown] # For loading the Stable Diffusion using Diffusers, follow the instuctions https://huggingface.co/blog/stable_diffusion and update MY_TOKEN with your token. # %% scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) MY_TOKEN = '' LOW_RESOURCE = False NUM_DDIM_STEPS = 50 GUIDANCE_SCALE = 7.5 MAX_NUM_WORDS = 77 device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') # ldm_stable = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", scheduler=scheduler).to(device) ldm_stable = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base").to(device) # try: # ldm_stable.disable_xformers_memory_efficient_attention() # except AttributeError: # print("Attribute disable_xformers_memory_efficient_attention() is missing") if is_xformers_available(): ldm_stable.enable_xformers_memory_efficient_attention() tokenizer = ldm_stable.tokenizer def load_512(image_path, left=0, right=0, top=0, bottom=0): if type(image_path) is str: image = np.array(Image.open(image_path))[:, :, :3] else: image = image_path h, w, c = image.shape left = min(left, w-1) right = min(right, w - left - 1) top = min(top, h - left - 1) bottom = min(bottom, h - top - 1) image = image[top:h-bottom, left:w-right] h, w, c = image.shape if h < w: offset = (w - h) // 2 image = image[:, offset:offset + h] elif w < h: offset = (h - w) // 2 image = image[offset:offset + w] image = np.array(Image.fromarray(image).resize((512, 512))) return image class NullInversion: def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]): prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps alpha_prod_t = self.scheduler.alphas_cumprod[timestep] alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction return prev_sample def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]): timestep, next_timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] beta_prod_t = 1 - alpha_prod_t next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction return next_sample def get_noise_pred_single(self, latents, t, context): noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"] return noise_pred def get_noise_pred(self, latents, t, is_forward=True, context=None): latents_input = torch.cat([latents] * 2) if context is None: context = self.context guidance_scale = 1 if is_forward else GUIDANCE_SCALE noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context)["sample"] noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) if is_forward: latents = self.next_step(noise_pred, t, latents) else: latents = self.prev_step(noise_pred, t, latents) return latents @torch.no_grad() def latent2image(self, latents, return_type='np'): latents = 1 / 0.18215 * latents.detach() image = self.model.vae.decode(latents)['sample'] if return_type == 'np': image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy()[0] image = (image * 255).astype(np.uint8) return image @torch.no_grad() def image2latent(self, image): with torch.no_grad(): if type(image) is Image: image = np.array(image) if type(image) is torch.Tensor and image.dim() == 4: latents = image else: image = torch.from_numpy(image).float() / 127.5 - 1 image = image.permute(2, 0, 1).unsqueeze(0).to(device) latents = self.model.vae.encode(image)['latent_dist'].mean latents = latents * 0.18215 return latents @torch.no_grad() def init_prompt(self, prompt: str): uncond_input = self.model.tokenizer( [""], padding="max_length", max_length=self.model.tokenizer.model_max_length, return_tensors="pt" ) uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0] text_input = self.model.tokenizer( [prompt], padding="max_length", max_length=self.model.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0] self.context = torch.cat([uncond_embeddings, text_embeddings]) self.prompt = prompt @torch.no_grad() def ddim_loop(self, latent): uncond_embeddings, cond_embeddings = self.context.chunk(2) all_latent = [latent] latent = latent.clone().detach() for i in range(NUM_DDIM_STEPS): t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1] noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings) latent = self.next_step(noise_pred, t, latent) all_latent.append(latent) return all_latent @property def scheduler(self): return self.model.scheduler @torch.no_grad() def ddim_inversion(self, image): latent = self.image2latent(image) image_rec = self.latent2image(latent) ddim_latents = self.ddim_loop(latent) return image_rec, ddim_latents def null_optimization(self, latents, num_inner_steps, epsilon): uncond_embeddings, cond_embeddings = self.context.chunk(2) uncond_embeddings_list = [] latent_cur = latents[-1] bar = tqdm(total=num_inner_steps * NUM_DDIM_STEPS) for i in range(NUM_DDIM_STEPS): uncond_embeddings = uncond_embeddings.clone().detach() uncond_embeddings.requires_grad = True optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.)) latent_prev = latents[len(latents) - i - 2] t = self.model.scheduler.timesteps[i] with torch.no_grad(): noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings) for j in range(num_inner_steps): noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings) noise_pred = noise_pred_uncond + GUIDANCE_SCALE * (noise_pred_cond - noise_pred_uncond) latents_prev_rec = self.prev_step(noise_pred, t, latent_cur) loss = nnf.mse_loss(latents_prev_rec, latent_prev) optimizer.zero_grad() loss.backward() optimizer.step() loss_item = loss.item() bar.update() if loss_item < epsilon + i * 2e-5: break for j in range(j + 1, num_inner_steps): bar.update() uncond_embeddings_list.append(uncond_embeddings[:1].detach()) with torch.no_grad(): context = torch.cat([uncond_embeddings, cond_embeddings]) latent_cur = self.get_noise_pred(latent_cur, t, False, context) bar.close() return uncond_embeddings_list def invert(self, image_path: str, prompt: str, offsets=(0,0,0,0), num_inner_steps=10, early_stop_epsilon=1e-5, verbose=False): self.init_prompt(prompt) image_gt = load_512(image_path, *offsets) if verbose: print("DDIM inversion...") image_rec, ddim_latents = self.ddim_inversion(image_gt) uncond_embeddings = None # if verbose: # print("Null-text optimization...") # uncond_embeddings = self.null_optimization(ddim_latents, num_inner_steps, early_stop_epsilon) return (image_gt, image_rec), ddim_latents[-1], uncond_embeddings def __init__(self, model): scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) self.model = model self.tokenizer = self.model.tokenizer self.model.scheduler.set_timesteps(NUM_DDIM_STEPS) self.prompt = None self.context = None null_inversion = NullInversion(ldm_stable) # %% [markdown] # ## Infernce Code # %% @torch.no_grad() def text2image_ldm_stable( model, prompt: List[str], controller, num_inference_steps: int = 50, guidance_scale: Optional[float] = 7.5, generator: Optional[torch.Generator] = None, latent: Optional[torch.FloatTensor] = None, uncond_embeddings=None, start_time=50, return_type='image' ): batch_size = len(prompt) height = width = 512 text_input = model.tokenizer( prompt, padding="max_length", max_length=model.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0] max_length = text_input.input_ids.shape[-1] if uncond_embeddings is None: uncond_input = model.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) uncond_embeddings_ = model.text_encoder(uncond_input.input_ids.to(model.device))[0] else: uncond_embeddings_ = None latent, latents = ptp_utils.init_latent(latent, model, height, width, generator, batch_size) model.scheduler.set_timesteps(num_inference_steps) for i, t in enumerate(tqdm(model.scheduler.timesteps[-start_time:])): if uncond_embeddings_ is None: context = torch.cat([uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings]) else: context = torch.cat([uncond_embeddings_, text_embeddings]) latents = ptp_utils.diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False) if return_type == 'image': image = ptp_utils.latent2image(model.vae, latents) else: image = latents return image, latent # def run_and_display(prompts, latent=None, run_baseline=False, generator=None, uncond_embeddings=None, verbose=True): # images, latent = run_and_display(prompts, latent=latent, run_baseline=False, generator=generator) # if verbose: # ptp_utils.view_images(images) # return images, x_t class EmptyControl: def step_callback(self, x_t): return x_t def between_steps(self): return def __call__(self, attn, is_cross: bool, place_in_unet: str): return attn def run_and_display(prompts, controller, latent=None, run_baseline=False, generator=None, uncond_embeddings=None, verbose=True): if run_baseline: print("w.o. prompt-to-prompt") images, latent = run_and_display(prompts, EmptyControl(), latent=latent, run_baseline=False, generator=generator) print("with prompt-to-prompt") images, x_t = text2image_ldm_stable(ldm_stable, prompts, controller, latent=latent, num_inference_steps=NUM_DDIM_STEPS, guidance_scale=GUIDANCE_SCALE, generator=generator, uncond_embeddings=uncond_embeddings) if verbose: ptp_utils.view_images(images) return images, x_t # %% image_path = "../data/dragon_statue_1/image.png" prompt = "A high-resolution DSLR image of a grey dragon statue" offsets = (0, 0, 0, 0) img = load_512(image_path, *offsets) ptp_utils.view_images(img) (image_gt, image_enc), x_t, uncond_embeddings = null_inversion.invert(image_path, prompt, offsets=offsets, verbose=True) prompts = [prompt] image_inv, x_t = run_and_display(prompts, EmptyControl(), run_baseline=False, latent=x_t, uncond_embeddings=uncond_embeddings, verbose=False) print("showing from left to right: the ground truth image, the vq-autoencoder reconstruction, the null-text inverted image") ptp_utils.view_images([image_gt, image_enc, image_inv[0]])