first commit
This commit is contained in:
96
ldm/guidance.py
Executable file
96
ldm/guidance.py
Executable file
@@ -0,0 +1,96 @@
|
||||
from typing import List, Tuple
|
||||
from scipy import interpolate
|
||||
import numpy as np
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
from IPython.display import clear_output
|
||||
import abc
|
||||
|
||||
|
||||
class GuideModel(torch.nn.Module, abc.ABC):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@abc.abstractmethod
|
||||
def preprocess(self, x_img):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def compute_loss(self, inp):
|
||||
pass
|
||||
|
||||
|
||||
class Guider(torch.nn.Module):
|
||||
def __init__(self, sampler, guide_model, scale=1.0, verbose=False):
|
||||
"""Apply classifier guidance
|
||||
|
||||
Specify a guidance scale as either a scalar
|
||||
Or a schedule as a list of tuples t = 0->1 and scale, e.g.
|
||||
[(0, 10), (0.5, 20), (1, 50)]
|
||||
"""
|
||||
super().__init__()
|
||||
self.sampler = sampler
|
||||
self.index = 0
|
||||
self.show = verbose
|
||||
self.guide_model = guide_model
|
||||
self.history = []
|
||||
|
||||
if isinstance(scale, (Tuple, List)):
|
||||
times = np.array([x[0] for x in scale])
|
||||
values = np.array([x[1] for x in scale])
|
||||
self.scale_schedule = {"times": times, "values": values}
|
||||
else:
|
||||
self.scale_schedule = float(scale)
|
||||
|
||||
self.ddim_timesteps = sampler.ddim_timesteps
|
||||
self.ddpm_num_timesteps = sampler.ddpm_num_timesteps
|
||||
|
||||
|
||||
def get_scales(self):
|
||||
if isinstance(self.scale_schedule, float):
|
||||
return len(self.ddim_timesteps)*[self.scale_schedule]
|
||||
|
||||
interpolater = interpolate.interp1d(self.scale_schedule["times"], self.scale_schedule["values"])
|
||||
fractional_steps = np.array(self.ddim_timesteps)/self.ddpm_num_timesteps
|
||||
return interpolater(fractional_steps)
|
||||
|
||||
def modify_score(self, model, e_t, x, t, c):
|
||||
|
||||
# TODO look up index by t
|
||||
scale = self.get_scales()[self.index]
|
||||
|
||||
if (scale == 0):
|
||||
return e_t
|
||||
|
||||
sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device)
|
||||
with torch.enable_grad():
|
||||
x_in = x.detach().requires_grad_(True)
|
||||
pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t)
|
||||
x_img = model.first_stage_model.decode((1/0.18215)*pred_x0)
|
||||
|
||||
inp = self.guide_model.preprocess(x_img)
|
||||
loss = self.guide_model.compute_loss(inp)
|
||||
grads = torch.autograd.grad(loss.sum(), x_in)[0]
|
||||
correction = grads * scale
|
||||
|
||||
if self.show:
|
||||
clear_output(wait=True)
|
||||
print(loss.item(), scale, correction.abs().max().item(), e_t.abs().max().item())
|
||||
self.history.append([loss.item(), scale, correction.min().item(), correction.max().item()])
|
||||
plt.imshow((inp[0].detach().permute(1,2,0).clamp(-1,1).cpu()+1)/2)
|
||||
plt.axis('off')
|
||||
plt.show()
|
||||
plt.imshow(correction[0][0].detach().cpu())
|
||||
plt.axis('off')
|
||||
plt.show()
|
||||
|
||||
|
||||
e_t_mod = e_t - sqrt_1ma*correction
|
||||
if self.show:
|
||||
fig, axs = plt.subplots(1, 3)
|
||||
axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2)
|
||||
axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2)
|
||||
axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2)
|
||||
plt.show()
|
||||
self.index += 1
|
||||
return e_t_mod
|
||||
Reference in New Issue
Block a user