first commit
This commit is contained in:
119
textual-inversion/autoinit.py
Normal file
119
textual-inversion/autoinit.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
It takes about 2 minutes to compute and save embeddings for all noun tokens in the CLIP tokenizer vocabulary. Examples:
|
||||
|
||||
python autoinit.py save_embeddings
|
||||
python autoinit.py get_initialization /path/to/bird.jpg
|
||||
|
||||
"""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
DEFAULT_EMB_FILE = 'clip-vit-large-patch14-text-embeddings.pth'
|
||||
|
||||
|
||||
def get_model():
|
||||
model: CLIPModel = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").eval()
|
||||
processor: CLIPProcessor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
||||
return model, processor
|
||||
|
||||
|
||||
def save_embeddings(file_name: str = DEFAULT_EMB_FILE, device: str = 'cuda'):
|
||||
try:
|
||||
import nltk
|
||||
from nltk.corpus import wordnet as wn
|
||||
except ImportError:
|
||||
print('Please install google fire with `pip install fire`')
|
||||
sys.exit()
|
||||
|
||||
# # The first time you run this code you will have to run this
|
||||
# nltk.download('wordnet')
|
||||
# nltk.download('omw-1.4')
|
||||
# All English nouns
|
||||
english_nouns = {x.name().split('.', 1)[0] for x in wn.all_synsets('n')}
|
||||
print(f'Found {len(english_nouns)} English nouns')
|
||||
|
||||
# Get model
|
||||
model, processor = get_model()
|
||||
model.to(device)
|
||||
|
||||
# Get all tokens in CLIP tokenizer that are nouns
|
||||
all_noun_ids = []
|
||||
all_token_ids = sorted(processor.tokenizer.vocab.values())
|
||||
for token_id in tqdm(all_token_ids):
|
||||
token_str = processor.tokenizer.convert_ids_to_tokens(token_id)
|
||||
if token_str.replace('</w>', '') in english_nouns and token_str.endswith('</w>'):
|
||||
all_noun_ids.append(token_id)
|
||||
print(f'Found {len(all_noun_ids)} English nouns in the CLIP tokenizer')
|
||||
|
||||
# Get all embeddings
|
||||
all_text_emb = []
|
||||
all_text_str = []
|
||||
for token_id in tqdm(all_noun_ids):
|
||||
text_ids = [49406, 550, 2867, 539, 320, token_id, 49407] # "<bos> an image of a _ <eos>"
|
||||
text_str = processor.tokenizer.decode(text_ids, skip_special_tokens=True)
|
||||
inputs = processor(text=text_str, return_tensors="pt", padding=True)
|
||||
text_emb = model.get_text_features(**inputs.to(device))
|
||||
text_emb = F.normalize(text_emb, p=2, dim=-1)
|
||||
all_text_emb.append(text_emb.detach().cpu())
|
||||
all_text_str.append(text_str)
|
||||
all_text_emb = torch.cat(all_text_emb)
|
||||
|
||||
# Save
|
||||
torch.save({
|
||||
'idx': all_noun_ids,
|
||||
'emb': all_text_emb,
|
||||
}, file_name)
|
||||
print(f'Saved embeddings to {file_name}')
|
||||
|
||||
# %%
|
||||
|
||||
def get_initialization(image_file: str, text_emb_file: str = DEFAULT_EMB_FILE, device: str = 'cuda',
|
||||
save: bool = False, save_dir: Optional[str] = None):
|
||||
|
||||
# Load text embeddings
|
||||
text_emb = torch.load(text_emb_file)
|
||||
all_noun_ids = text_emb['idx']
|
||||
all_noun_emb = text_emb['emb']
|
||||
|
||||
# Get model
|
||||
model, processor = get_model()
|
||||
model.to(device)
|
||||
|
||||
# Load and process
|
||||
image = Image.open(image_file)
|
||||
inputs = processor(images=image, return_tensors="pt", padding=True)
|
||||
image_emb = model.get_image_features(**inputs.to(device))
|
||||
image_emb = F.normalize(image_emb, p=2, dim=-1)
|
||||
|
||||
# Get similarities
|
||||
sim = all_noun_emb.to(device) @ image_emb.to(device).squeeze() # (V, )
|
||||
sim = F.softmax(sim, dim=-1) # (V, )
|
||||
topk_texts = sim.topk(k=5, largest=True, sorted=True)
|
||||
topk_indices = [all_noun_ids[idx] for idx in topk_texts.indices.cpu()]
|
||||
|
||||
# Print topk
|
||||
topk_tokens = processor.tokenizer.convert_ids_to_tokens(topk_indices)
|
||||
top_token = topk_tokens[0].replace('</w>', '')
|
||||
print('Top tokens:')
|
||||
print(topk_tokens)
|
||||
if save:
|
||||
save_dir = Path(image_file).parent if save_dir is None else Path(save_dir)
|
||||
text_file = save_dir / 'token_autoinit.txt'
|
||||
text_file.write_text(top_token)
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
import fire
|
||||
except ImportError:
|
||||
print('Please install google fire with `pip install fire`')
|
||||
sys.exit()
|
||||
fire.Fire(dict(get_initialization=get_initialization, save_embeddings=save_embeddings))
|
||||
340
textual-inversion/null_inversion.py
Normal file
340
textual-inversion/null_inversion.py
Normal file
@@ -0,0 +1,340 @@
|
||||
# %% [markdown]
|
||||
# ## Copyright 2022 Google LLC. Double-click for license information.
|
||||
|
||||
# %%
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# %% [markdown]
|
||||
# # Null-text inversion + Editing with Prompt-to-Prompt
|
||||
|
||||
# %%
|
||||
from typing import Optional, Union, Tuple, List, Callable, Dict
|
||||
# from tqdm.notebook import tqdm
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline, DDIMScheduler
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
import torch.nn.functional as nnf
|
||||
import numpy as np
|
||||
import abc
|
||||
import ptp_utils
|
||||
import seq_aligner
|
||||
import shutil
|
||||
from torch.optim.adam import Adam
|
||||
from PIL import Image
|
||||
|
||||
# %% [markdown]
|
||||
# For loading the Stable Diffusion using Diffusers, follow the instuctions https://huggingface.co/blog/stable_diffusion and update MY_TOKEN with your token.
|
||||
|
||||
# %%
|
||||
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
|
||||
MY_TOKEN = ''
|
||||
LOW_RESOURCE = False
|
||||
NUM_DDIM_STEPS = 50
|
||||
GUIDANCE_SCALE = 7.5
|
||||
MAX_NUM_WORDS = 77
|
||||
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||
# ldm_stable = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", scheduler=scheduler).to(device)
|
||||
ldm_stable = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base").to(device)
|
||||
# try:
|
||||
# ldm_stable.disable_xformers_memory_efficient_attention()
|
||||
# except AttributeError:
|
||||
# print("Attribute disable_xformers_memory_efficient_attention() is missing")
|
||||
if is_xformers_available():
|
||||
ldm_stable.enable_xformers_memory_efficient_attention()
|
||||
|
||||
tokenizer = ldm_stable.tokenizer
|
||||
def load_512(image_path, left=0, right=0, top=0, bottom=0):
|
||||
if type(image_path) is str:
|
||||
image = np.array(Image.open(image_path))[:, :, :3]
|
||||
else:
|
||||
image = image_path
|
||||
h, w, c = image.shape
|
||||
left = min(left, w-1)
|
||||
right = min(right, w - left - 1)
|
||||
top = min(top, h - left - 1)
|
||||
bottom = min(bottom, h - top - 1)
|
||||
image = image[top:h-bottom, left:w-right]
|
||||
h, w, c = image.shape
|
||||
if h < w:
|
||||
offset = (w - h) // 2
|
||||
image = image[:, offset:offset + h]
|
||||
elif w < h:
|
||||
offset = (h - w) // 2
|
||||
image = image[offset:offset + w]
|
||||
image = np.array(Image.fromarray(image).resize((512, 512)))
|
||||
return image
|
||||
|
||||
|
||||
class NullInversion:
|
||||
|
||||
def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
|
||||
prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
|
||||
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
|
||||
pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output
|
||||
prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
|
||||
return prev_sample
|
||||
|
||||
def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
|
||||
timestep, next_timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep
|
||||
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
|
||||
alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
|
||||
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
|
||||
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
|
||||
return next_sample
|
||||
|
||||
def get_noise_pred_single(self, latents, t, context):
|
||||
noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"]
|
||||
return noise_pred
|
||||
|
||||
def get_noise_pred(self, latents, t, is_forward=True, context=None):
|
||||
latents_input = torch.cat([latents] * 2)
|
||||
if context is None:
|
||||
context = self.context
|
||||
guidance_scale = 1 if is_forward else GUIDANCE_SCALE
|
||||
noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
|
||||
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
|
||||
if is_forward:
|
||||
latents = self.next_step(noise_pred, t, latents)
|
||||
else:
|
||||
latents = self.prev_step(noise_pred, t, latents)
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def latent2image(self, latents, return_type='np'):
|
||||
latents = 1 / 0.18215 * latents.detach()
|
||||
image = self.model.vae.decode(latents)['sample']
|
||||
if return_type == 'np':
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
|
||||
image = (image * 255).astype(np.uint8)
|
||||
return image
|
||||
|
||||
@torch.no_grad()
|
||||
def image2latent(self, image):
|
||||
with torch.no_grad():
|
||||
if type(image) is Image:
|
||||
image = np.array(image)
|
||||
if type(image) is torch.Tensor and image.dim() == 4:
|
||||
latents = image
|
||||
else:
|
||||
image = torch.from_numpy(image).float() / 127.5 - 1
|
||||
image = image.permute(2, 0, 1).unsqueeze(0).to(device)
|
||||
latents = self.model.vae.encode(image)['latent_dist'].mean
|
||||
latents = latents * 0.18215
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def init_prompt(self, prompt: str):
|
||||
uncond_input = self.model.tokenizer(
|
||||
[""], padding="max_length", max_length=self.model.tokenizer.model_max_length,
|
||||
return_tensors="pt"
|
||||
)
|
||||
uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0]
|
||||
text_input = self.model.tokenizer(
|
||||
[prompt],
|
||||
padding="max_length",
|
||||
max_length=self.model.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0]
|
||||
self.context = torch.cat([uncond_embeddings, text_embeddings])
|
||||
self.prompt = prompt
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_loop(self, latent):
|
||||
uncond_embeddings, cond_embeddings = self.context.chunk(2)
|
||||
all_latent = [latent]
|
||||
latent = latent.clone().detach()
|
||||
for i in range(NUM_DDIM_STEPS):
|
||||
t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1]
|
||||
noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings)
|
||||
latent = self.next_step(noise_pred, t, latent)
|
||||
all_latent.append(latent)
|
||||
return all_latent
|
||||
|
||||
@property
|
||||
def scheduler(self):
|
||||
return self.model.scheduler
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_inversion(self, image):
|
||||
latent = self.image2latent(image)
|
||||
image_rec = self.latent2image(latent)
|
||||
ddim_latents = self.ddim_loop(latent)
|
||||
return image_rec, ddim_latents
|
||||
|
||||
def null_optimization(self, latents, num_inner_steps, epsilon):
|
||||
uncond_embeddings, cond_embeddings = self.context.chunk(2)
|
||||
uncond_embeddings_list = []
|
||||
latent_cur = latents[-1]
|
||||
bar = tqdm(total=num_inner_steps * NUM_DDIM_STEPS)
|
||||
for i in range(NUM_DDIM_STEPS):
|
||||
uncond_embeddings = uncond_embeddings.clone().detach()
|
||||
uncond_embeddings.requires_grad = True
|
||||
optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.))
|
||||
latent_prev = latents[len(latents) - i - 2]
|
||||
t = self.model.scheduler.timesteps[i]
|
||||
with torch.no_grad():
|
||||
noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings)
|
||||
for j in range(num_inner_steps):
|
||||
noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings)
|
||||
noise_pred = noise_pred_uncond + GUIDANCE_SCALE * (noise_pred_cond - noise_pred_uncond)
|
||||
latents_prev_rec = self.prev_step(noise_pred, t, latent_cur)
|
||||
loss = nnf.mse_loss(latents_prev_rec, latent_prev)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
loss_item = loss.item()
|
||||
bar.update()
|
||||
if loss_item < epsilon + i * 2e-5:
|
||||
break
|
||||
for j in range(j + 1, num_inner_steps):
|
||||
bar.update()
|
||||
uncond_embeddings_list.append(uncond_embeddings[:1].detach())
|
||||
with torch.no_grad():
|
||||
context = torch.cat([uncond_embeddings, cond_embeddings])
|
||||
latent_cur = self.get_noise_pred(latent_cur, t, False, context)
|
||||
bar.close()
|
||||
return uncond_embeddings_list
|
||||
|
||||
def invert(self, image_path: str, prompt: str, offsets=(0,0,0,0), num_inner_steps=10, early_stop_epsilon=1e-5, verbose=False):
|
||||
self.init_prompt(prompt)
|
||||
image_gt = load_512(image_path, *offsets)
|
||||
if verbose:
|
||||
print("DDIM inversion...")
|
||||
image_rec, ddim_latents = self.ddim_inversion(image_gt)
|
||||
|
||||
uncond_embeddings = None
|
||||
# if verbose:
|
||||
# print("Null-text optimization...")
|
||||
# uncond_embeddings = self.null_optimization(ddim_latents, num_inner_steps, early_stop_epsilon)
|
||||
return (image_gt, image_rec), ddim_latents[-1], uncond_embeddings
|
||||
|
||||
|
||||
def __init__(self, model):
|
||||
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False,
|
||||
set_alpha_to_one=False)
|
||||
self.model = model
|
||||
self.tokenizer = self.model.tokenizer
|
||||
self.model.scheduler.set_timesteps(NUM_DDIM_STEPS)
|
||||
self.prompt = None
|
||||
self.context = None
|
||||
|
||||
null_inversion = NullInversion(ldm_stable)
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
# ## Infernce Code
|
||||
|
||||
# %%
|
||||
@torch.no_grad()
|
||||
def text2image_ldm_stable(
|
||||
model,
|
||||
prompt: List[str],
|
||||
controller,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latent: Optional[torch.FloatTensor] = None,
|
||||
uncond_embeddings=None,
|
||||
start_time=50,
|
||||
return_type='image'
|
||||
):
|
||||
batch_size = len(prompt)
|
||||
height = width = 512
|
||||
|
||||
text_input = model.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=model.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
|
||||
max_length = text_input.input_ids.shape[-1]
|
||||
if uncond_embeddings is None:
|
||||
uncond_input = model.tokenizer(
|
||||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
||||
)
|
||||
uncond_embeddings_ = model.text_encoder(uncond_input.input_ids.to(model.device))[0]
|
||||
else:
|
||||
uncond_embeddings_ = None
|
||||
|
||||
latent, latents = ptp_utils.init_latent(latent, model, height, width, generator, batch_size)
|
||||
model.scheduler.set_timesteps(num_inference_steps)
|
||||
for i, t in enumerate(tqdm(model.scheduler.timesteps[-start_time:])):
|
||||
if uncond_embeddings_ is None:
|
||||
context = torch.cat([uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings])
|
||||
else:
|
||||
context = torch.cat([uncond_embeddings_, text_embeddings])
|
||||
latents = ptp_utils.diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False)
|
||||
|
||||
if return_type == 'image':
|
||||
image = ptp_utils.latent2image(model.vae, latents)
|
||||
else:
|
||||
image = latents
|
||||
return image, latent
|
||||
|
||||
|
||||
|
||||
# def run_and_display(prompts, latent=None, run_baseline=False, generator=None, uncond_embeddings=None, verbose=True):
|
||||
# images, latent = run_and_display(prompts, latent=latent, run_baseline=False, generator=generator)
|
||||
# if verbose:
|
||||
# ptp_utils.view_images(images)
|
||||
# return images, x_t
|
||||
|
||||
class EmptyControl:
|
||||
|
||||
|
||||
def step_callback(self, x_t):
|
||||
return x_t
|
||||
|
||||
def between_steps(self):
|
||||
return
|
||||
|
||||
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
||||
return attn
|
||||
|
||||
def run_and_display(prompts, controller, latent=None, run_baseline=False, generator=None, uncond_embeddings=None, verbose=True):
|
||||
if run_baseline:
|
||||
print("w.o. prompt-to-prompt")
|
||||
images, latent = run_and_display(prompts, EmptyControl(), latent=latent, run_baseline=False, generator=generator)
|
||||
print("with prompt-to-prompt")
|
||||
images, x_t = text2image_ldm_stable(ldm_stable, prompts, controller, latent=latent, num_inference_steps=NUM_DDIM_STEPS, guidance_scale=GUIDANCE_SCALE, generator=generator, uncond_embeddings=uncond_embeddings)
|
||||
if verbose:
|
||||
ptp_utils.view_images(images)
|
||||
return images, x_t
|
||||
|
||||
# %%
|
||||
image_path = "../data/dragon_statue_1/image.png"
|
||||
prompt = "A high-resolution DSLR image of a grey dragon statue"
|
||||
offsets = (0, 0, 0, 0)
|
||||
img = load_512(image_path, *offsets)
|
||||
ptp_utils.view_images(img)
|
||||
|
||||
(image_gt, image_enc), x_t, uncond_embeddings = null_inversion.invert(image_path, prompt, offsets=offsets, verbose=True)
|
||||
|
||||
prompts = [prompt]
|
||||
image_inv, x_t = run_and_display(prompts, EmptyControl(), run_baseline=False, latent=x_t, uncond_embeddings=uncond_embeddings, verbose=False)
|
||||
print("showing from left to right: the ground truth image, the vq-autoencoder reconstruction, the null-text inverted image")
|
||||
ptp_utils.view_images([image_gt, image_enc, image_inv[0]])
|
||||
1470
textual-inversion/null_text_w_ptp.ipynb
Normal file
1470
textual-inversion/null_text_w_ptp.ipynb
Normal file
File diff suppressed because one or more lines are too long
295
textual-inversion/ptp_utils.py
Normal file
295
textual-inversion/ptp_utils.py
Normal file
@@ -0,0 +1,295 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import cv2
|
||||
from typing import Optional, Union, Tuple, List, Callable, Dict
|
||||
from IPython.display import display
|
||||
from tqdm.notebook import tqdm
|
||||
|
||||
|
||||
def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
|
||||
h, w, c = image.shape
|
||||
offset = int(h * .2)
|
||||
img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
# font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
|
||||
img[:h] = image
|
||||
textsize = cv2.getTextSize(text, font, 1, 2)[0]
|
||||
text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
|
||||
cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
|
||||
return img
|
||||
|
||||
|
||||
def view_images(images, num_rows=1, offset_ratio=0.02):
|
||||
if type(images) is list:
|
||||
num_empty = len(images) % num_rows
|
||||
elif images.ndim == 4:
|
||||
num_empty = images.shape[0] % num_rows
|
||||
else:
|
||||
images = [images]
|
||||
num_empty = 0
|
||||
|
||||
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
|
||||
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
|
||||
num_items = len(images)
|
||||
|
||||
h, w, c = images[0].shape
|
||||
offset = int(h * offset_ratio)
|
||||
num_cols = num_items // num_rows
|
||||
image_ = np.ones((h * num_rows + offset * (num_rows - 1),
|
||||
w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
|
||||
for i in range(num_rows):
|
||||
for j in range(num_cols):
|
||||
image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
|
||||
i * num_cols + j]
|
||||
|
||||
pil_img = Image.fromarray(image_)
|
||||
display(pil_img)
|
||||
|
||||
|
||||
def diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False):
|
||||
if low_resource:
|
||||
noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"]
|
||||
noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"]
|
||||
else:
|
||||
latents_input = torch.cat([latents] * 2)
|
||||
noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
|
||||
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
|
||||
latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"]
|
||||
latents = controller.step_callback(latents)
|
||||
return latents
|
||||
|
||||
|
||||
def latent2image(vae, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = vae.decode(latents)['sample']
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
image = (image * 255).astype(np.uint8)
|
||||
return image
|
||||
|
||||
|
||||
def init_latent(latent, model, height, width, generator, batch_size):
|
||||
if latent is None:
|
||||
latent = torch.randn(
|
||||
(1, model.unet.in_channels, height // 8, width // 8),
|
||||
generator=generator,
|
||||
)
|
||||
latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device)
|
||||
return latent, latents
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def text2image_ldm(
|
||||
model,
|
||||
prompt: List[str],
|
||||
controller,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: Optional[float] = 7.,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latent: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
register_attention_control(model, controller)
|
||||
height = width = 256
|
||||
batch_size = len(prompt)
|
||||
|
||||
uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
|
||||
uncond_embeddings = model.bert(uncond_input.input_ids.to(model.device))[0]
|
||||
|
||||
text_input = model.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
|
||||
text_embeddings = model.bert(text_input.input_ids.to(model.device))[0]
|
||||
latent, latents = init_latent(latent, model, height, width, generator, batch_size)
|
||||
context = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
model.scheduler.set_timesteps(num_inference_steps)
|
||||
for t in tqdm(model.scheduler.timesteps):
|
||||
latents = diffusion_step(model, controller, latents, context, t, guidance_scale)
|
||||
|
||||
image = latent2image(model.vqvae, latents)
|
||||
|
||||
return image, latent
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def text2image_ldm_stable(
|
||||
model,
|
||||
prompt: List[str],
|
||||
controller,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latent: Optional[torch.FloatTensor] = None,
|
||||
low_resource: bool = False,
|
||||
):
|
||||
register_attention_control(model, controller)
|
||||
height = width = 512
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_input = model.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=model.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
|
||||
max_length = text_input.input_ids.shape[-1]
|
||||
uncond_input = model.tokenizer(
|
||||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
||||
)
|
||||
uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0]
|
||||
|
||||
context = [uncond_embeddings, text_embeddings]
|
||||
if not low_resource:
|
||||
context = torch.cat(context)
|
||||
latent, latents = init_latent(latent, model, height, width, generator, batch_size)
|
||||
|
||||
# set timesteps
|
||||
extra_set_kwargs = {"offset": 1}
|
||||
model.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
||||
for t in tqdm(model.scheduler.timesteps):
|
||||
latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource)
|
||||
|
||||
image = latent2image(model.vae, latents)
|
||||
|
||||
return image, latent
|
||||
|
||||
|
||||
def register_attention_control(model, controller):
|
||||
def ca_forward(self, place_in_unet):
|
||||
to_out = self.to_out
|
||||
if type(to_out) is torch.nn.modules.container.ModuleList:
|
||||
to_out = self.to_out[0]
|
||||
else:
|
||||
to_out = self.to_out
|
||||
|
||||
def forward(x, context=None, mask=None):
|
||||
batch_size, sequence_length, dim = x.shape
|
||||
h = self.heads
|
||||
q = self.to_q(x)
|
||||
is_cross = context is not None
|
||||
context = context if is_cross else x
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
q = self.reshape_heads_to_batch_dim(q)
|
||||
k = self.reshape_heads_to_batch_dim(k)
|
||||
v = self.reshape_heads_to_batch_dim(v)
|
||||
|
||||
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.reshape(batch_size, -1)
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = mask[:, None, :].repeat(h, 1, 1)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
attn = sim.softmax(dim=-1)
|
||||
attn = controller(attn, is_cross, place_in_unet)
|
||||
out = torch.einsum("b i j, b j d -> b i d", attn, v)
|
||||
out = self.reshape_batch_dim_to_heads(out)
|
||||
return to_out(out)
|
||||
|
||||
return forward
|
||||
|
||||
class DummyController:
|
||||
|
||||
def __call__(self, *args):
|
||||
return args[0]
|
||||
|
||||
def __init__(self):
|
||||
self.num_att_layers = 0
|
||||
|
||||
if controller is None:
|
||||
controller = DummyController()
|
||||
|
||||
def register_recr(net_, count, place_in_unet):
|
||||
if net_.__class__.__name__ == 'CrossAttention':
|
||||
net_.forward = ca_forward(net_, place_in_unet)
|
||||
return count + 1
|
||||
elif hasattr(net_, 'children'):
|
||||
for net__ in net_.children():
|
||||
count = register_recr(net__, count, place_in_unet)
|
||||
return count
|
||||
|
||||
cross_att_count = 0
|
||||
sub_nets = model.unet.named_children()
|
||||
for net in sub_nets:
|
||||
if "down" in net[0]:
|
||||
cross_att_count += register_recr(net[1], 0, "down")
|
||||
elif "up" in net[0]:
|
||||
cross_att_count += register_recr(net[1], 0, "up")
|
||||
elif "mid" in net[0]:
|
||||
cross_att_count += register_recr(net[1], 0, "mid")
|
||||
|
||||
controller.num_att_layers = cross_att_count
|
||||
|
||||
|
||||
def get_word_inds(text: str, word_place: int, tokenizer):
|
||||
split_text = text.split(" ")
|
||||
if type(word_place) is str:
|
||||
word_place = [i for i, word in enumerate(split_text) if word_place == word]
|
||||
elif type(word_place) is int:
|
||||
word_place = [word_place]
|
||||
out = []
|
||||
if len(word_place) > 0:
|
||||
words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
|
||||
cur_len, ptr = 0, 0
|
||||
|
||||
for i in range(len(words_encode)):
|
||||
cur_len += len(words_encode[i])
|
||||
if ptr in word_place:
|
||||
out.append(i + 1)
|
||||
if cur_len >= len(split_text[ptr]):
|
||||
ptr += 1
|
||||
cur_len = 0
|
||||
return np.array(out)
|
||||
|
||||
|
||||
def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int,
|
||||
word_inds: Optional[torch.Tensor]=None):
|
||||
if type(bounds) is float:
|
||||
bounds = 0, bounds
|
||||
start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
|
||||
if word_inds is None:
|
||||
word_inds = torch.arange(alpha.shape[2])
|
||||
alpha[: start, prompt_ind, word_inds] = 0
|
||||
alpha[start: end, prompt_ind, word_inds] = 1
|
||||
alpha[end:, prompt_ind, word_inds] = 0
|
||||
return alpha
|
||||
|
||||
|
||||
def get_time_words_attention_alpha(prompts, num_steps,
|
||||
cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
|
||||
tokenizer, max_num_words=77):
|
||||
if type(cross_replace_steps) is not dict:
|
||||
cross_replace_steps = {"default_": cross_replace_steps}
|
||||
if "default_" not in cross_replace_steps:
|
||||
cross_replace_steps["default_"] = (0., 1.)
|
||||
alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
|
||||
for i in range(len(prompts) - 1):
|
||||
alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
|
||||
i)
|
||||
for key, item in cross_replace_steps.items():
|
||||
if key != "default_":
|
||||
inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
|
||||
for i, ind in enumerate(inds):
|
||||
if len(ind) > 0:
|
||||
alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
|
||||
alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
|
||||
return alpha_time_words
|
||||
196
textual-inversion/seq_aligner.py
Normal file
196
textual-inversion/seq_aligner.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ScoreParams:
|
||||
|
||||
def __init__(self, gap, match, mismatch):
|
||||
self.gap = gap
|
||||
self.match = match
|
||||
self.mismatch = mismatch
|
||||
|
||||
def mis_match_char(self, x, y):
|
||||
if x != y:
|
||||
return self.mismatch
|
||||
else:
|
||||
return self.match
|
||||
|
||||
|
||||
def get_matrix(size_x, size_y, gap):
|
||||
matrix = []
|
||||
for i in range(len(size_x) + 1):
|
||||
sub_matrix = []
|
||||
for j in range(len(size_y) + 1):
|
||||
sub_matrix.append(0)
|
||||
matrix.append(sub_matrix)
|
||||
for j in range(1, len(size_y) + 1):
|
||||
matrix[0][j] = j*gap
|
||||
for i in range(1, len(size_x) + 1):
|
||||
matrix[i][0] = i*gap
|
||||
return matrix
|
||||
|
||||
|
||||
def get_matrix(size_x, size_y, gap):
|
||||
matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
|
||||
matrix[0, 1:] = (np.arange(size_y) + 1) * gap
|
||||
matrix[1:, 0] = (np.arange(size_x) + 1) * gap
|
||||
return matrix
|
||||
|
||||
|
||||
def get_traceback_matrix(size_x, size_y):
|
||||
matrix = np.zeros((size_x + 1, size_y +1), dtype=np.int32)
|
||||
matrix[0, 1:] = 1
|
||||
matrix[1:, 0] = 2
|
||||
matrix[0, 0] = 4
|
||||
return matrix
|
||||
|
||||
|
||||
def global_align(x, y, score):
|
||||
matrix = get_matrix(len(x), len(y), score.gap)
|
||||
trace_back = get_traceback_matrix(len(x), len(y))
|
||||
for i in range(1, len(x) + 1):
|
||||
for j in range(1, len(y) + 1):
|
||||
left = matrix[i, j - 1] + score.gap
|
||||
up = matrix[i - 1, j] + score.gap
|
||||
diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1])
|
||||
matrix[i, j] = max(left, up, diag)
|
||||
if matrix[i, j] == left:
|
||||
trace_back[i, j] = 1
|
||||
elif matrix[i, j] == up:
|
||||
trace_back[i, j] = 2
|
||||
else:
|
||||
trace_back[i, j] = 3
|
||||
return matrix, trace_back
|
||||
|
||||
|
||||
def get_aligned_sequences(x, y, trace_back):
|
||||
x_seq = []
|
||||
y_seq = []
|
||||
i = len(x)
|
||||
j = len(y)
|
||||
mapper_y_to_x = []
|
||||
while i > 0 or j > 0:
|
||||
if trace_back[i, j] == 3:
|
||||
x_seq.append(x[i-1])
|
||||
y_seq.append(y[j-1])
|
||||
i = i-1
|
||||
j = j-1
|
||||
mapper_y_to_x.append((j, i))
|
||||
elif trace_back[i][j] == 1:
|
||||
x_seq.append('-')
|
||||
y_seq.append(y[j-1])
|
||||
j = j-1
|
||||
mapper_y_to_x.append((j, -1))
|
||||
elif trace_back[i][j] == 2:
|
||||
x_seq.append(x[i-1])
|
||||
y_seq.append('-')
|
||||
i = i-1
|
||||
elif trace_back[i][j] == 4:
|
||||
break
|
||||
mapper_y_to_x.reverse()
|
||||
return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64)
|
||||
|
||||
|
||||
def get_mapper(x: str, y: str, tokenizer, max_len=77):
|
||||
x_seq = tokenizer.encode(x)
|
||||
y_seq = tokenizer.encode(y)
|
||||
score = ScoreParams(0, 1, -1)
|
||||
matrix, trace_back = global_align(x_seq, y_seq, score)
|
||||
mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1]
|
||||
alphas = torch.ones(max_len)
|
||||
alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float()
|
||||
mapper = torch.zeros(max_len, dtype=torch.int64)
|
||||
mapper[:mapper_base.shape[0]] = mapper_base[:, 1]
|
||||
mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq))
|
||||
return mapper, alphas
|
||||
|
||||
|
||||
def get_refinement_mapper(prompts, tokenizer, max_len=77):
|
||||
x_seq = prompts[0]
|
||||
mappers, alphas = [], []
|
||||
for i in range(1, len(prompts)):
|
||||
mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len)
|
||||
mappers.append(mapper)
|
||||
alphas.append(alpha)
|
||||
return torch.stack(mappers), torch.stack(alphas)
|
||||
|
||||
|
||||
def get_word_inds(text: str, word_place: int, tokenizer):
|
||||
split_text = text.split(" ")
|
||||
if type(word_place) is str:
|
||||
word_place = [i for i, word in enumerate(split_text) if word_place == word]
|
||||
elif type(word_place) is int:
|
||||
word_place = [word_place]
|
||||
out = []
|
||||
if len(word_place) > 0:
|
||||
words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
|
||||
cur_len, ptr = 0, 0
|
||||
|
||||
for i in range(len(words_encode)):
|
||||
cur_len += len(words_encode[i])
|
||||
if ptr in word_place:
|
||||
out.append(i + 1)
|
||||
if cur_len >= len(split_text[ptr]):
|
||||
ptr += 1
|
||||
cur_len = 0
|
||||
return np.array(out)
|
||||
|
||||
|
||||
def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):
|
||||
words_x = x.split(' ')
|
||||
words_y = y.split(' ')
|
||||
if len(words_x) != len(words_y):
|
||||
raise ValueError(f"attention replacement edit can only be applied on prompts with the same length"
|
||||
f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.")
|
||||
inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
|
||||
inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace]
|
||||
inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace]
|
||||
mapper = np.zeros((max_len, max_len))
|
||||
i = j = 0
|
||||
cur_inds = 0
|
||||
while i < max_len and j < max_len:
|
||||
if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i:
|
||||
inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds]
|
||||
if len(inds_source_) == len(inds_target_):
|
||||
mapper[inds_source_, inds_target_] = 1
|
||||
else:
|
||||
ratio = 1 / len(inds_target_)
|
||||
for i_t in inds_target_:
|
||||
mapper[inds_source_, i_t] = ratio
|
||||
cur_inds += 1
|
||||
i += len(inds_source_)
|
||||
j += len(inds_target_)
|
||||
elif cur_inds < len(inds_source):
|
||||
mapper[i, j] = 1
|
||||
i += 1
|
||||
j += 1
|
||||
else:
|
||||
mapper[j, j] = 1
|
||||
i += 1
|
||||
j += 1
|
||||
|
||||
return torch.from_numpy(mapper).float()
|
||||
|
||||
|
||||
|
||||
def get_replacement_mapper(prompts, tokenizer, max_len=77):
|
||||
x_seq = prompts[0]
|
||||
mappers = []
|
||||
for i in range(1, len(prompts)):
|
||||
mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len)
|
||||
mappers.append(mapper)
|
||||
return torch.stack(mappers)
|
||||
|
||||
927
textual-inversion/textual_inversion.py
Normal file
927
textual-inversion/textual_inversion.py
Normal file
@@ -0,0 +1,927 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Adapted with almost no modifications from
|
||||
# https://github.com/huggingface/diffusers/tree/3d2648d743e4257c550bba03242486b1f3834838/examples/textual_inversion
|
||||
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import glob
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||
|
||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
DiffusionPipeline,
|
||||
DPMSolverMultistepScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.Resampling.BILINEAR,
|
||||
"bilinear": PIL.Image.Resampling.BILINEAR,
|
||||
"bicubic": PIL.Image.Resampling.BICUBIC,
|
||||
"lanczos": PIL.Image.Resampling.LANCZOS,
|
||||
"nearest": PIL.Image.Resampling.NEAREST,
|
||||
}
|
||||
else:
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
"nearest": PIL.Image.NEAREST,
|
||||
}
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
# create pipeline (note: unet and vae are loaded again in float32)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
vae=vae,
|
||||
revision=args.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
images = []
|
||||
for _ in range(args.num_validation_images):
|
||||
with torch.autocast("cuda"):
|
||||
image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
|
||||
images.append(image)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in images])
|
||||
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
||||
if tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"validation": [
|
||||
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
|
||||
logger.info("Saving embeddings")
|
||||
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
|
||||
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
|
||||
torch.save(learned_embeds_dict, save_path)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--save_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Save learned_embeds.bin every X updates steps.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--only_save_embeds",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Save only the embeddings for the new concept.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--placeholder_token",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="A token to use as a placeholder for the concept.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
|
||||
)
|
||||
parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
|
||||
parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="text-inversion-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument("--num_train_epochs", type=int, default=100)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=5000,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing",
|
||||
action="store_true",
|
||||
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow_tf32",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
||||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A prompt that is used during validation to verify that the model is learning.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_validation_images",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of images that should be generated during validation with `validation_prompt`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_steps",
|
||||
type=int,
|
||||
default=100,
|
||||
help=(
|
||||
"Run validation every X steps. Validation consists of running the prompt"
|
||||
" `args.validation_prompt` multiple times: `args.num_validation_images`"
|
||||
" and logging the images."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_epochs",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt"
|
||||
" `args.validation_prompt` multiple times: `args.num_validation_images`"
|
||||
" and logging the images."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
||||
" training using `--resume_from_checkpoint`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoints_total_limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
|
||||
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
|
||||
" for more docs"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
|
||||
# New: heavy image augmentations
|
||||
parser.add_argument(
|
||||
"--use_augmentations", action="store_true", help="Whether or not to use heavy image augmentations."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
if args.train_data_dir is None:
|
||||
raise ValueError("You must specify a train data directory.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
imagenet_templates_small = [
|
||||
"a photo of a {}",
|
||||
"a rendering of a {}",
|
||||
"a cropped photo of the {}",
|
||||
"the photo of a {}",
|
||||
"a photo of a clean {}",
|
||||
"a photo of a dirty {}",
|
||||
"a dark photo of the {}",
|
||||
"a photo of my {}",
|
||||
"a photo of the cool {}",
|
||||
"a close-up photo of a {}",
|
||||
"a bright photo of the {}",
|
||||
"a cropped photo of a {}",
|
||||
"a photo of the {}",
|
||||
"a good photo of the {}",
|
||||
"a photo of one {}",
|
||||
"a close-up photo of the {}",
|
||||
"a rendition of the {}",
|
||||
"a photo of the clean {}",
|
||||
"a rendition of a {}",
|
||||
"a photo of a nice {}",
|
||||
"a good photo of a {}",
|
||||
"a photo of the nice {}",
|
||||
"a photo of the small {}",
|
||||
"a photo of the weird {}",
|
||||
"a photo of the large {}",
|
||||
"a photo of a cool {}",
|
||||
"a photo of a small {}",
|
||||
]
|
||||
|
||||
imagenet_style_templates_small = [
|
||||
"a painting in the style of {}",
|
||||
"a rendering in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"the painting in the style of {}",
|
||||
"a clean painting in the style of {}",
|
||||
"a dirty painting in the style of {}",
|
||||
"a dark painting in the style of {}",
|
||||
"a picture in the style of {}",
|
||||
"a cool painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a bright painting in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"a good painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a rendition in the style of {}",
|
||||
"a nice painting in the style of {}",
|
||||
"a small painting in the style of {}",
|
||||
"a weird painting in the style of {}",
|
||||
"a large painting in the style of {}",
|
||||
]
|
||||
|
||||
|
||||
class TextualInversionDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
tokenizer,
|
||||
learnable_property="object", # [object, style]
|
||||
size=512,
|
||||
repeats=100,
|
||||
interpolation="bicubic",
|
||||
flip_p=0.5,
|
||||
set="train",
|
||||
placeholder_token="*",
|
||||
center_crop=False,
|
||||
use_augmentations=False,
|
||||
):
|
||||
self.data_root = data_root
|
||||
self.tokenizer = tokenizer
|
||||
self.learnable_property = learnable_property
|
||||
self.size = size
|
||||
self.placeholder_token = placeholder_token
|
||||
self.center_crop = center_crop
|
||||
self.flip_p = flip_p
|
||||
|
||||
#self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
|
||||
self.image_paths = [self.data_root]
|
||||
self.num_images = len(self.image_paths)
|
||||
self._length = self.num_images
|
||||
|
||||
if set == "train":
|
||||
self._length = self.num_images * repeats
|
||||
|
||||
self.interpolation = {
|
||||
"linear": PIL_INTERPOLATION["linear"],
|
||||
"bilinear": PIL_INTERPOLATION["bilinear"],
|
||||
"bicubic": PIL_INTERPOLATION["bicubic"],
|
||||
"lanczos": PIL_INTERPOLATION["lanczos"],
|
||||
}[interpolation]
|
||||
|
||||
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
|
||||
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
|
||||
|
||||
self.use_augmentations = use_augmentations
|
||||
if self.use_augmentations:
|
||||
# This is unnecessarily convoluted because I previously used the albumentations library, but
|
||||
# I wanted to remove that dependency. In torchvision, there is no good way of randomly rotating
|
||||
# and then cropping into the rotation by the correct amount such that there is no padding. But
|
||||
# this is a hack that works ok for that case.
|
||||
self.aug_transform = transforms.Compose([
|
||||
transforms.Resize(int(self.size * 5/4)),
|
||||
transforms.CenterCrop(int(self.size * 5/4)),
|
||||
transforms.RandomApply([
|
||||
transforms.RandomRotation(degrees=10, fill=255),
|
||||
transforms.CenterCrop(int(self.size * 5/6)),
|
||||
transforms.Resize(self.size),
|
||||
], p=0.75),
|
||||
transforms.RandomResizedCrop(self.size, scale=(0.85, 1.15)),
|
||||
transforms.RandomApply([transforms.ColorJitter(0.04, 0.04, 0.04, 0.04)], p=0.75),
|
||||
transforms.RandomGrayscale(p=0.10),
|
||||
transforms.RandomApply([transforms.GaussianBlur(5, (0.1, 2))], p=0.10),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = {}
|
||||
image = Image.open(self.image_paths[i % self.num_images])
|
||||
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
placeholder_string = self.placeholder_token
|
||||
text = random.choice(self.templates).format(placeholder_string)
|
||||
|
||||
example["input_ids"] = self.tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids[0]
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
(
|
||||
h,
|
||||
w,
|
||||
) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
if self.use_augmentations:
|
||||
image = self.aug_transform(image)
|
||||
else:
|
||||
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||
image = self.flip_transform(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
image = (image / 127.5 - 1.0).astype(np.float32)
|
||||
image = torch.from_numpy(image).permute(2, 0, 1)
|
||||
|
||||
example["pixel_values"] = image
|
||||
return example
|
||||
|
||||
|
||||
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||
if token is None:
|
||||
token = HfFolder.get_token()
|
||||
if organization is None:
|
||||
username = whoami(token)["name"]
|
||||
return f"{username}/{model_id}"
|
||||
else:
|
||||
return f"{organization}/{model_id}"
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
project_dir=logging_dir,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(args)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
if accelerator.is_local_main_process:
|
||||
transformers.utils.logging.set_verbosity_warning()
|
||||
diffusers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
diffusers.utils.logging.set_verbosity_error()
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub:
|
||||
if args.hub_model_id is None:
|
||||
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
||||
else:
|
||||
repo_name = args.hub_model_id
|
||||
create_repo(repo_name, exist_ok=True, token=args.hub_token)
|
||||
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
|
||||
|
||||
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
||||
if "step_*" not in gitignore:
|
||||
gitignore.write("step_*\n")
|
||||
if "epoch_*" not in gitignore:
|
||||
gitignore.write("epoch_*\n")
|
||||
elif args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Load tokenizer
|
||||
if args.tokenizer_name:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
|
||||
elif args.pretrained_model_name_or_path:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
)
|
||||
|
||||
# Add the placeholder token in tokenizer
|
||||
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
|
||||
if num_added_tokens == 0:
|
||||
raise ValueError(
|
||||
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
|
||||
" `placeholder_token` that is not already in the tokenizer."
|
||||
)
|
||||
|
||||
# Convert the initializer_token, placeholder_token to ids
|
||||
token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
|
||||
# Check if initializer_token is a single token or a sequence of tokens
|
||||
if len(token_ids) > 1:
|
||||
raise ValueError("The initializer token must be a single token.")
|
||||
|
||||
initializer_token_id = token_ids[0]
|
||||
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
|
||||
|
||||
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
|
||||
|
||||
# Freeze vae and unet
|
||||
vae.requires_grad_(False)
|
||||
unet.requires_grad_(False)
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
text_encoder.text_model.encoder.requires_grad_(False)
|
||||
text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
||||
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
# Keep unet in train mode if we are using gradient checkpointing to save memory.
|
||||
# The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
|
||||
unet.train()
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if args.allow_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Initialize the optimizer
|
||||
optimizer = torch.optim.AdamW(
|
||||
text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
# Dataset and DataLoaders creation:
|
||||
train_dataset = TextualInversionDataset(
|
||||
data_root=args.train_data_dir,
|
||||
tokenizer=tokenizer,
|
||||
size=args.resolution,
|
||||
placeholder_token=args.placeholder_token,
|
||||
repeats=args.repeats,
|
||||
learnable_property=args.learnable_property,
|
||||
center_crop=args.center_crop,
|
||||
set="train",
|
||||
use_augmentations=args.use_augmentations,
|
||||
)
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
|
||||
)
|
||||
if args.validation_epochs is not None:
|
||||
warnings.warn(
|
||||
f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}."
|
||||
" Deprecated validation_epochs in favor of `validation_steps`"
|
||||
f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
args.validation_steps = args.validation_epochs * len(train_dataset)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# For mixed precision training we cast the unet and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# Move vae and unet to device and cast to weight_dtype
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("textual_inversion", config=vars(args))
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
# Potentially load in the weights and states from a previous save
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the most recent checkpoint
|
||||
dirs = os.listdir(args.output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1] if len(dirs) > 0 else None
|
||||
|
||||
if path is None:
|
||||
accelerator.print(
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
# keep original embeddings as reference
|
||||
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
text_encoder.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(text_encoder):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
if global_step % args.save_steps == 0:
|
||||
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
|
||||
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
|
||||
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
|
||||
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub and args.only_save_embeds:
|
||||
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
save_full_model = True
|
||||
else:
|
||||
save_full_model = not args.only_save_embeds
|
||||
if save_full_model:
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
# Save the newly trained embeddings
|
||||
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
|
||||
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user