first commit
This commit is contained in:
29
.gitignore
vendored
Normal file
29
.gitignore
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
__pycache__/
|
||||
venv*
|
||||
workspace/
|
||||
out*
|
||||
saved_out*
|
||||
all_outputs*
|
||||
shap_e_model_cache/*
|
||||
slurm_logs/
|
||||
debug/
|
||||
notinclude/
|
||||
scripts/snap/yamls
|
||||
|
||||
# */validataion
|
||||
*csv
|
||||
build/
|
||||
*.egg-info/
|
||||
*.so
|
||||
.vscode/
|
||||
|
||||
tmp*
|
||||
data/
|
||||
trial*/
|
||||
.vs/
|
||||
|
||||
TOKEN
|
||||
*.ckpt
|
||||
*.pt
|
||||
densegridencoder
|
||||
tets/256_tets.npz
|
||||
214
LICENSE
214
LICENSE
@@ -1,21 +1,201 @@
|
||||
MIT License
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
Copyright (c) 2023 Gordon Guocheng Qian 钱国成
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
1. Definitions.
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
21
activation.py
Normal file
21
activation.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import torch
|
||||
from torch.autograd import Function
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
class _trunc_exp(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float)
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return torch.exp(x)
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, g):
|
||||
x = ctx.saved_tensors[0]
|
||||
return g * torch.exp(x.clamp(max=15))
|
||||
|
||||
trunc_exp = _trunc_exp.apply
|
||||
|
||||
def biased_softplus(x, bias=0):
|
||||
return torch.nn.functional.softplus(x - bias)
|
||||
459
all_metrics/metric_utils.py
Executable file
459
all_metrics/metric_utils.py
Executable file
@@ -0,0 +1,459 @@
|
||||
# * evaluate use laion/CLIP-ViT-H-14-laion2B-s32B-b79K
|
||||
# best open source clip so far: laion/CLIP-ViT-bigG-14-laion2B-39B-b160k
|
||||
# code adapted from NeuralLift-360
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import os
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as TF
|
||||
import matplotlib.pyplot as plt
|
||||
# import clip
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer, CLIPProcessor
|
||||
from torchvision import transforms
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from tqdm import tqdm
|
||||
import cv2
|
||||
from PIL import Image
|
||||
# import torchvision.transforms as transforms
|
||||
import glob
|
||||
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
|
||||
import lpips
|
||||
from os.path import join as osp
|
||||
import argparse
|
||||
import pandas as pd
|
||||
|
||||
class CLIP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
device,
|
||||
clip_name='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k',
|
||||
size=224): #'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'):
|
||||
super().__init__()
|
||||
self.size = size
|
||||
self.device = f"cuda:{device}"
|
||||
|
||||
clip_name = clip_name
|
||||
|
||||
self.feature_extractor = CLIPFeatureExtractor.from_pretrained(
|
||||
clip_name)
|
||||
self.clip_model = CLIPModel.from_pretrained(clip_name).to(self.device)
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(
|
||||
'openai/clip-vit-base-patch32')
|
||||
|
||||
self.normalize = transforms.Normalize(
|
||||
mean=self.feature_extractor.image_mean,
|
||||
std=self.feature_extractor.image_std)
|
||||
|
||||
self.resize = transforms.Resize(224)
|
||||
self.to_tensor = transforms.ToTensor()
|
||||
|
||||
# image augmentation
|
||||
self.aug = T.Compose([
|
||||
T.Resize((224, 224)),
|
||||
T.Normalize((0.48145466, 0.4578275, 0.40821073),
|
||||
(0.26862954, 0.26130258, 0.27577711)),
|
||||
])
|
||||
|
||||
# * recommend to use this function for evaluation
|
||||
@torch.no_grad()
|
||||
def score_gt(self, ref_img_path, novel_views):
|
||||
# assert len(novel_views) == 100
|
||||
clip_scores = []
|
||||
for novel in novel_views:
|
||||
clip_scores.append(self.score_from_path(ref_img_path, [novel]))
|
||||
return np.mean(clip_scores)
|
||||
|
||||
# * recommend to use this function for evaluation
|
||||
# def score_gt(self, ref_paths, novel_paths):
|
||||
# clip_scores = []
|
||||
# for img1_path, img2_path in zip(ref_paths, novel_paths):
|
||||
# clip_scores.append(self.score_from_path(img1_path, img2_path))
|
||||
|
||||
# return np.mean(clip_scores)
|
||||
|
||||
def similarity(self, image1_features: torch.Tensor,
|
||||
image2_features: torch.Tensor) -> float:
|
||||
with torch.no_grad(), torch.cuda.amp.autocast():
|
||||
y = image1_features.T.view(image1_features.T.shape[1],
|
||||
image1_features.T.shape[0])
|
||||
similarity = torch.matmul(y, image2_features.T)
|
||||
# print(similarity)
|
||||
return similarity[0][0].item()
|
||||
|
||||
def get_img_embeds(self, img):
|
||||
if img.shape[0] == 4:
|
||||
img = img[:3, :, :]
|
||||
|
||||
img = self.aug(img).to(self.device)
|
||||
img = img.unsqueeze(0) # b,c,h,w
|
||||
|
||||
# plt.imshow(img.cpu().squeeze(0).permute(1, 2, 0).numpy())
|
||||
# plt.show()
|
||||
# print(img)
|
||||
|
||||
image_z = self.clip_model.get_image_features(img)
|
||||
image_z = image_z / image_z.norm(dim=-1,
|
||||
keepdim=True) # normalize features
|
||||
return image_z
|
||||
|
||||
def score_from_feature(self, img1, img2):
|
||||
img1_feature, img2_feature = self.get_img_embeds(
|
||||
img1), self.get_img_embeds(img2)
|
||||
# for debug
|
||||
return self.similarity(img1_feature, img2_feature)
|
||||
|
||||
def read_img_list(self, img_list):
|
||||
size = self.size
|
||||
images = []
|
||||
# white_background = np.ones((size, size, 3), dtype=np.uint8) * 255
|
||||
|
||||
for img_path in img_list:
|
||||
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
||||
# print(img_path)
|
||||
if img.shape[2] == 4: # Handle BGRA images
|
||||
alpha = img[:, :, 3] # Extract alpha channel
|
||||
img = cv2.cvtColor(img,cv2.COLOR_BGRA2RGB) # Convert BGRA to BGR
|
||||
img[np.where(alpha == 0)] = [
|
||||
255, 255, 255
|
||||
] # Set transparent pixels to white
|
||||
else: # Handle other image formats like JPG and PNG
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
|
||||
|
||||
# plt.imshow(img)
|
||||
# plt.show()
|
||||
|
||||
images.append(img)
|
||||
|
||||
images = np.stack(images, axis=0)
|
||||
# images[np.where(images == 0)] = 255 # Set black pixels to white
|
||||
# images = np.where(images == 0, white_background, images) # Set transparent pixels to white
|
||||
# images = images.astype(np.float32)
|
||||
|
||||
return images
|
||||
|
||||
def score_from_path(self, img1_path, img2_path):
|
||||
img1, img2 = self.read_img_list(img1_path), self.read_img_list(img2_path)
|
||||
img1 = np.squeeze(img1)
|
||||
img2 = np.squeeze(img2)
|
||||
# plt.imshow(img1)
|
||||
# plt.show()
|
||||
# plt.imshow(img2)
|
||||
# plt.show()
|
||||
|
||||
img1, img2 = self.to_tensor(img1), self.to_tensor(img2)
|
||||
# print("img1 to tensor ",img1)
|
||||
return self.score_from_feature(img1, img2)
|
||||
|
||||
|
||||
def numpy_to_torch(images):
|
||||
images = images * 2.0 - 1.0
|
||||
images = torch.from_numpy(images.transpose((0, 3, 1, 2))).float()
|
||||
return images.cuda()
|
||||
|
||||
|
||||
class LPIPSMeter:
|
||||
|
||||
def __init__(self,
|
||||
net='alex',
|
||||
device=None,
|
||||
size=224): # or we can use 'alex', 'vgg' as network
|
||||
self.size = size
|
||||
self.net = net
|
||||
self.results = []
|
||||
self.device = device if device is not None else torch.device(
|
||||
'cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.fn = lpips.LPIPS(net=net).eval().to(self.device)
|
||||
|
||||
def measure(self):
|
||||
return np.mean(self.results)
|
||||
|
||||
def report(self):
|
||||
return f'LPIPS ({self.net}) = {self.measure():.6f}'
|
||||
|
||||
def read_img_list(self, img_list):
|
||||
size = self.size
|
||||
images = []
|
||||
white_background = np.ones((size, size, 3), dtype=np.uint8) * 255
|
||||
|
||||
for img_path in img_list:
|
||||
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
||||
|
||||
if img.shape[2] == 4: # Handle BGRA images
|
||||
alpha = img[:, :, 3] # Extract alpha channel
|
||||
img = cv2.cvtColor(img,
|
||||
cv2.COLOR_BGRA2BGR) # Convert BGRA to BGR
|
||||
|
||||
img = cv2.cvtColor(img,
|
||||
cv2.COLOR_BGR2RGB) # Convert BGR to RGB
|
||||
img[np.where(alpha == 0)] = [
|
||||
255, 255, 255
|
||||
] # Set transparent pixels to white
|
||||
else: # Handle other image formats like JPG and PNG
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
|
||||
images.append(img)
|
||||
|
||||
images = np.stack(images, axis=0)
|
||||
# images[np.where(images == 0)] = 255 # Set black pixels to white
|
||||
# images = np.where(images == 0, white_background, images) # Set transparent pixels to white
|
||||
images = images.astype(np.float32) / 255.0
|
||||
|
||||
return images
|
||||
|
||||
# * recommend to use this function for evaluation
|
||||
@torch.no_grad()
|
||||
def score_gt(self, ref_paths, novel_paths):
|
||||
self.results = []
|
||||
for path0, path1 in zip(ref_paths, novel_paths):
|
||||
# Load images
|
||||
# img0 = lpips.im2tensor(lpips.load_image(path0)).cuda() # RGB image from [-1,1]
|
||||
# img1 = lpips.im2tensor(lpips.load_image(path1)).cuda()
|
||||
img0, img1 = self.read_img_list([path0]), self.read_img_list(
|
||||
[path1])
|
||||
img0, img1 = numpy_to_torch(img0), numpy_to_torch(img1)
|
||||
# print(img0.shape,img1.shape)
|
||||
img0 = F.interpolate(img0,
|
||||
size=(self.size, self.size),
|
||||
mode='area')
|
||||
img1 = F.interpolate(img1,
|
||||
size=(self.size, self.size),
|
||||
mode='area')
|
||||
|
||||
# for debug vis
|
||||
# plt.imshow(img0.cpu().squeeze(0).permute(1, 2, 0).numpy())
|
||||
# plt.show()
|
||||
# plt.imshow(img1.cpu().squeeze(0).permute(1, 2, 0).numpy())
|
||||
# plt.show()
|
||||
# equivalent to cv2.resize(rgba, (w, h), interpolation=cv2.INTER_AREA
|
||||
|
||||
# print(img0.shape,img1.shape)
|
||||
|
||||
self.results.append(self.fn.forward(img0, img1).cpu().numpy())
|
||||
|
||||
return self.measure()
|
||||
|
||||
|
||||
class PSNRMeter:
|
||||
|
||||
def __init__(self, size=800):
|
||||
self.results = []
|
||||
self.size = size
|
||||
|
||||
def read_img_list(self, img_list):
|
||||
size = self.size
|
||||
images = []
|
||||
white_background = np.ones((size, size, 3), dtype=np.uint8) * 255
|
||||
for img_path in img_list:
|
||||
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
||||
|
||||
if img.shape[2] == 4: # Handle BGRA images
|
||||
alpha = img[:, :, 3] # Extract alpha channel
|
||||
img = cv2.cvtColor(img,
|
||||
cv2.COLOR_BGRA2BGR) # Convert BGRA to BGR
|
||||
|
||||
img = cv2.cvtColor(img,
|
||||
cv2.COLOR_BGR2RGB) # Convert BGR to RGB
|
||||
img[np.where(alpha == 0)] = [
|
||||
255, 255, 255
|
||||
] # Set transparent pixels to white
|
||||
else: # Handle other image formats like JPG and PNG
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
|
||||
images.append(img)
|
||||
|
||||
images = np.stack(images, axis=0)
|
||||
# images[np.where(images == 0)] = 255 # Set black pixels to white
|
||||
# images = np.where(images == 0, white_background, images) # Set transparent pixels to white
|
||||
images = images.astype(np.float32) / 255.0
|
||||
# print(images.shape)
|
||||
return images
|
||||
|
||||
def update(self, preds, truths):
|
||||
# print(preds.shape)
|
||||
|
||||
psnr_values = []
|
||||
# For each pair of images in the batches
|
||||
for img1, img2 in zip(preds, truths):
|
||||
# Compute the PSNR and add it to the list
|
||||
# print(img1.shape,img2.shape)
|
||||
|
||||
# for debug
|
||||
# plt.imshow(img1)
|
||||
# plt.show()
|
||||
# plt.imshow(img2)
|
||||
# plt.show()
|
||||
|
||||
psnr = compare_psnr(
|
||||
img1, img2,
|
||||
data_range=1.0) # assuming your images are scaled to [0,1]
|
||||
# print(f"temp psnr {psnr}")
|
||||
psnr_values.append(psnr)
|
||||
|
||||
# Convert the list of PSNR values to a numpy array
|
||||
self.results = psnr_values
|
||||
|
||||
def measure(self):
|
||||
return np.mean(self.results)
|
||||
|
||||
def report(self):
|
||||
return f'PSNR = {self.measure():.6f}'
|
||||
|
||||
# * recommend to use this function for evaluation
|
||||
def score_gt(self, ref_paths, novel_paths):
|
||||
self.results = []
|
||||
# [B, N, 3] or [B, H, W, 3], range[0, 1]
|
||||
preds = self.read_img_list(ref_paths)
|
||||
truths = self.read_img_list(novel_paths)
|
||||
self.update(preds, truths)
|
||||
return self.measure()
|
||||
|
||||
all_inputs = 'data'
|
||||
nerf_dataset = os.listdir(osp(all_inputs, 'nerf4'))
|
||||
realfusion_dataset = os.listdir(osp(all_inputs, 'realfusion15'))
|
||||
meta_examples = {
|
||||
'nerf4': nerf_dataset,
|
||||
'realfusion15': realfusion_dataset,
|
||||
}
|
||||
all_datasets = meta_examples.keys()
|
||||
|
||||
# organization 1
|
||||
def deprecated_score_from_method_for_dataset(my_scorer,
|
||||
method,
|
||||
dataset,
|
||||
input,
|
||||
output,
|
||||
score_type='clip',
|
||||
): # psnr, lpips
|
||||
# print("\n\n\n")
|
||||
# print(f"______{method}___{dataset}___{score_type}_________")
|
||||
scores = {}
|
||||
final_res = 0
|
||||
examples = meta_examples[dataset]
|
||||
for i in range(len(examples)):
|
||||
|
||||
# compare entire folder for clip
|
||||
if score_type == 'clip':
|
||||
novel_view = osp(pred_path, examples[i], 'colors')
|
||||
# compare first image for other metrics
|
||||
else:
|
||||
if method == '3d_fuse': method = '3d_fuse_0'
|
||||
novel_view = list(
|
||||
glob.glob(
|
||||
osp(pred_path, examples[i], 'colors',
|
||||
'step_0000*')))[0]
|
||||
|
||||
score_i = my_scorer.score_gt(
|
||||
[], [novel_view])
|
||||
scores[examples[i]] = score_i
|
||||
final_res += score_i
|
||||
# print(scores, " Avg : ", final_res / len(examples))
|
||||
# print("``````````````````````")
|
||||
return scores
|
||||
|
||||
# results organization 2
|
||||
def score_from_method_for_dataset(my_scorer,
|
||||
input_path,
|
||||
pred_path,
|
||||
score_type='clip',
|
||||
rgb_name='lambertian',
|
||||
result_folder='results/images',
|
||||
first_str='*0000*'
|
||||
): # psnr, lpips
|
||||
scores = {}
|
||||
final_res = 0
|
||||
examples = os.listdir(input_path)
|
||||
for i in range(len(examples)):
|
||||
# ref path
|
||||
ref_path = osp(input_path, examples[i], 'rgba.png')
|
||||
# compare entire folder for clip
|
||||
if score_type == 'clip':
|
||||
novel_view = glob.glob(osp(pred_path,'*'+examples[i]+'*', result_folder, f'*{rgb_name}*'))
|
||||
print(f'[INOF] {score_type} loss for example {examples[i]} between 1 GT and {len(novel_view)} predictions')
|
||||
# compare first image for other metrics
|
||||
else:
|
||||
novel_view = glob.glob(osp(pred_path, '*'+examples[i]+'*/', result_folder, f'{first_str}{rgb_name}*'))
|
||||
print(f'[INOF] {score_type} loss for example {examples[i]} between {ref_path} and {novel_view}')
|
||||
# breakpoint()
|
||||
score_i = my_scorer.score_gt([ref_path], novel_view)
|
||||
scores[examples[i]] = score_i
|
||||
final_res += score_i
|
||||
avg_score = final_res / len(examples)
|
||||
scores['average'] = avg_score
|
||||
return scores
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Script to accept three string arguments")
|
||||
parser.add_argument("--input_path",
|
||||
default=all_inputs,
|
||||
help="Specify the input path")
|
||||
parser.add_argument("--pred_pattern",
|
||||
default="out/magic123*",
|
||||
help="Specify the pattern of predition paths")
|
||||
parser.add_argument("--results_folder",
|
||||
default="results/images",
|
||||
help="where are the results under each pred_path")
|
||||
parser.add_argument("--rgb_name",
|
||||
default="lambertian",
|
||||
help="the postfix of the image")
|
||||
parser.add_argument("--first_str",
|
||||
default="*0000*",
|
||||
help="the str to indicate the first view")
|
||||
parser.add_argument("--datasets",
|
||||
default=all_datasets,
|
||||
nargs='*',
|
||||
help="Specify the output path")
|
||||
parser.add_argument("--device",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Specify the GPU device to be used")
|
||||
parser.add_argument("--save_dir", type=str, default='all_metrics/results')
|
||||
args = parser.parse_args()
|
||||
|
||||
clip_scorer = CLIP(args.device)
|
||||
lpips_scorer = LPIPSMeter()
|
||||
psnr_scorer = PSNRMeter()
|
||||
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
|
||||
for dataset in args.datasets:
|
||||
input_path = osp(args.input_path, dataset)
|
||||
|
||||
# assume the pred_path is organized as: pred_path/methods/dataset
|
||||
pred_pattern = osp(args.pred_pattern, dataset)
|
||||
pred_paths = glob.glob(pred_pattern)
|
||||
print(f"[INFO] Following the pattern {pred_pattern}, find {len(pred_paths)} pred_paths: \n", pred_paths)
|
||||
if len(pred_paths) == 0:
|
||||
raise IOError
|
||||
for pred_path in pred_paths:
|
||||
if not os.path.exists(pred_path):
|
||||
print(f'[WARN] prediction does not exit for {pred_path}')
|
||||
else:
|
||||
print(f'[INFO] evaluate {pred_path}')
|
||||
results_dict = {}
|
||||
results_dict['clip'] = score_from_method_for_dataset(
|
||||
clip_scorer, input_path, pred_path, 'clip',
|
||||
result_folder=args.results_folder, rgb_name=args.rgb_name, first_str=args.first_str)
|
||||
|
||||
results_dict['psnr'] = score_from_method_for_dataset(
|
||||
psnr_scorer, input_path, pred_path, 'psnr',
|
||||
result_folder=args.results_folder, rgb_name=args.rgb_name, first_str=args.first_str)
|
||||
|
||||
results_dict['lpips'] = score_from_method_for_dataset(
|
||||
lpips_scorer, input_path, pred_path, 'lpips',
|
||||
result_folder=args.results_folder, rgb_name=args.rgb_name, first_str=args.first_str)
|
||||
|
||||
df = pd.DataFrame(results_dict)
|
||||
method = pred_path.split('/')[-2]
|
||||
print(osp(pred_path, args.results_folder))
|
||||
results_str = '_'.join(args.results_folder.split('/'))
|
||||
print(method+'-'+results_str)
|
||||
print(df)
|
||||
df.to_csv(f"{args.save_dir}/{method}-{results_str}-{dataset}.csv")
|
||||
13
all_metrics/test.sh
Executable file
13
all_metrics/test.sh
Executable file
@@ -0,0 +1,13 @@
|
||||
# python all_metrics/metric_utils.py --datasets nerf4 realfusion15 --pred_pattern "all_outputs/magic123/magic123-2d1-3d30-dmtet" --results_folder "results/images"
|
||||
|
||||
# 2d only
|
||||
python all_metrics/metric_utils.py --datasets nerf4 realfusion15 --pred_pattern "all_outputs/magic123-2d/magic123-2d1*" --results_folder "results/images"
|
||||
|
||||
# 3d only
|
||||
python all_metrics/metric_utils.py --datasets nerf4 realfusion15 --pred_pattern "all_outputs/magic123-3d/zero123-z40*" --results_folder "results/images"
|
||||
|
||||
# python all_metrics/metric_utils.py --datasets nerf4 realfusion15 --pred_pattern "all_outputs/3d_fuse" --results_folder "color" --rgb_name "" --first_str "*_0_*"
|
||||
# python all_metrics/metric_utils.py --datasets nerf4 realfusion15 --pred_pattern "all_outputs/neural_lift" --results_folder "albedo" --rgb_name "albedo" --first_str "*0000*"
|
||||
# python all_metrics/metric_utils.py --datasets nerf4 realfusion15 --pred_pattern "all_outputs/real_fusion" --results_folder "colors" --rgb_name ""
|
||||
# python all_metrics/metric_utils.py --datasets nerf4 realfusion15 --pred_pattern "all_outputs/shape" --results_folder "" --rgb_name "" --first_str "0."
|
||||
# python all_metrics/metric_utils.py --datasets realfusion15 --pred_pattern "all_outputs/pointe-r100" --results_folder "" --rgb_name "" --first_str "0."
|
||||
71
assets/advanced.md
Normal file
71
assets/advanced.md
Normal file
@@ -0,0 +1,71 @@
|
||||
|
||||
# Code organization & Advanced tips
|
||||
|
||||
This is a simple description of the most important implementation details.
|
||||
If you are interested in improving this repo, this might be a starting point.
|
||||
Any contribution would be greatly appreciated!
|
||||
|
||||
* The SDS loss is located at `./guidance/sd_utils.py > StableDiffusion > train_step`:
|
||||
```python
|
||||
## 1. we need to interpolate the NeRF rendering to 512x512, to feed it to SD's VAE.
|
||||
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
|
||||
## 2. image (512x512) --- VAE --> latents (64x64), this is SD's difference from Imagen.
|
||||
latents = self.encode_imgs(pred_rgb_512)
|
||||
... # timestep sampling, noise adding and UNet noise predicting
|
||||
## 3. the SDS loss
|
||||
w = (1 - self.alphas[t])
|
||||
grad = w * (noise_pred - noise)
|
||||
# since UNet part is ignored and cannot simply audodiff, we have two ways to set the grad:
|
||||
# 3.1. call backward and set the grad now (need to retain graph since we will call a second backward for the other losses later)
|
||||
latents.backward(gradient=grad, retain_graph=True)
|
||||
return 0 # dummy loss
|
||||
# 3.2. use a custom function to set a hook in backward, so we only call backward once (credits to @elliottzheng)
|
||||
class SpecifyGradient(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, input_tensor, gt_grad):
|
||||
ctx.save_for_backward(gt_grad)
|
||||
# we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
|
||||
return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_scale):
|
||||
gt_grad, = ctx.saved_tensors
|
||||
gt_grad = gt_grad * grad_scale
|
||||
return gt_grad, None
|
||||
|
||||
loss = SpecifyGradient.apply(latents, grad)
|
||||
return loss # functional loss
|
||||
```
|
||||
* Other regularizations are in `./nerf/utils.py > Trainer > train_step`.
|
||||
* The generation seems quite sensitive to regularizations on weights_sum (alphas for each ray). The original opacity loss tends to make NeRF disappear (zero density everywhere), so we use an entropy loss to replace it for now (encourages alpha to be either 0 or 1).
|
||||
* NeRF Rendering core function: `./nerf/renderer.py > NeRFRenderer > run & run_cuda`.
|
||||
* Shading & normal evaluation: `./nerf/network*.py > NeRFNetwork > forward`.
|
||||
* light direction: current implementation use a plane light source, instead of a point light source.
|
||||
* View-dependent prompting: `./nerf/provider.py > get_view_direction`.
|
||||
* use `--angle_overhead, --angle_front` to set the border.
|
||||
* Network backbone (`./nerf/network*.py`) can be chosen by the `--backbone` option.
|
||||
* Spatial density bias (density blob): `./nerf/network*.py > NeRFNetwork > density_blob`.
|
||||
|
||||
|
||||
# Debugging
|
||||
|
||||
`debugpy-run` is a convenient way to remotely debug this project. Simply replace a command like this one:
|
||||
|
||||
```bash
|
||||
python main.py --text "a hamburger" --workspace trial -O --vram_O
|
||||
```
|
||||
|
||||
... with:
|
||||
|
||||
```bash
|
||||
debugpy-run main.py -- --text "a hamburger" --workspace trial -O --vram_O
|
||||
```
|
||||
|
||||
For more details: https://github.com/bulletmark/debugpy-run
|
||||
|
||||
# Axes and directions of polar, azimuth, etc. in NeRF and Zero123
|
||||
|
||||
<img width="1119" alt="NeRF_Zero123" src="https://github.com/ashawkey/stable-dreamfusion/assets/22424247/a0f432ff-2d08-45a4-a390-bda64f5cbc94">
|
||||
|
||||
39
assets/update_logs.md
Normal file
39
assets/update_logs.md
Normal file
@@ -0,0 +1,39 @@
|
||||
### 2023.4.19
|
||||
* Fix depth supervision, migrate depth estimation model to omnidata.
|
||||
* Add normal supervision (also by omnidata).
|
||||
|
||||
https://user-images.githubusercontent.com/25863658/232403294-b77409bf-ddc7-4bb8-af32-ee0cc123825a.mp4
|
||||
|
||||
### 2023.4.7
|
||||
Improvement on mesh quality & DMTet finetuning support.
|
||||
|
||||
https://user-images.githubusercontent.com/25863658/230535363-298c960e-bf9c-4906-8b96-cd60edcb24dd.mp4
|
||||
|
||||
### 2023.3.30
|
||||
* adopt ideas from [Fantasia3D](https://fantasia3d.github.io/) to concatenate normal and mask as the latent code in a warm up stage, which shows faster convergence of shape.
|
||||
|
||||
https://user-images.githubusercontent.com/25863658/230535373-6ee28f16-bb21-4ec4-bc86-d46597361a04.mp4
|
||||
|
||||
### 2023.1.30
|
||||
* Use an MLP to predict the surface normals as in Magic3D to avoid finite difference / second order gradient, generation quality is greatly improved.
|
||||
* More efficient two-pass raymarching in training inspired by nerfacc.
|
||||
|
||||
https://user-images.githubusercontent.com/25863658/215996308-9fd959f5-b5c7-4a8e-a241-0fe63ec86a4a.mp4
|
||||
|
||||
### 2022.12.3
|
||||
* Support Stable-diffusion 2.0 base.
|
||||
|
||||
### 2022.11.15
|
||||
* Add the vanilla backbone that is pure-pytorch.
|
||||
|
||||
### 2022.10.9
|
||||
* The shading (partially) starts to work, at least it won't make scene empty. For some prompts, it shows better results (less severe Janus problem). The textureless rendering mode is still disabled.
|
||||
* Enable shading by default (--latent_iter_ratio 1000).
|
||||
|
||||
### 2022.10.5
|
||||
* Basic reproduction finished.
|
||||
* Non --cuda_ray, --tcnn are not working, need to fix.
|
||||
* Shading is not working, disabled in utils.py for now. Surface normals are bad.
|
||||
* Use an entropy loss to regularize weights_sum (alpha), the original L2 reg always leads to degenerated geometry...
|
||||
|
||||
https://user-images.githubusercontent.com/25863658/194241493-f3e68f78-aefe-479e-a4a8-001424a61b37.mp4
|
||||
16
dnnultis/REAMD.me
Normal file
16
dnnultis/REAMD.me
Normal file
@@ -0,0 +1,16 @@
|
||||
# dnnutils
|
||||
dnnutils is a simple library for the commonly used utils for DNN research including:
|
||||
1. log (log/experiment directory, logging, tensorboard, wandb)
|
||||
2.
|
||||
|
||||
|
||||
|
||||
dnnutils is designed to be compitabe for timm.
|
||||
|
||||
|
||||
# Install
|
||||
|
||||
|
||||
|
||||
# TODO
|
||||
configs
|
||||
1
dnnultis/__init__.py
Normal file
1
dnnultis/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .log import *
|
||||
2
dnnultis/log/__init__.py
Normal file
2
dnnultis/log/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .logger import *
|
||||
from .wandb import *
|
||||
86
dnnultis/log/logger.py
Normal file
86
dnnultis/log/logger.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import functools
|
||||
import logging
|
||||
import os, sys
|
||||
from termcolor import colored
|
||||
|
||||
|
||||
class _ColorfulFormatter(logging.Formatter):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._root_name = kwargs.pop("root_name") + "."
|
||||
self._abbrev_name = kwargs.pop("abbrev_name", "")
|
||||
if len(self._abbrev_name):
|
||||
self._abbrev_name = self._abbrev_name + "."
|
||||
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
|
||||
|
||||
def formatMessage(self, record):
|
||||
record.name = record.name.replace(self._root_name, self._abbrev_name)
|
||||
log = super(_ColorfulFormatter, self).formatMessage(record)
|
||||
if record.levelno == logging.WARNING:
|
||||
prefix = colored("WARNING", "red", attrs=["blink"])
|
||||
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
|
||||
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
|
||||
else:
|
||||
return log
|
||||
return prefix + " " + log
|
||||
|
||||
|
||||
# so that calling setup_logging.root multiple times won't add many handlers
|
||||
@functools.lru_cache()
|
||||
def setup_logging(output=None,
|
||||
distributed_rank=0,
|
||||
default_level=logging.INFO,
|
||||
*,
|
||||
color=True,
|
||||
name=__name__):
|
||||
"""
|
||||
Initialize the detectron2 logging.root and set its verbosity level to "INFO".
|
||||
Args:
|
||||
output (str): a file name or a directory to save log. If None, will not save log file.
|
||||
If ends with ".txt" or ".log", assumed to be a file name.
|
||||
Otherwise, logs will be saved to `output/log.txt`.
|
||||
name (str): the root module name of this logging.root
|
||||
Returns:
|
||||
logging.logging.root: a logging.root
|
||||
"""
|
||||
logging.root.setLevel(default_level)
|
||||
logging.root.propagate = False
|
||||
|
||||
plain_formatter = logging.Formatter(
|
||||
"[%(asctime)s] %(name)s %(levelname)s: %(message)s",
|
||||
datefmt="%y/%m/%d %H:%M:%S")
|
||||
# stdout logging: master only
|
||||
if distributed_rank == 0:
|
||||
ch = logging.StreamHandler(stream=sys.stdout)
|
||||
ch.setLevel(default_level)
|
||||
if color:
|
||||
formatter = _ColorfulFormatter(
|
||||
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
|
||||
datefmt="%y/%m/%d %H:%M:%S",
|
||||
root_name=name,
|
||||
)
|
||||
else:
|
||||
formatter = plain_formatter
|
||||
ch.setFormatter(formatter)
|
||||
logging.root.addHandler(ch)
|
||||
|
||||
# file logging: all workers
|
||||
if output is not None:
|
||||
if output.endswith(".txt") or output.endswith(".log"):
|
||||
filename = output
|
||||
else:
|
||||
filename = os.path.join(output, "log.txt")
|
||||
if distributed_rank > 0:
|
||||
filename = filename + f".rank{distributed_rank}"
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
|
||||
fh = logging.StreamHandler(_cached_log_stream(filename))
|
||||
fh.setLevel(default_level)
|
||||
fh.setFormatter(plain_formatter)
|
||||
logging.root.addHandler(fh)
|
||||
|
||||
|
||||
# cache the opened file object, so that different calls to `setup_logging.root`
|
||||
# with the same file name can safely write to the same file.
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _cached_log_stream(filename):
|
||||
return open(filename, "a")
|
||||
84
dnnultis/log/wandb.py
Normal file
84
dnnultis/log/wandb.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import shutil
|
||||
import os
|
||||
import subprocess
|
||||
import wandb
|
||||
|
||||
|
||||
class WandbUrls:
|
||||
def __init__(self, url):
|
||||
|
||||
hash = url.split("/")[-2]
|
||||
project = url.split("/")[-3]
|
||||
entity = url.split("/")[-4]
|
||||
|
||||
self.weight_url = url
|
||||
self.log_url = "https://app.wandb.ai/{}/{}/runs/{}/logs".format(entity, project, hash)
|
||||
self.chart_url = "https://app.wandb.ai/{}/{}/runs/{}".format(entity, project, hash)
|
||||
self.overview_url = "https://app.wandb.ai/{}/{}/runs/{}/overview".format(entity, project, hash)
|
||||
self.config_url = "https://app.wandb.ai/{}/{}/runs/{}/files/hydra-config.yaml".format(
|
||||
entity, project, hash
|
||||
)
|
||||
self.overrides_url = "https://app.wandb.ai/{}/{}/runs/{}/files/overrides.yaml".format(entity, project, hash)
|
||||
|
||||
def __repr__(self):
|
||||
msg = "=================================================== WANDB URLS ===================================================================\n"
|
||||
for k, v in self.__dict__.items():
|
||||
msg += "{}: {}\n".format(k.upper(), v)
|
||||
msg += "=================================================================================================================================\n"
|
||||
return msg
|
||||
|
||||
|
||||
class Wandb:
|
||||
IS_ACTIVE = False
|
||||
|
||||
@staticmethod
|
||||
def set_urls_to_model(model, url):
|
||||
wandb_urls = WandbUrls(url)
|
||||
model.wandb = wandb_urls
|
||||
|
||||
@staticmethod
|
||||
def _set_to_wandb_args(wandb_args, cfg, name):
|
||||
var = getattr(cfg.wandb, name, None)
|
||||
if var:
|
||||
wandb_args[name] = var
|
||||
|
||||
@staticmethod
|
||||
def launch(cfg, launch: bool):
|
||||
if launch:
|
||||
|
||||
Wandb.IS_ACTIVE = True
|
||||
|
||||
wandb_args = {}
|
||||
wandb_args["resume"] = "allow"
|
||||
Wandb._set_to_wandb_args(wandb_args, cfg, "tags")
|
||||
Wandb._set_to_wandb_args(wandb_args, cfg, "project")
|
||||
Wandb._set_to_wandb_args(wandb_args, cfg, "name")
|
||||
Wandb._set_to_wandb_args(wandb_args, cfg, "entity")
|
||||
Wandb._set_to_wandb_args(wandb_args, cfg, "notes")
|
||||
Wandb._set_to_wandb_args(wandb_args, cfg, "config")
|
||||
Wandb._set_to_wandb_args(wandb_args, cfg, "id")
|
||||
|
||||
try:
|
||||
commit_sha = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()
|
||||
gitdiff = subprocess.check_output(["git", "diff", "--", "':!notebooks'"]).decode()
|
||||
except:
|
||||
commit_sha = "n/a"
|
||||
gitdiff = ""
|
||||
|
||||
config = wandb_args.get("config", {})
|
||||
wandb_args["config"] = {
|
||||
**config,
|
||||
"run_path": os.getcwd(),
|
||||
"commit": commit_sha,
|
||||
"gitdiff": gitdiff
|
||||
}
|
||||
wandb.init(**wandb_args, sync_tensorboard=True)
|
||||
wandb.save(os.path.join(os.getcwd(), cfg.cfg_path))
|
||||
|
||||
@staticmethod
|
||||
def add_file(file_path: str):
|
||||
if not Wandb.IS_ACTIVE:
|
||||
raise RuntimeError("wandb is inactive, please launch first.")
|
||||
|
||||
filename = os.path.basename(file_path)
|
||||
shutil.copyfile(file_path, os.path.join(wandb.run.dir, filename))
|
||||
53
docker/Dockerfile
Normal file
53
docker/Dockerfile
Normal file
@@ -0,0 +1,53 @@
|
||||
FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04
|
||||
|
||||
# Remove any third-party apt sources to avoid issues with expiring keys.
|
||||
RUN rm -f /etc/apt/sources.list.d/*.list
|
||||
|
||||
RUN apt-get update
|
||||
|
||||
RUN DEBIAN_FRONTEND=noninteractive TZ=Europe/MADRID apt-get install -y tzdata
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get install -y \
|
||||
curl \
|
||||
ca-certificates \
|
||||
sudo \
|
||||
git \
|
||||
bzip2 \
|
||||
libx11-6 \
|
||||
python3 \
|
||||
python3-pip \
|
||||
libglfw3-dev \
|
||||
libgles2-mesa-dev \
|
||||
libglib2.0-0 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
# Create a working directory
|
||||
RUN mkdir /app
|
||||
WORKDIR /app
|
||||
|
||||
RUN cd /app
|
||||
RUN git clone https://github.com/ashawkey/stable-dreamfusion.git
|
||||
|
||||
|
||||
RUN pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
|
||||
|
||||
WORKDIR /app/stable-dreamfusion
|
||||
|
||||
RUN pip3 install -r requirements.txt
|
||||
RUN pip3 install git+https://github.com/NVlabs/nvdiffrast/
|
||||
|
||||
# Needs nvidia runtime, if you have "No CUDA runtime is found" error: https://stackoverflow.com/questions/59691207/docker-build-with-nvidia-runtime, first answer
|
||||
RUN pip3 install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
|
||||
|
||||
RUN pip3 install git+https://github.com/openai/CLIP.git
|
||||
RUN bash scripts/install_ext.sh
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# Set the default command to python3
|
||||
#CMD ["python3"]
|
||||
|
||||
80
docker/README.md
Normal file
80
docker/README.md
Normal file
@@ -0,0 +1,80 @@
|
||||
### Docker installation
|
||||
|
||||
## Build image
|
||||
To build the docker image on your own machine, which may take 15-30 mins:
|
||||
```
|
||||
docker build -t stable-dreamfusion:latest .
|
||||
```
|
||||
|
||||
If you have the error **No CUDA runtime is found** when building the wheels for tiny-cuda-nn you need to setup the nvidia-runtime for docker.
|
||||
```
|
||||
sudo apt-get install nvidia-container-runtime
|
||||
```
|
||||
Then edit `/etc/docker/daemon.json` and add the default-runtime:
|
||||
```
|
||||
{
|
||||
"runtimes": {
|
||||
"nvidia": {
|
||||
"path": "nvidia-container-runtime",
|
||||
"runtimeArgs": []
|
||||
}
|
||||
},
|
||||
"default-runtime": "nvidia"
|
||||
}
|
||||
```
|
||||
And restart docker:
|
||||
```
|
||||
sudo systemctl restart docker
|
||||
```
|
||||
Now you can build tiny-cuda-nn inside docker.
|
||||
|
||||
## Download image
|
||||
To download the image (~6GB) instead:
|
||||
```
|
||||
docker pull supercabb/stable-dreamfusion:3080_0.0.1
|
||||
docker tag supercabb/stable-dreamfusion:3080_0.0.1 stable-dreamfusion
|
||||
```
|
||||
|
||||
## Use image
|
||||
|
||||
You can launch an interactive shell inside the container:
|
||||
|
||||
```
|
||||
docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash
|
||||
```
|
||||
From this shell, all the code in the repo should work.
|
||||
|
||||
To run any single command `<command...>` inside the docker container:
|
||||
```
|
||||
docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash -c "<command...>"
|
||||
```
|
||||
To train:
|
||||
```
|
||||
export TOKEN="#HUGGING FACE ACCESS TOKEN#"
|
||||
docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash -c "echo ${TOKEN} > TOKEN \
|
||||
&& python3 main.py --text \"a hamburger\" --workspace trial -O"
|
||||
|
||||
```
|
||||
Run test without gui:
|
||||
```
|
||||
export PATH_TO_WORKSPACE="#PATH_TO_WORKSPACE#"
|
||||
docker run --gpus all -it --rm -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix:ro -v $(cd ~ && pwd):/mnt \
|
||||
-v $(cd ${PATH_TO_WORKSPACE} && pwd):/app/stable-dreamfusion/trial stable-dreamfusion /bin/bash -c "python3 \
|
||||
main.py --workspace trial -O --test"
|
||||
```
|
||||
Run test with gui:
|
||||
```
|
||||
export PATH_TO_WORKSPACE="#PATH_TO_WORKSPACE#"
|
||||
xhost +
|
||||
docker run --gpus all -it --rm -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix:ro -v $(cd ~ && pwd):/mnt \
|
||||
-v $(cd ${PATH_TO_WORKSPACE} && pwd):/app/stable-dreamfusion/trial stable-dreamfusion /bin/bash -c "python3 \
|
||||
main.py --workspace trial -O --test --gui"
|
||||
xhost -
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
BIN
docs/static/ironman-val-magic123.gif
vendored
Normal file
BIN
docs/static/ironman-val-magic123.gif
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.4 MiB |
BIN
docs/static/magic123-results.mp4
vendored
Normal file
BIN
docs/static/magic123-results.mp4
vendored
Normal file
Binary file not shown.
BIN
docs/static/magic123.gif
vendored
Normal file
BIN
docs/static/magic123.gif
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 6.5 MiB |
924
dpt.py
Normal file
924
dpt.py
Normal file
@@ -0,0 +1,924 @@
|
||||
import math
|
||||
import types
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import timm
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def load(self, path):
|
||||
"""Load model from file.
|
||||
Args:
|
||||
path (str): file path
|
||||
"""
|
||||
parameters = torch.load(path, map_location=torch.device('cpu'))
|
||||
|
||||
if "optimizer" in parameters:
|
||||
parameters = parameters["model"]
|
||||
|
||||
self.load_state_dict(parameters)
|
||||
|
||||
|
||||
def unflatten_with_named_tensor(input, dim, sizes):
|
||||
"""Workaround for unflattening with named tensor."""
|
||||
# tracer acts up with unflatten. See https://github.com/pytorch/pytorch/issues/49538
|
||||
new_shape = list(input.shape)[:dim] + list(sizes) + list(input.shape)[dim+1:]
|
||||
return input.view(*new_shape)
|
||||
|
||||
class Slice(nn.Module):
|
||||
def __init__(self, start_index=1):
|
||||
super(Slice, self).__init__()
|
||||
self.start_index = start_index
|
||||
|
||||
def forward(self, x):
|
||||
return x[:, self.start_index :]
|
||||
|
||||
|
||||
class AddReadout(nn.Module):
|
||||
def __init__(self, start_index=1):
|
||||
super(AddReadout, self).__init__()
|
||||
self.start_index = start_index
|
||||
|
||||
def forward(self, x):
|
||||
if self.start_index == 2:
|
||||
readout = (x[:, 0] + x[:, 1]) / 2
|
||||
else:
|
||||
readout = x[:, 0]
|
||||
return x[:, self.start_index :] + readout.unsqueeze(1)
|
||||
|
||||
|
||||
class ProjectReadout(nn.Module):
|
||||
def __init__(self, in_features, start_index=1):
|
||||
super(ProjectReadout, self).__init__()
|
||||
self.start_index = start_index
|
||||
|
||||
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
||||
|
||||
def forward(self, x):
|
||||
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
|
||||
features = torch.cat((x[:, self.start_index :], readout), -1)
|
||||
|
||||
return self.project(features)
|
||||
|
||||
|
||||
class Transpose(nn.Module):
|
||||
def __init__(self, dim0, dim1):
|
||||
super(Transpose, self).__init__()
|
||||
self.dim0 = dim0
|
||||
self.dim1 = dim1
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(self.dim0, self.dim1)
|
||||
return x
|
||||
|
||||
|
||||
def forward_vit(pretrained, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
glob = pretrained.model.forward_flex(x)
|
||||
|
||||
layer_1 = pretrained.activations["1"]
|
||||
layer_2 = pretrained.activations["2"]
|
||||
layer_3 = pretrained.activations["3"]
|
||||
layer_4 = pretrained.activations["4"]
|
||||
|
||||
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
|
||||
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
|
||||
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
|
||||
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
|
||||
|
||||
|
||||
unflattened_dim = 2
|
||||
unflattened_size = (
|
||||
int(torch.div(h, pretrained.model.patch_size[1], rounding_mode='floor')),
|
||||
int(torch.div(w, pretrained.model.patch_size[0], rounding_mode='floor')),
|
||||
)
|
||||
unflatten = nn.Sequential(nn.Unflatten(unflattened_dim, unflattened_size))
|
||||
|
||||
|
||||
if layer_1.ndim == 3:
|
||||
layer_1 = unflatten(layer_1)
|
||||
if layer_2.ndim == 3:
|
||||
layer_2 = unflatten(layer_2)
|
||||
if layer_3.ndim == 3:
|
||||
layer_3 = unflatten_with_named_tensor(layer_3, unflattened_dim, unflattened_size)
|
||||
if layer_4.ndim == 3:
|
||||
layer_4 = unflatten_with_named_tensor(layer_4, unflattened_dim, unflattened_size)
|
||||
|
||||
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
|
||||
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
|
||||
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
|
||||
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
|
||||
|
||||
return layer_1, layer_2, layer_3, layer_4
|
||||
|
||||
|
||||
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
||||
posemb_tok, posemb_grid = (
|
||||
posemb[:, : self.start_index],
|
||||
posemb[0, self.start_index :],
|
||||
)
|
||||
|
||||
gs_old = int(math.sqrt(posemb_grid.shape[0]))
|
||||
|
||||
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
||||
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
|
||||
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
||||
|
||||
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
||||
|
||||
return posemb
|
||||
|
||||
|
||||
def forward_flex(self, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
pos_embed = self._resize_pos_embed(
|
||||
self.pos_embed, torch.div(h, self.patch_size[1], rounding_mode='floor'), torch.div(w, self.patch_size[0], rounding_mode='floor')
|
||||
)
|
||||
|
||||
B = x.shape[0]
|
||||
|
||||
if hasattr(self.patch_embed, "backbone"):
|
||||
x = self.patch_embed.backbone(x)
|
||||
if isinstance(x, (list, tuple)):
|
||||
x = x[-1] # last feature if backbone outputs list/tuple of features
|
||||
|
||||
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
||||
|
||||
if getattr(self, "dist_token", None) is not None:
|
||||
cls_tokens = self.cls_token.expand(
|
||||
B, -1, -1
|
||||
) # stole cls_tokens impl from Phil Wang, thanks
|
||||
dist_token = self.dist_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
||||
else:
|
||||
cls_tokens = self.cls_token.expand(
|
||||
B, -1, -1
|
||||
) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
x = x + pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
activations = {}
|
||||
|
||||
|
||||
def get_activation(name):
|
||||
def hook(model, input, output):
|
||||
activations[name] = output
|
||||
|
||||
return hook
|
||||
|
||||
|
||||
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
||||
if use_readout == "ignore":
|
||||
readout_oper = [Slice(start_index)] * len(features)
|
||||
elif use_readout == "add":
|
||||
readout_oper = [AddReadout(start_index)] * len(features)
|
||||
elif use_readout == "project":
|
||||
readout_oper = [
|
||||
ProjectReadout(vit_features, start_index) for out_feat in features
|
||||
]
|
||||
else:
|
||||
assert (
|
||||
False
|
||||
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
||||
|
||||
return readout_oper
|
||||
|
||||
|
||||
def _make_vit_b16_backbone(
|
||||
model,
|
||||
features=[96, 192, 384, 768],
|
||||
size=[384, 384],
|
||||
hooks=[2, 5, 8, 11],
|
||||
vit_features=768,
|
||||
use_readout="ignore",
|
||||
start_index=1,
|
||||
):
|
||||
pretrained = nn.Module()
|
||||
|
||||
pretrained.model = model
|
||||
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
||||
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
||||
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
||||
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
||||
|
||||
pretrained.activations = activations
|
||||
|
||||
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
||||
|
||||
# 32, 48, 136, 384
|
||||
pretrained.act_postprocess1 = nn.Sequential(
|
||||
readout_oper[0],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[0],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=features[0],
|
||||
out_channels=features[0],
|
||||
kernel_size=4,
|
||||
stride=4,
|
||||
padding=0,
|
||||
bias=True,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.act_postprocess2 = nn.Sequential(
|
||||
readout_oper[1],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[1],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=features[1],
|
||||
out_channels=features[1],
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=True,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.act_postprocess3 = nn.Sequential(
|
||||
readout_oper[2],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[2],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.act_postprocess4 = nn.Sequential(
|
||||
readout_oper[3],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[3],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.Conv2d(
|
||||
in_channels=features[3],
|
||||
out_channels=features[3],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.model.start_index = start_index
|
||||
pretrained.model.patch_size = [16, 16]
|
||||
|
||||
# We inject this function into the VisionTransformer instances so that
|
||||
# we can use it with interpolated position embeddings without modifying the library source.
|
||||
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
||||
pretrained.model._resize_pos_embed = types.MethodType(
|
||||
_resize_pos_embed, pretrained.model
|
||||
)
|
||||
|
||||
return pretrained
|
||||
|
||||
|
||||
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
|
||||
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
|
||||
|
||||
hooks = [5, 11, 17, 23] if hooks == None else hooks
|
||||
return _make_vit_b16_backbone(
|
||||
model,
|
||||
features=[256, 512, 1024, 1024],
|
||||
hooks=hooks,
|
||||
vit_features=1024,
|
||||
use_readout=use_readout,
|
||||
)
|
||||
|
||||
|
||||
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
|
||||
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
|
||||
|
||||
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
||||
return _make_vit_b16_backbone(
|
||||
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
||||
)
|
||||
|
||||
|
||||
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
|
||||
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
|
||||
|
||||
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
||||
return _make_vit_b16_backbone(
|
||||
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
||||
)
|
||||
|
||||
|
||||
def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
|
||||
model = timm.create_model(
|
||||
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
|
||||
)
|
||||
|
||||
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
||||
return _make_vit_b16_backbone(
|
||||
model,
|
||||
features=[96, 192, 384, 768],
|
||||
hooks=hooks,
|
||||
use_readout=use_readout,
|
||||
start_index=2,
|
||||
)
|
||||
|
||||
|
||||
def _make_vit_b_rn50_backbone(
|
||||
model,
|
||||
features=[256, 512, 768, 768],
|
||||
size=[384, 384],
|
||||
hooks=[0, 1, 8, 11],
|
||||
vit_features=768,
|
||||
use_vit_only=False,
|
||||
use_readout="ignore",
|
||||
start_index=1,
|
||||
):
|
||||
pretrained = nn.Module()
|
||||
|
||||
pretrained.model = model
|
||||
|
||||
if use_vit_only == True:
|
||||
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
||||
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
||||
else:
|
||||
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
|
||||
get_activation("1")
|
||||
)
|
||||
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
|
||||
get_activation("2")
|
||||
)
|
||||
|
||||
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
||||
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
||||
|
||||
pretrained.activations = activations
|
||||
|
||||
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
||||
|
||||
if use_vit_only == True:
|
||||
pretrained.act_postprocess1 = nn.Sequential(
|
||||
readout_oper[0],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[0],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=features[0],
|
||||
out_channels=features[0],
|
||||
kernel_size=4,
|
||||
stride=4,
|
||||
padding=0,
|
||||
bias=True,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.act_postprocess2 = nn.Sequential(
|
||||
readout_oper[1],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[1],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=features[1],
|
||||
out_channels=features[1],
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=True,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
),
|
||||
)
|
||||
else:
|
||||
pretrained.act_postprocess1 = nn.Sequential(
|
||||
nn.Identity(), nn.Identity(), nn.Identity()
|
||||
)
|
||||
pretrained.act_postprocess2 = nn.Sequential(
|
||||
nn.Identity(), nn.Identity(), nn.Identity()
|
||||
)
|
||||
|
||||
pretrained.act_postprocess3 = nn.Sequential(
|
||||
readout_oper[2],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[2],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.act_postprocess4 = nn.Sequential(
|
||||
readout_oper[3],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[3],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.Conv2d(
|
||||
in_channels=features[3],
|
||||
out_channels=features[3],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.model.start_index = start_index
|
||||
pretrained.model.patch_size = [16, 16]
|
||||
|
||||
# We inject this function into the VisionTransformer instances so that
|
||||
# we can use it with interpolated position embeddings without modifying the library source.
|
||||
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
||||
|
||||
# We inject this function into the VisionTransformer instances so that
|
||||
# we can use it with interpolated position embeddings without modifying the library source.
|
||||
pretrained.model._resize_pos_embed = types.MethodType(
|
||||
_resize_pos_embed, pretrained.model
|
||||
)
|
||||
|
||||
return pretrained
|
||||
|
||||
|
||||
def _make_pretrained_vitb_rn50_384(
|
||||
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
|
||||
):
|
||||
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
|
||||
|
||||
hooks = [0, 1, 8, 11] if hooks == None else hooks
|
||||
return _make_vit_b_rn50_backbone(
|
||||
model,
|
||||
features=[256, 512, 768, 768],
|
||||
size=[384, 384],
|
||||
hooks=hooks,
|
||||
use_vit_only=use_vit_only,
|
||||
use_readout=use_readout,
|
||||
)
|
||||
|
||||
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
|
||||
if backbone == "vitl16_384":
|
||||
pretrained = _make_pretrained_vitl16_384(
|
||||
use_pretrained, hooks=hooks, use_readout=use_readout
|
||||
)
|
||||
scratch = _make_scratch(
|
||||
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
||||
) # ViT-L/16 - 85.0% Top1 (backbone)
|
||||
elif backbone == "vitb_rn50_384":
|
||||
pretrained = _make_pretrained_vitb_rn50_384(
|
||||
use_pretrained,
|
||||
hooks=hooks,
|
||||
use_vit_only=use_vit_only,
|
||||
use_readout=use_readout,
|
||||
)
|
||||
scratch = _make_scratch(
|
||||
[256, 512, 768, 768], features, groups=groups, expand=expand
|
||||
) # ViT-H/16 - 85.0% Top1 (backbone)
|
||||
elif backbone == "vitb16_384":
|
||||
pretrained = _make_pretrained_vitb16_384(
|
||||
use_pretrained, hooks=hooks, use_readout=use_readout
|
||||
)
|
||||
scratch = _make_scratch(
|
||||
[96, 192, 384, 768], features, groups=groups, expand=expand
|
||||
) # ViT-B/16 - 84.6% Top1 (backbone)
|
||||
elif backbone == "resnext101_wsl":
|
||||
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
||||
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
|
||||
elif backbone == "efficientnet_lite3":
|
||||
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
|
||||
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
|
||||
else:
|
||||
print(f"Backbone '{backbone}' not implemented")
|
||||
assert False
|
||||
|
||||
return pretrained, scratch
|
||||
|
||||
|
||||
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
||||
scratch = nn.Module()
|
||||
|
||||
out_shape1 = out_shape
|
||||
out_shape2 = out_shape
|
||||
out_shape3 = out_shape
|
||||
out_shape4 = out_shape
|
||||
if expand==True:
|
||||
out_shape1 = out_shape
|
||||
out_shape2 = out_shape*2
|
||||
out_shape3 = out_shape*4
|
||||
out_shape4 = out_shape*8
|
||||
|
||||
scratch.layer1_rn = nn.Conv2d(
|
||||
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
scratch.layer2_rn = nn.Conv2d(
|
||||
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
scratch.layer3_rn = nn.Conv2d(
|
||||
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
scratch.layer4_rn = nn.Conv2d(
|
||||
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
|
||||
return scratch
|
||||
|
||||
|
||||
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
|
||||
efficientnet = torch.hub.load(
|
||||
"rwightman/gen-efficientnet-pytorch",
|
||||
"tf_efficientnet_lite3",
|
||||
pretrained=use_pretrained,
|
||||
exportable=exportable
|
||||
)
|
||||
return _make_efficientnet_backbone(efficientnet)
|
||||
|
||||
|
||||
def _make_efficientnet_backbone(effnet):
|
||||
pretrained = nn.Module()
|
||||
|
||||
pretrained.layer1 = nn.Sequential(
|
||||
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
|
||||
)
|
||||
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
|
||||
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
|
||||
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
|
||||
|
||||
return pretrained
|
||||
|
||||
|
||||
def _make_resnet_backbone(resnet):
|
||||
pretrained = nn.Module()
|
||||
pretrained.layer1 = nn.Sequential(
|
||||
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
|
||||
)
|
||||
|
||||
pretrained.layer2 = resnet.layer2
|
||||
pretrained.layer3 = resnet.layer3
|
||||
pretrained.layer4 = resnet.layer4
|
||||
|
||||
return pretrained
|
||||
|
||||
|
||||
def _make_pretrained_resnext101_wsl(use_pretrained):
|
||||
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
|
||||
return _make_resnet_backbone(resnet)
|
||||
|
||||
|
||||
|
||||
class Interpolate(nn.Module):
|
||||
"""Interpolation module.
|
||||
"""
|
||||
|
||||
def __init__(self, scale_factor, mode, align_corners=False):
|
||||
"""Init.
|
||||
Args:
|
||||
scale_factor (float): scaling
|
||||
mode (str): interpolation mode
|
||||
"""
|
||||
super(Interpolate, self).__init__()
|
||||
|
||||
self.interp = nn.functional.interpolate
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
Args:
|
||||
x (tensor): input
|
||||
Returns:
|
||||
tensor: interpolated data
|
||||
"""
|
||||
|
||||
x = self.interp(
|
||||
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResidualConvUnit(nn.Module):
|
||||
"""Residual convolution module.
|
||||
"""
|
||||
|
||||
def __init__(self, features):
|
||||
"""Init.
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
||||
)
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
||||
)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
Args:
|
||||
x (tensor): input
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
out = self.relu(x)
|
||||
out = self.conv1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
|
||||
return out + x
|
||||
|
||||
|
||||
class FeatureFusionBlock(nn.Module):
|
||||
"""Feature fusion block.
|
||||
"""
|
||||
|
||||
def __init__(self, features):
|
||||
"""Init.
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super(FeatureFusionBlock, self).__init__()
|
||||
|
||||
self.resConfUnit1 = ResidualConvUnit(features)
|
||||
self.resConfUnit2 = ResidualConvUnit(features)
|
||||
|
||||
def forward(self, *xs):
|
||||
"""Forward pass.
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
output = xs[0]
|
||||
|
||||
if len(xs) == 2:
|
||||
output += self.resConfUnit1(xs[1])
|
||||
|
||||
output = self.resConfUnit2(output)
|
||||
|
||||
output = nn.functional.interpolate(
|
||||
output, scale_factor=2, mode="bilinear", align_corners=True
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
||||
|
||||
class ResidualConvUnit_custom(nn.Module):
|
||||
"""Residual convolution module.
|
||||
"""
|
||||
|
||||
def __init__(self, features, activation, bn):
|
||||
"""Init.
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.bn = bn
|
||||
|
||||
self.groups=1
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
||||
)
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
||||
)
|
||||
|
||||
if self.bn==True:
|
||||
self.bn1 = nn.BatchNorm2d(features)
|
||||
self.bn2 = nn.BatchNorm2d(features)
|
||||
|
||||
self.activation = activation
|
||||
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
Args:
|
||||
x (tensor): input
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
|
||||
out = self.activation(x)
|
||||
out = self.conv1(out)
|
||||
if self.bn==True:
|
||||
out = self.bn1(out)
|
||||
|
||||
out = self.activation(out)
|
||||
out = self.conv2(out)
|
||||
if self.bn==True:
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.groups > 1:
|
||||
out = self.conv_merge(out)
|
||||
|
||||
return self.skip_add.add(out, x)
|
||||
|
||||
# return out + x
|
||||
|
||||
|
||||
class FeatureFusionBlock_custom(nn.Module):
|
||||
"""Feature fusion block.
|
||||
"""
|
||||
|
||||
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
|
||||
"""Init.
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super(FeatureFusionBlock_custom, self).__init__()
|
||||
|
||||
self.deconv = deconv
|
||||
self.align_corners = align_corners
|
||||
|
||||
self.groups=1
|
||||
|
||||
self.expand = expand
|
||||
out_features = features
|
||||
if self.expand==True:
|
||||
out_features = features//2
|
||||
|
||||
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
||||
|
||||
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
||||
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
||||
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
|
||||
def forward(self, *xs):
|
||||
"""Forward pass.
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
output = xs[0]
|
||||
|
||||
if len(xs) == 2:
|
||||
res = self.resConfUnit1(xs[1])
|
||||
output = self.skip_add.add(output, res)
|
||||
# output += res
|
||||
|
||||
output = self.resConfUnit2(output)
|
||||
|
||||
output = nn.functional.interpolate(
|
||||
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
||||
)
|
||||
|
||||
output = self.out_conv(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
||||
def _make_fusion_block(features, use_bn):
|
||||
return FeatureFusionBlock_custom(
|
||||
features,
|
||||
nn.ReLU(False),
|
||||
deconv=False,
|
||||
bn=use_bn,
|
||||
expand=False,
|
||||
align_corners=True,
|
||||
)
|
||||
|
||||
|
||||
class DPT(BaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
head,
|
||||
features=256,
|
||||
backbone="vitb_rn50_384",
|
||||
readout="project",
|
||||
channels_last=False,
|
||||
use_bn=False,
|
||||
):
|
||||
|
||||
super(DPT, self).__init__()
|
||||
|
||||
self.channels_last = channels_last
|
||||
|
||||
hooks = {
|
||||
"vitb_rn50_384": [0, 1, 8, 11],
|
||||
"vitb16_384": [2, 5, 8, 11],
|
||||
"vitl16_384": [5, 11, 17, 23],
|
||||
}
|
||||
|
||||
# Instantiate backbone and reassemble blocks
|
||||
self.pretrained, self.scratch = _make_encoder(
|
||||
backbone,
|
||||
features,
|
||||
True, # Set to true of you want to train from scratch, uses ImageNet weights
|
||||
groups=1,
|
||||
expand=False,
|
||||
exportable=False,
|
||||
hooks=hooks[backbone],
|
||||
use_readout=readout,
|
||||
)
|
||||
|
||||
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
||||
|
||||
self.scratch.output_conv = head
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
if self.channels_last == True:
|
||||
x.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||
|
||||
path_4 = self.scratch.refinenet4(layer_4_rn)
|
||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||
|
||||
out = self.scratch.output_conv(path_1)
|
||||
|
||||
return out
|
||||
|
||||
class DPTDepthModel(DPT):
|
||||
def __init__(self, path=None, non_negative=True, num_channels=1, **kwargs):
|
||||
features = kwargs["features"] if "features" in kwargs else 256
|
||||
|
||||
head = nn.Sequential(
|
||||
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
||||
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
||||
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(32, num_channels, kernel_size=1, stride=1, padding=0),
|
||||
nn.ReLU(True) if non_negative else nn.Identity(),
|
||||
nn.Identity(),
|
||||
)
|
||||
|
||||
super().__init__(head, **kwargs)
|
||||
|
||||
if path is not None:
|
||||
self.load(path)
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x).squeeze(dim=1)
|
||||
89
encoding.py
Normal file
89
encoding.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class FreqEncoder_torch(nn.Module):
|
||||
def __init__(self, input_dim, max_freq_log2, N_freqs,
|
||||
log_sampling=True, include_input=True,
|
||||
periodic_fns=(torch.sin, torch.cos)):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.include_input = include_input
|
||||
self.periodic_fns = periodic_fns
|
||||
self.N_freqs = N_freqs
|
||||
|
||||
self.output_dim = 0
|
||||
if self.include_input:
|
||||
self.output_dim += self.input_dim
|
||||
|
||||
self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns)
|
||||
|
||||
if log_sampling:
|
||||
self.freq_bands = 2 ** torch.linspace(0, max_freq_log2, N_freqs)
|
||||
else:
|
||||
self.freq_bands = torch.linspace(2 ** 0, 2 ** max_freq_log2, N_freqs)
|
||||
|
||||
self.freq_bands = self.freq_bands.numpy().tolist()
|
||||
|
||||
def forward(self, input, max_level=None, **kwargs):
|
||||
|
||||
if max_level is None:
|
||||
max_level = self.N_freqs
|
||||
else:
|
||||
max_level = int(max_level * self.N_freqs)
|
||||
|
||||
out = []
|
||||
if self.include_input:
|
||||
out.append(input)
|
||||
|
||||
for i in range(max_level):
|
||||
freq = self.freq_bands[i]
|
||||
for p_fn in self.periodic_fns:
|
||||
out.append(p_fn(input * freq))
|
||||
|
||||
# append 0
|
||||
if self.N_freqs - max_level > 0:
|
||||
out.append(torch.zeros(input.shape[0], (self.N_freqs - max_level) * 2 * input.shape[1], device=input.device, dtype=input.dtype))
|
||||
|
||||
out = torch.cat(out, dim=-1)
|
||||
|
||||
return out
|
||||
|
||||
def get_encoder(encoding, input_dim=3,
|
||||
multires=6,
|
||||
degree=4,
|
||||
num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, interpolation='linear',
|
||||
**kwargs):
|
||||
|
||||
if encoding == 'None':
|
||||
return lambda x, **kwargs: x, input_dim
|
||||
|
||||
elif encoding == 'frequency_torch':
|
||||
encoder = FreqEncoder_torch(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True)
|
||||
|
||||
elif encoding == 'frequency': # CUDA implementation, faster than torch.
|
||||
from freqencoder import FreqEncoder
|
||||
encoder = FreqEncoder(input_dim=input_dim, degree=multires)
|
||||
|
||||
elif encoding == 'sphere_harmonics':
|
||||
from shencoder import SHEncoder
|
||||
encoder = SHEncoder(input_dim=input_dim, degree=degree)
|
||||
|
||||
elif encoding == 'hashgrid':
|
||||
from gridencoder import GridEncoder
|
||||
encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners, interpolation=interpolation)
|
||||
|
||||
elif encoding == 'tiledgrid':
|
||||
from gridencoder import GridEncoder
|
||||
encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners, interpolation=interpolation)
|
||||
|
||||
elif encoding == 'hashgrid_taichi':
|
||||
from taichi_modules.hash_encoder import HashEncoderTaichi
|
||||
encoder = HashEncoderTaichi(batch_size=4096) #TODO: hard encoded batch size
|
||||
|
||||
else:
|
||||
raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]')
|
||||
|
||||
return encoder, encoder.output_dim
|
||||
1
freqencoder/__init__.py
Normal file
1
freqencoder/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .freq import FreqEncoder
|
||||
42
freqencoder/backend.py
Normal file
42
freqencoder/backend.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import os
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
nvcc_flags = [
|
||||
'-O3', '-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
||||
'-use_fast_math'
|
||||
]
|
||||
|
||||
if os.name == "posix":
|
||||
c_flags = ['-O3', '-std=c++14']
|
||||
elif os.name == "nt":
|
||||
c_flags = ['/O2', '/std:c++17']
|
||||
|
||||
# find cl.exe
|
||||
def find_cl_path():
|
||||
import glob
|
||||
for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
|
||||
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
||||
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
|
||||
if paths:
|
||||
return paths[0]
|
||||
|
||||
# If cl.exe is not on path, try to find it.
|
||||
if os.system("where cl.exe >nul 2>nul") != 0:
|
||||
cl_path = find_cl_path()
|
||||
if cl_path is None:
|
||||
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
||||
os.environ["PATH"] += ";" + cl_path
|
||||
|
||||
_backend = load(name='_freqencoder',
|
||||
extra_cflags=c_flags,
|
||||
extra_cuda_cflags=nvcc_flags,
|
||||
sources=[os.path.join(_src_path, 'src', f) for f in [
|
||||
'freqencoder.cu',
|
||||
'bindings.cpp',
|
||||
]],
|
||||
)
|
||||
|
||||
__all__ = ['_backend']
|
||||
77
freqencoder/freq.py
Normal file
77
freqencoder/freq.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
try:
|
||||
import _freqencoder as _backend
|
||||
except ImportError:
|
||||
from .backend import _backend
|
||||
|
||||
|
||||
class _freq_encoder(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
|
||||
def forward(ctx, inputs, degree, output_dim):
|
||||
# inputs: [B, input_dim], float
|
||||
# RETURN: [B, F], float
|
||||
|
||||
if not inputs.is_cuda: inputs = inputs.cuda()
|
||||
inputs = inputs.contiguous()
|
||||
|
||||
B, input_dim = inputs.shape # batch size, coord dim
|
||||
|
||||
outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
|
||||
|
||||
_backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
|
||||
|
||||
ctx.save_for_backward(inputs, outputs)
|
||||
ctx.dims = [B, input_dim, degree, output_dim]
|
||||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
#@once_differentiable
|
||||
@custom_bwd
|
||||
def backward(ctx, grad):
|
||||
# grad: [B, C * C]
|
||||
|
||||
grad = grad.contiguous()
|
||||
inputs, outputs = ctx.saved_tensors
|
||||
B, input_dim, degree, output_dim = ctx.dims
|
||||
|
||||
grad_inputs = torch.zeros_like(inputs)
|
||||
_backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
|
||||
|
||||
return grad_inputs, None, None
|
||||
|
||||
|
||||
freq_encode = _freq_encoder.apply
|
||||
|
||||
|
||||
class FreqEncoder(nn.Module):
|
||||
def __init__(self, input_dim=3, degree=4):
|
||||
super().__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.degree = degree
|
||||
self.output_dim = input_dim + input_dim * 2 * degree
|
||||
|
||||
def __repr__(self):
|
||||
return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}"
|
||||
|
||||
def forward(self, inputs, **kwargs):
|
||||
# inputs: [..., input_dim]
|
||||
# return: [..., ]
|
||||
|
||||
prefix_shape = list(inputs.shape[:-1])
|
||||
inputs = inputs.reshape(-1, self.input_dim)
|
||||
|
||||
outputs = freq_encode(inputs, self.degree, self.output_dim)
|
||||
|
||||
outputs = outputs.reshape(prefix_shape + [self.output_dim])
|
||||
|
||||
return outputs
|
||||
52
freqencoder/setup.py
Normal file
52
freqencoder/setup.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
nvcc_flags = [
|
||||
'-O3', '-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
||||
'-use_fast_math'
|
||||
]
|
||||
|
||||
if os.name == "posix":
|
||||
c_flags = ['-O3', '-std=c++14']
|
||||
elif os.name == "nt":
|
||||
c_flags = ['/O2', '/std:c++17']
|
||||
|
||||
# find cl.exe
|
||||
def find_cl_path():
|
||||
import glob
|
||||
for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
|
||||
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
||||
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
|
||||
if paths:
|
||||
return paths[0]
|
||||
|
||||
# If cl.exe is not on path, try to find it.
|
||||
if os.system("where cl.exe >nul 2>nul") != 0:
|
||||
cl_path = find_cl_path()
|
||||
if cl_path is None:
|
||||
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
||||
os.environ["PATH"] += ";" + cl_path
|
||||
|
||||
setup(
|
||||
name='freqencoder', # package name, import this to use python API
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name='_freqencoder', # extension name, import this to use CUDA API
|
||||
sources=[os.path.join(_src_path, 'src', f) for f in [
|
||||
'freqencoder.cu',
|
||||
'bindings.cpp',
|
||||
]],
|
||||
extra_compile_args={
|
||||
'cxx': c_flags,
|
||||
'nvcc': nvcc_flags,
|
||||
}
|
||||
),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension,
|
||||
}
|
||||
)
|
||||
8
freqencoder/src/bindings.cpp
Normal file
8
freqencoder/src/bindings.cpp
Normal file
@@ -0,0 +1,8 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "freqencoder.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)");
|
||||
m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)");
|
||||
}
|
||||
129
freqencoder/src/freqencoder.cu
Normal file
129
freqencoder/src/freqencoder.cu
Normal file
@@ -0,0 +1,129 @@
|
||||
#include <stdint.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
|
||||
#include <cstdio>
|
||||
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
|
||||
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
|
||||
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
|
||||
|
||||
inline constexpr __device__ float PI() { return 3.141592653589793f; }
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ T div_round_up(T val, T divisor) {
|
||||
return (val + divisor - 1) / divisor;
|
||||
}
|
||||
|
||||
// inputs: [B, D]
|
||||
// outputs: [B, C], C = D + D * deg * 2
|
||||
__global__ void kernel_freq(
|
||||
const float * __restrict__ inputs,
|
||||
uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
|
||||
float * outputs
|
||||
) {
|
||||
// parallel on per-element
|
||||
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (t >= B * C) return;
|
||||
|
||||
// get index
|
||||
const uint32_t b = t / C;
|
||||
const uint32_t c = t - b * C; // t % C;
|
||||
|
||||
// locate
|
||||
inputs += b * D;
|
||||
outputs += t;
|
||||
|
||||
// write self
|
||||
if (c < D) {
|
||||
outputs[0] = inputs[c];
|
||||
// write freq
|
||||
} else {
|
||||
const uint32_t col = c / D - 1;
|
||||
const uint32_t d = c % D;
|
||||
const uint32_t freq = col / 2;
|
||||
const float phase_shift = (col % 2) * (PI() / 2);
|
||||
outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift);
|
||||
}
|
||||
}
|
||||
|
||||
// grad: [B, C], C = D + D * deg * 2
|
||||
// outputs: [B, C]
|
||||
// grad_inputs: [B, D]
|
||||
__global__ void kernel_freq_backward(
|
||||
const float * __restrict__ grad,
|
||||
const float * __restrict__ outputs,
|
||||
uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
|
||||
float * grad_inputs
|
||||
) {
|
||||
// parallel on per-element
|
||||
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (t >= B * D) return;
|
||||
|
||||
const uint32_t b = t / D;
|
||||
const uint32_t d = t - b * D; // t % D;
|
||||
|
||||
// locate
|
||||
grad += b * C;
|
||||
outputs += b * C;
|
||||
grad_inputs += t;
|
||||
|
||||
// register
|
||||
float result = grad[d];
|
||||
grad += D;
|
||||
outputs += D;
|
||||
|
||||
for (uint32_t f = 0; f < deg; f++) {
|
||||
result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]);
|
||||
grad += 2 * D;
|
||||
outputs += 2 * D;
|
||||
}
|
||||
|
||||
// write
|
||||
grad_inputs[0] = result;
|
||||
}
|
||||
|
||||
|
||||
void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) {
|
||||
CHECK_CUDA(inputs);
|
||||
CHECK_CUDA(outputs);
|
||||
|
||||
CHECK_CONTIGUOUS(inputs);
|
||||
CHECK_CONTIGUOUS(outputs);
|
||||
|
||||
CHECK_IS_FLOATING(inputs);
|
||||
CHECK_IS_FLOATING(outputs);
|
||||
|
||||
static constexpr uint32_t N_THREADS = 128;
|
||||
|
||||
kernel_freq<<<div_round_up(B * C, N_THREADS), N_THREADS>>>(inputs.data_ptr<float>(), B, D, deg, C, outputs.data_ptr<float>());
|
||||
}
|
||||
|
||||
|
||||
void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) {
|
||||
CHECK_CUDA(grad);
|
||||
CHECK_CUDA(outputs);
|
||||
CHECK_CUDA(grad_inputs);
|
||||
|
||||
CHECK_CONTIGUOUS(grad);
|
||||
CHECK_CONTIGUOUS(outputs);
|
||||
CHECK_CONTIGUOUS(grad_inputs);
|
||||
|
||||
CHECK_IS_FLOATING(grad);
|
||||
CHECK_IS_FLOATING(outputs);
|
||||
CHECK_IS_FLOATING(grad_inputs);
|
||||
|
||||
static constexpr uint32_t N_THREADS = 128;
|
||||
|
||||
kernel_freq_backward<<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad.data_ptr<float>(), outputs.data_ptr<float>(), B, D, deg, C, grad_inputs.data_ptr<float>());
|
||||
}
|
||||
10
freqencoder/src/freqencoder.h
Normal file
10
freqencoder/src/freqencoder.h
Normal file
@@ -0,0 +1,10 @@
|
||||
# pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
// _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
|
||||
void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs);
|
||||
|
||||
// _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
|
||||
void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs);
|
||||
246
gradio_app.py
Normal file
246
gradio_app.py
Normal file
@@ -0,0 +1,246 @@
|
||||
import torch
|
||||
import argparse
|
||||
|
||||
from nerf.provider import NeRFDataset
|
||||
from nerf.utils import *
|
||||
|
||||
import gradio as gr
|
||||
import gc
|
||||
|
||||
print(f'[INFO] loading options..')
|
||||
|
||||
# fake config object, this should not be used in CMD, only allow change from gradio UI.
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--text', default=None, help="text prompt")
|
||||
parser.add_argument('--negative', default='', type=str, help="negative text prompt")
|
||||
parser.add_argument('--test', action='store_true', help="test mode")
|
||||
parser.add_argument('--eval_interval', type=int, default=10, help="evaluate on the valid set every interval epochs")
|
||||
parser.add_argument('--workspace', type=str, default='trial_gradio')
|
||||
parser.add_argument('--guidance', type=str, default='stable-diffusion', help='choose from [stable-diffusion, clip]')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
|
||||
parser.add_argument('--save_mesh', action='store_true', help="export an obj mesh with texture")
|
||||
parser.add_argument('--mcubes_resolution', type=int, default=256, help="mcubes resolution for extracting mesh")
|
||||
parser.add_argument('--decimate_target', type=int, default=1e5, help="target face number for mesh decimation")
|
||||
|
||||
### training options
|
||||
parser.add_argument('--iters', type=int, default=10000, help="training iters")
|
||||
parser.add_argument('--lr', type=float, default=1e-3, help="initial learning rate")
|
||||
parser.add_argument('--ckpt', type=str, default='latest')
|
||||
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
|
||||
parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)")
|
||||
parser.add_argument('--num_steps', type=int, default=64, help="num steps sampled per ray (only valid when not using --cuda_ray)")
|
||||
parser.add_argument('--upsample_steps', type=int, default=64, help="num steps up-sampled per ray (only valid when not using --cuda_ray)")
|
||||
parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
|
||||
parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)")
|
||||
parser.add_argument('--warmup_iters', type=int, default=1000, help="training iters that only use albedo shading")
|
||||
parser.add_argument('--uniform_sphere_rate', type=float, default=0.5, help="likelihood of sampling camera location uniformly on the sphere surface area")
|
||||
# model options
|
||||
parser.add_argument('--bg_radius', type=float, default=1.4, help="if positive, use a background model at sphere(bg_radius)")
|
||||
parser.add_argument('--density_activation', type=str, default='softplus', choices=['softplus', 'exp'], help="density activation function")
|
||||
parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied")
|
||||
parser.add_argument('--blob_density', type=float, default=10, help="max (center) density for the density blob")
|
||||
parser.add_argument('--blob_radius', type=float, default=0.3, help="control the radius for the density blob")
|
||||
# network backbone
|
||||
parser.add_argument('--fp16', action='store_true', help="use float16 for training")
|
||||
parser.add_argument('--vram_O', action='store_true', help="optimization for low VRAM usage")
|
||||
parser.add_argument('--backbone', type=str, default='grid', help="nerf backbone, choose from [grid, vanilla]")
|
||||
parser.add_argument('--optim', type=str, default='adan', choices=['adan', 'adam', 'adamw'], help="optimizer")
|
||||
parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version")
|
||||
parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key")
|
||||
# rendering resolution in training, decrease this if CUDA OOM.
|
||||
parser.add_argument('--w', type=int, default=64, help="render width for NeRF in training")
|
||||
parser.add_argument('--h', type=int, default=64, help="render height for NeRF in training")
|
||||
parser.add_argument('--jitter_pose', action='store_true', help="add jitters to the randomly sampled camera poses")
|
||||
|
||||
### dataset options
|
||||
parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box(-bound, bound)")
|
||||
parser.add_argument('--dt_gamma', type=float, default=0, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
|
||||
parser.add_argument('--min_near', type=float, default=0.1, help="minimum near distance for camera")
|
||||
parser.add_argument('--radius_range', type=float, nargs='*', default=[1.0, 1.5], help="training camera radius range")
|
||||
parser.add_argument('--fovy_range', type=float, nargs='*', default=[40, 70], help="training camera fovy range")
|
||||
parser.add_argument('--angle_overhead', type=float, default=30, help="[0, angle_overhead] is the overhead region")
|
||||
parser.add_argument('--angle_front', type=float, default=60, help="[0, angle_front] is the front region, [180, 180+angle_front] the back region, otherwise the side region.")
|
||||
|
||||
parser.add_argument('--lambda_entropy', type=float, default=1e-4, help="loss scale for alpha entropy")
|
||||
parser.add_argument('--lambda_opacity', type=float, default=0, help="loss scale for alpha value")
|
||||
parser.add_argument('--lambda_orient', type=float, default=1e-2, help="loss scale for orientation")
|
||||
parser.add_argument('--lambda_smooth', type=float, default=0, help="loss scale for surface smoothness")
|
||||
|
||||
### GUI options
|
||||
parser.add_argument('--gui', action='store_true', help="start a GUI")
|
||||
parser.add_argument('--W', type=int, default=800, help="GUI width")
|
||||
parser.add_argument('--H', type=int, default=800, help="GUI height")
|
||||
parser.add_argument('--radius', type=float, default=3, help="default GUI camera radius from center")
|
||||
parser.add_argument('--fovy', type=float, default=60, help="default GUI camera fovy")
|
||||
parser.add_argument('--light_theta', type=float, default=60, help="default GUI light direction in [0, 180], corresponding to elevation [90, -90]")
|
||||
parser.add_argument('--light_phi', type=float, default=0, help="default GUI light direction in [0, 360), azimuth")
|
||||
parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")
|
||||
|
||||
parser.add_argument('--need_share', type=bool, default=False, help="do you want to share gradio app to external network?")
|
||||
|
||||
opt = parser.parse_args()
|
||||
|
||||
# default to use -O !!!
|
||||
opt.fp16 = True
|
||||
opt.cuda_ray = True
|
||||
opt.vram_O = True
|
||||
# opt.lambda_entropy = 1e-4
|
||||
# opt.lambda_opacity = 0
|
||||
|
||||
if opt.backbone == 'vanilla':
|
||||
from nerf.network import NeRFNetwork
|
||||
elif opt.backbone == 'grid':
|
||||
from nerf.network_grid import NeRFNetwork
|
||||
else:
|
||||
raise NotImplementedError(f'--backbone {opt.backbone} is not implemented!')
|
||||
|
||||
print(opt)
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
print(f'[INFO] loading models..')
|
||||
|
||||
if opt.guidance == 'stable-diffusion':
|
||||
from guidance.sd_utils import StableDiffusion
|
||||
guidance = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key)
|
||||
elif opt.guidance == 'clip':
|
||||
from guidance.clip_utils import CLIP
|
||||
guidance = CLIP(device)
|
||||
else:
|
||||
raise NotImplementedError(f'--guidance {opt.guidance} is not implemented.')
|
||||
|
||||
train_loader = NeRFDataset(opt, device=device, type='train', H=opt.h, W=opt.w, size=100).dataloader()
|
||||
valid_loader = NeRFDataset(opt, device=device, type='val', H=opt.H, W=opt.W, size=5).dataloader()
|
||||
test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=100).dataloader()
|
||||
|
||||
print(f'[INFO] everything loaded!')
|
||||
|
||||
trainer = None
|
||||
model = None
|
||||
|
||||
# define UI
|
||||
|
||||
with gr.Blocks(css=".gradio-container {max-width: 512px; margin: auto;}") as demo:
|
||||
|
||||
# title
|
||||
gr.Markdown('[Stable-DreamFusion](https://github.com/ashawkey/stable-dreamfusion) Text-to-3D Example')
|
||||
|
||||
# inputs
|
||||
prompt = gr.Textbox(label="Prompt", max_lines=1, value="a DSLR photo of a koi fish")
|
||||
iters = gr.Slider(label="Iters", minimum=1000, maximum=20000, value=5000, step=100)
|
||||
seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True)
|
||||
button = gr.Button('Generate')
|
||||
|
||||
# outputs
|
||||
image = gr.Image(label="image", visible=True)
|
||||
video = gr.Video(label="video", visible=False)
|
||||
logs = gr.Textbox(label="logging")
|
||||
|
||||
# gradio main func
|
||||
def submit(text, iters, seed):
|
||||
|
||||
global trainer, model
|
||||
|
||||
# seed
|
||||
opt.seed = seed
|
||||
opt.text = text
|
||||
opt.iters = iters
|
||||
|
||||
seed_everything(seed)
|
||||
|
||||
# clean up
|
||||
if trainer is not None:
|
||||
del model
|
||||
del trainer
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
print('[INFO] clean up!')
|
||||
|
||||
# simply reload everything...
|
||||
model = NeRFNetwork(opt)
|
||||
|
||||
if opt.optim == 'adan':
|
||||
from optimizer import Adan
|
||||
# Adan usually requires a larger LR
|
||||
optimizer = lambda model: Adan(model.get_params(5 * opt.lr), eps=1e-15)
|
||||
elif opt.optim == 'adamw':
|
||||
optimizer = lambda model: torch.optim.AdamW(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)
|
||||
else: # adam
|
||||
optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)
|
||||
|
||||
scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 1) # fixed
|
||||
|
||||
trainer = Trainer('df', opt, model, guidance, device=device, workspace=opt.workspace, optimizer=optimizer, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint=opt.ckpt, eval_interval=opt.eval_interval, scheduler_update_every_step=True)
|
||||
|
||||
# train (every ep only contain 8 steps, so we can get some vis every ~10s)
|
||||
STEPS = 8
|
||||
max_epochs = np.ceil(opt.iters / STEPS).astype(np.int32)
|
||||
|
||||
# we have to get the explicit training loop out here to yield progressive results...
|
||||
loader = iter(valid_loader)
|
||||
|
||||
start_t = time.time()
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
|
||||
trainer.train_gui(train_loader, step=STEPS)
|
||||
|
||||
# manual test and get intermediate results
|
||||
try:
|
||||
data = next(loader)
|
||||
except StopIteration:
|
||||
loader = iter(valid_loader)
|
||||
data = next(loader)
|
||||
|
||||
trainer.model.eval()
|
||||
|
||||
if trainer.ema is not None:
|
||||
trainer.ema.store()
|
||||
trainer.ema.copy_to()
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=trainer.fp16):
|
||||
preds, preds_depth = trainer.test_step(data, perturb=False)
|
||||
|
||||
if trainer.ema is not None:
|
||||
trainer.ema.restore()
|
||||
|
||||
pred = preds[0].detach().cpu().numpy()
|
||||
# pred_depth = preds_depth[0].detach().cpu().numpy()
|
||||
|
||||
pred = (pred * 255).astype(np.uint8)
|
||||
|
||||
yield {
|
||||
image: gr.update(value=pred, visible=True),
|
||||
video: gr.update(visible=False),
|
||||
logs: f"training iters: {epoch * STEPS} / {iters}, lr: {trainer.optimizer.param_groups[0]['lr']:.6f}",
|
||||
}
|
||||
|
||||
|
||||
# test
|
||||
trainer.test(test_loader)
|
||||
|
||||
results = glob.glob(os.path.join(opt.workspace, 'results', '*rgb*.mp4'))
|
||||
assert results is not None, "cannot retrieve results!"
|
||||
results.sort(key=lambda x: os.path.getmtime(x)) # sort by mtime
|
||||
|
||||
end_t = time.time()
|
||||
|
||||
yield {
|
||||
image: gr.update(visible=False),
|
||||
video: gr.update(value=results[-1], visible=True),
|
||||
logs: f"Generation Finished in {(end_t - start_t)/ 60:.4f} minutes!",
|
||||
}
|
||||
|
||||
|
||||
button.click(
|
||||
submit,
|
||||
[prompt, iters, seed],
|
||||
[image, video, logs]
|
||||
)
|
||||
|
||||
# concurrency_count: only allow ONE running progress, else GPU will OOM.
|
||||
demo.queue(concurrency_count=1)
|
||||
|
||||
demo.launch(share=opt.need_share)
|
||||
1
gridencoder/__init__.py
Normal file
1
gridencoder/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .grid import GridEncoder
|
||||
40
gridencoder/backend.py
Normal file
40
gridencoder/backend.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import os
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
nvcc_flags = [
|
||||
'-O3', '-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
||||
]
|
||||
|
||||
if os.name == "posix":
|
||||
c_flags = ['-O3', '-std=c++14']
|
||||
elif os.name == "nt":
|
||||
c_flags = ['/O2', '/std:c++17']
|
||||
|
||||
# find cl.exe
|
||||
def find_cl_path():
|
||||
import glob
|
||||
for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
|
||||
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
||||
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
|
||||
if paths:
|
||||
return paths[0]
|
||||
# If cl.exe is not on path, try to find it.
|
||||
if os.system("where cl.exe >nul 2>nul") != 0:
|
||||
cl_path = find_cl_path()
|
||||
if cl_path is None:
|
||||
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
||||
os.environ["PATH"] += ";" + cl_path
|
||||
|
||||
_backend = load(name='_grid_encoder',
|
||||
extra_cflags=c_flags,
|
||||
extra_cuda_cflags=nvcc_flags,
|
||||
sources=[os.path.join(_src_path, 'src', f) for f in [
|
||||
'gridencoder.cu',
|
||||
'bindings.cpp',
|
||||
]],
|
||||
)
|
||||
|
||||
__all__ = ['_backend']
|
||||
206
gridencoder/grid.py
Normal file
206
gridencoder/grid.py
Normal file
@@ -0,0 +1,206 @@
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
try:
|
||||
import _gridencoder as _backend
|
||||
except ImportError:
|
||||
from .backend import _backend
|
||||
|
||||
_gridtype_to_id = {
|
||||
'hash': 0,
|
||||
'tiled': 1,
|
||||
}
|
||||
|
||||
_interp_to_id = {
|
||||
'linear': 0,
|
||||
'smoothstep': 1,
|
||||
}
|
||||
|
||||
class _grid_encode(Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False, interpolation=0, max_level=None):
|
||||
# inputs: [B, D], float in [0, 1]
|
||||
# embeddings: [sO, C], float
|
||||
# offsets: [L + 1], int
|
||||
# RETURN: [B, F], float
|
||||
|
||||
inputs = inputs.contiguous()
|
||||
|
||||
B, D = inputs.shape # batch size, coord dim
|
||||
L = offsets.shape[0] - 1 # level
|
||||
C = embeddings.shape[1] # embedding dim for each level
|
||||
S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
|
||||
H = base_resolution # base resolution
|
||||
|
||||
max_level = L if max_level is None else max(min(int(math.ceil(max_level * L)), L), 1)
|
||||
|
||||
# manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
|
||||
# if C % 2 != 0, force float, since half for atomicAdd is very slow.
|
||||
if torch.is_autocast_enabled() and C % 2 == 0:
|
||||
embeddings = embeddings.to(torch.half)
|
||||
|
||||
# L first, optimize cache for cuda kernel, but needs an extra permute later
|
||||
outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
|
||||
|
||||
# zero init if we only calculate partial levels
|
||||
if max_level < L: outputs.zero_()
|
||||
|
||||
if calc_grad_inputs:
|
||||
dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype)
|
||||
if max_level < L: dy_dx.zero_()
|
||||
else:
|
||||
dy_dx = None
|
||||
|
||||
_backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interpolation)
|
||||
|
||||
# permute back to [B, L * C]
|
||||
outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
|
||||
|
||||
ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
|
||||
ctx.dims = [B, D, C, L, S, H, gridtype, interpolation, max_level]
|
||||
ctx.align_corners = align_corners
|
||||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
#@once_differentiable
|
||||
@custom_bwd
|
||||
def backward(ctx, grad):
|
||||
|
||||
inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
|
||||
B, D, C, L, S, H, gridtype, interpolation, max_level = ctx.dims
|
||||
align_corners = ctx.align_corners
|
||||
|
||||
# grad: [B, L * C] --> [L, B, C]
|
||||
grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
|
||||
|
||||
grad_embeddings = torch.zeros_like(embeddings)
|
||||
|
||||
if dy_dx is not None:
|
||||
grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
|
||||
else:
|
||||
grad_inputs = None
|
||||
|
||||
_backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interpolation)
|
||||
|
||||
if dy_dx is not None:
|
||||
grad_inputs = grad_inputs.to(inputs.dtype)
|
||||
|
||||
return grad_inputs, grad_embeddings, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
|
||||
grid_encode = _grid_encode.apply
|
||||
|
||||
|
||||
class GridEncoder(nn.Module):
|
||||
def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False, interpolation='linear'):
|
||||
super().__init__()
|
||||
|
||||
# the finest resolution desired at the last level, if provided, overridee per_level_scale
|
||||
if desired_resolution is not None:
|
||||
per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))
|
||||
|
||||
self.input_dim = input_dim # coord dims, 2 or 3
|
||||
self.num_levels = num_levels # num levels, each level multiply resolution by 2
|
||||
self.level_dim = level_dim # encode channels per level
|
||||
self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
|
||||
self.log2_hashmap_size = log2_hashmap_size
|
||||
self.base_resolution = base_resolution
|
||||
self.output_dim = num_levels * level_dim
|
||||
self.gridtype = gridtype
|
||||
self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"
|
||||
self.interpolation = interpolation
|
||||
self.interp_id = _interp_to_id[interpolation] # "linear" or "smoothstep"
|
||||
self.align_corners = align_corners
|
||||
|
||||
# allocate parameters
|
||||
offsets = []
|
||||
offset = 0
|
||||
self.max_params = 2 ** log2_hashmap_size
|
||||
for i in range(num_levels):
|
||||
resolution = int(np.ceil(base_resolution * per_level_scale ** i))
|
||||
params_in_level = min(self.max_params, (resolution) ** input_dim) # limit max number
|
||||
params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
|
||||
offsets.append(offset)
|
||||
offset += params_in_level
|
||||
offsets.append(offset)
|
||||
offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
|
||||
self.register_buffer('offsets', offsets)
|
||||
|
||||
self.n_params = offsets[-1] * level_dim
|
||||
|
||||
# parameters
|
||||
self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
std = 1e-4
|
||||
self.embeddings.data.uniform_(-std, std)
|
||||
|
||||
def __repr__(self):
|
||||
return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}"
|
||||
|
||||
def forward(self, inputs, bound=1, max_level=None):
|
||||
# inputs: [..., input_dim], normalized real world positions in [-bound, bound]
|
||||
# max_level: only calculate first max_level levels (None will use all levels)
|
||||
# return: [..., num_levels * level_dim]
|
||||
|
||||
inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
|
||||
|
||||
#print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
|
||||
|
||||
prefix_shape = list(inputs.shape[:-1])
|
||||
inputs = inputs.view(-1, self.input_dim)
|
||||
|
||||
outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners, self.interp_id, max_level)
|
||||
outputs = outputs.view(prefix_shape + [self.output_dim])
|
||||
|
||||
#print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
|
||||
|
||||
return outputs
|
||||
|
||||
# always run in float precision!
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000):
|
||||
# inputs: [..., input_dim], float in [-b, b], location to calculate TV loss.
|
||||
|
||||
D = self.input_dim
|
||||
C = self.embeddings.shape[1] # embedding dim for each level
|
||||
L = self.offsets.shape[0] - 1 # level
|
||||
S = np.log2(self.per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
|
||||
H = self.base_resolution # base resolution
|
||||
|
||||
if inputs is None:
|
||||
# randomized in [0, 1]
|
||||
inputs = torch.rand(B, self.input_dim, device=self.embeddings.device)
|
||||
else:
|
||||
inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
|
||||
inputs = inputs.view(-1, self.input_dim)
|
||||
B = inputs.shape[0]
|
||||
|
||||
if self.embeddings.grad is None:
|
||||
raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!')
|
||||
|
||||
_backend.grad_total_variation(inputs, self.embeddings, self.embeddings.grad, self.offsets, weight, B, D, C, L, S, H, self.gridtype_id, self.align_corners)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def grad_weight_decay(self, weight=0.1):
|
||||
# level-wise meaned weight decay (ref: zip-nerf)
|
||||
|
||||
B = self.embeddings.shape[0] # size of embedding
|
||||
C = self.embeddings.shape[1] # embedding dim for each level
|
||||
L = self.offsets.shape[0] - 1 # level
|
||||
|
||||
if self.embeddings.grad is None:
|
||||
raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!')
|
||||
|
||||
_backend.grad_weight_decay(self.embeddings, self.embeddings.grad, self.offsets, weight, B, C, L)
|
||||
51
gridencoder/setup.py
Normal file
51
gridencoder/setup.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import os
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
nvcc_flags = [
|
||||
'-O3', '-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
||||
]
|
||||
|
||||
if os.name == "posix":
|
||||
c_flags = ['-O3', '-std=c++14']
|
||||
elif os.name == "nt":
|
||||
c_flags = ['/O2', '/std:c++17']
|
||||
|
||||
# find cl.exe
|
||||
def find_cl_path():
|
||||
import glob
|
||||
for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
|
||||
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
||||
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
|
||||
if paths:
|
||||
return paths[0]
|
||||
|
||||
# If cl.exe is not on path, try to find it.
|
||||
if os.system("where cl.exe >nul 2>nul") != 0:
|
||||
cl_path = find_cl_path()
|
||||
if cl_path is None:
|
||||
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
||||
os.environ["PATH"] += ";" + cl_path
|
||||
|
||||
setup(
|
||||
name='gridencoder', # package name, import this to use python API
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name='_gridencoder', # extension name, import this to use CUDA API
|
||||
sources=[os.path.join(_src_path, 'src', f) for f in [
|
||||
'gridencoder.cu',
|
||||
'bindings.cpp',
|
||||
]],
|
||||
extra_compile_args={
|
||||
'cxx': c_flags,
|
||||
'nvcc': nvcc_flags,
|
||||
}
|
||||
),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension,
|
||||
}
|
||||
)
|
||||
10
gridencoder/src/bindings.cpp
Normal file
10
gridencoder/src/bindings.cpp
Normal file
@@ -0,0 +1,10 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "gridencoder.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
|
||||
m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
|
||||
m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)");
|
||||
m.def("grad_weight_decay", &grad_weight_decay, "grad_weight_decay (CUDA)");
|
||||
}
|
||||
713
gridencoder/src/gridencoder.cu
Normal file
713
gridencoder/src/gridencoder.cu
Normal file
@@ -0,0 +1,713 @@
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
|
||||
#include <stdint.h>
|
||||
#include <cstdio>
|
||||
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
|
||||
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
|
||||
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
|
||||
|
||||
|
||||
// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF... program will never reach here!
|
||||
__device__ inline at::Half atomicAdd(at::Half *address, at::Half val) {
|
||||
// requires CUDA >= 10 and ARCH >= 70
|
||||
// this is very slow compared to float or __half2, never use it.
|
||||
//return atomicAdd(reinterpret_cast<__half*>(address), val);
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ inline T div_round_up(T val, T divisor) {
|
||||
return (val + divisor - 1) / divisor;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline T smoothstep(T val) {
|
||||
return val*val*(3.0f - 2.0f * val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline T smoothstep_derivative(T val) {
|
||||
return 6*val*(1.0f - val);
|
||||
}
|
||||
|
||||
|
||||
template <uint32_t D>
|
||||
__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {
|
||||
|
||||
// coherent type of hashing
|
||||
constexpr uint32_t primes[7] = { 1u, 2654435761u, 805459861u, 3674653429u, 2097192037u, 1434869437u, 2165219737u };
|
||||
|
||||
uint32_t result = 0;
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < D; ++i) {
|
||||
result ^= pos_grid[i] * primes[i];
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
template <uint32_t D, uint32_t C>
|
||||
__device__ uint32_t get_grid_index(const uint32_t gridtype, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) {
|
||||
uint32_t stride = 1;
|
||||
uint32_t index = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
|
||||
index += pos_grid[d] * stride;
|
||||
stride *= resolution;
|
||||
}
|
||||
|
||||
// NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97.
|
||||
// gridtype: 0 == hash, 1 == tiled
|
||||
if (gridtype == 0 && stride > hashmap_size) {
|
||||
index = fast_hash<D>(pos_grid);
|
||||
}
|
||||
|
||||
return (index % hashmap_size) * C + ch;
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t, uint32_t D, uint32_t C>
|
||||
__global__ void kernel_grid(
|
||||
const float * __restrict__ inputs,
|
||||
const scalar_t * __restrict__ grid,
|
||||
const int * __restrict__ offsets,
|
||||
scalar_t * __restrict__ outputs,
|
||||
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
|
||||
scalar_t * __restrict__ dy_dx,
|
||||
const uint32_t gridtype,
|
||||
const bool align_corners,
|
||||
const uint32_t interp
|
||||
) {
|
||||
const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (b >= B) return;
|
||||
|
||||
const uint32_t level = blockIdx.y;
|
||||
|
||||
// locate
|
||||
grid += (uint32_t)offsets[level] * C;
|
||||
inputs += b * D;
|
||||
outputs += level * B * C + b * C;
|
||||
|
||||
// check input range (should be in [0, 1])
|
||||
bool flag_oob = false;
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
if (inputs[d] < 0 || inputs[d] > 1) {
|
||||
flag_oob = true;
|
||||
}
|
||||
}
|
||||
// if input out of bound, just set output to 0
|
||||
if (flag_oob) {
|
||||
#pragma unroll
|
||||
for (uint32_t ch = 0; ch < C; ch++) {
|
||||
outputs[ch] = 0;
|
||||
}
|
||||
if (dy_dx) {
|
||||
dy_dx += b * D * L * C + level * D * C; // B L D C
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
#pragma unroll
|
||||
for (uint32_t ch = 0; ch < C; ch++) {
|
||||
dy_dx[d * C + ch] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
|
||||
const uint32_t resolution = (uint32_t)ceil(exp2f(level * S) * H);
|
||||
|
||||
// calculate coordinate (always use float for precision!)
|
||||
float pos[D];
|
||||
float pos_deriv[D];
|
||||
uint32_t pos_grid[D];
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
|
||||
// align_corners
|
||||
if (align_corners) {
|
||||
pos[d] = inputs[d] * (float)(resolution - 1); // [0, resolution - 1]
|
||||
pos_grid[d] = min((uint32_t)floorf(pos[d]), resolution - 2); // left-top corner, [0, resolution - 2]
|
||||
} else {
|
||||
pos[d] = fminf(fmaxf(inputs[d] * (float)resolution - 0.5f, 0.0f), (float)(resolution - 1)); // [-0.5, resolution-0.5] --> [0, resolution - 1]
|
||||
pos_grid[d] = (uint32_t)floorf(pos[d]); // left-top corner, [0, resolution - 1]
|
||||
}
|
||||
pos[d] -= (float)pos_grid[d];
|
||||
|
||||
// smoothstep instead of linear
|
||||
if (interp == 1) {
|
||||
pos_deriv[d] = smoothstep_derivative(pos[d]);
|
||||
pos[d] = smoothstep(pos[d]);
|
||||
} else {
|
||||
pos_deriv[d] = 1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// verification of alignment
|
||||
// if (level == L - 1 && b < 4) {
|
||||
// printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
|
||||
// }
|
||||
|
||||
// interpolate
|
||||
scalar_t results[C] = {0}; // temp results in register
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t idx = 0; idx < (1 << D); idx++) {
|
||||
float w = 1;
|
||||
uint32_t pos_grid_local[D];
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
if ((idx & (1 << d)) == 0) {
|
||||
w *= 1 - pos[d];
|
||||
pos_grid_local[d] = pos_grid[d];
|
||||
} else {
|
||||
w *= pos[d];
|
||||
pos_grid_local[d] = min(pos_grid[d] + 1, resolution - 1);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t index = get_grid_index<D, C>(gridtype, 0, hashmap_size, resolution, pos_grid_local);
|
||||
|
||||
// writing to register (fast)
|
||||
#pragma unroll
|
||||
for (uint32_t ch = 0; ch < C; ch++) {
|
||||
results[ch] += w * grid[index + ch];
|
||||
}
|
||||
|
||||
//printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]);
|
||||
}
|
||||
|
||||
// writing to global memory (slow)
|
||||
#pragma unroll
|
||||
for (uint32_t ch = 0; ch < C; ch++) {
|
||||
outputs[ch] = results[ch];
|
||||
}
|
||||
|
||||
// prepare dy_dx
|
||||
// differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9
|
||||
if (dy_dx) {
|
||||
|
||||
dy_dx += b * D * L * C + level * D * C; // B L D C
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t gd = 0; gd < D; gd++) {
|
||||
|
||||
scalar_t results_grad[C] = {0};
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {
|
||||
float w = (float)(align_corners ? resolution - 1 : resolution);
|
||||
uint32_t pos_grid_local[D];
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t nd = 0; nd < D - 1; nd++) {
|
||||
const uint32_t d = (nd >= gd) ? (nd + 1) : nd;
|
||||
|
||||
if ((idx & (1 << nd)) == 0) {
|
||||
w *= 1 - pos[d];
|
||||
pos_grid_local[d] = pos_grid[d];
|
||||
} else {
|
||||
w *= pos[d];
|
||||
pos_grid_local[d] = min(pos_grid[d] + 1, resolution - 1);
|
||||
}
|
||||
}
|
||||
|
||||
pos_grid_local[gd] = pos_grid[gd];
|
||||
uint32_t index_left = get_grid_index<D, C>(gridtype, 0, hashmap_size, resolution, pos_grid_local);
|
||||
pos_grid_local[gd] = min(pos_grid[gd] + 1, resolution - 1);
|
||||
uint32_t index_right = get_grid_index<D, C>(gridtype, 0, hashmap_size, resolution, pos_grid_local);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t ch = 0; ch < C; ch++) {
|
||||
results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]) * pos_deriv[gd];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t ch = 0; ch < C; ch++) {
|
||||
dy_dx[gd * C + ch] = results_grad[ch];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t, uint32_t D, uint32_t C, uint32_t N_C>
|
||||
__global__ void kernel_grid_backward(
|
||||
const scalar_t * __restrict__ grad,
|
||||
const float * __restrict__ inputs,
|
||||
const scalar_t * __restrict__ grid,
|
||||
const int * __restrict__ offsets,
|
||||
scalar_t * __restrict__ grad_grid,
|
||||
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
|
||||
const uint32_t gridtype,
|
||||
const bool align_corners,
|
||||
const uint32_t interp
|
||||
) {
|
||||
const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;
|
||||
if (b >= B) return;
|
||||
|
||||
const uint32_t level = blockIdx.y;
|
||||
const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;
|
||||
|
||||
// locate
|
||||
grad_grid += offsets[level] * C;
|
||||
inputs += b * D;
|
||||
grad += level * B * C + b * C + ch; // L, B, C
|
||||
|
||||
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
|
||||
const uint32_t resolution = (uint32_t)ceil(exp2f(level * S) * H);
|
||||
|
||||
// check input range (should be in [0, 1])
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
if (inputs[d] < 0 || inputs[d] > 1) {
|
||||
return; // grad is init as 0, so we simply return.
|
||||
}
|
||||
}
|
||||
|
||||
// calculate coordinate
|
||||
float pos[D];
|
||||
uint32_t pos_grid[D];
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
// align_corners
|
||||
if (align_corners) {
|
||||
pos[d] = inputs[d] * (float)(resolution - 1); // [0, resolution - 1]
|
||||
pos_grid[d] = min((uint32_t)floorf(pos[d]), resolution - 2); // left-top corner, [0, resolution - 2]
|
||||
} else {
|
||||
pos[d] = fminf(fmaxf(inputs[d] * (float)resolution - 0.5f, 0.0f), (float)(resolution - 1)); // [-0.5, resolution-0.5] --> [0, resolution - 1]
|
||||
pos_grid[d] = (uint32_t)floorf(pos[d]); // left-top corner, [0, resolution - 1]
|
||||
}
|
||||
pos[d] -= (float)pos_grid[d];
|
||||
// smoothstep instead of linear
|
||||
if (interp == 1) {
|
||||
pos[d] = smoothstep(pos[d]);
|
||||
}
|
||||
}
|
||||
|
||||
scalar_t grad_cur[N_C] = {0}; // fetch to register
|
||||
#pragma unroll
|
||||
for (uint32_t c = 0; c < N_C; c++) {
|
||||
grad_cur[c] = grad[c];
|
||||
}
|
||||
|
||||
// interpolate
|
||||
#pragma unroll
|
||||
for (uint32_t idx = 0; idx < (1 << D); idx++) {
|
||||
float w = 1;
|
||||
uint32_t pos_grid_local[D];
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
if ((idx & (1 << d)) == 0) {
|
||||
w *= 1 - pos[d];
|
||||
pos_grid_local[d] = pos_grid[d];
|
||||
} else {
|
||||
w *= pos[d];
|
||||
pos_grid_local[d] = min(pos_grid[d] + 1, resolution - 1);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t index = get_grid_index<D, C>(gridtype, ch, hashmap_size, resolution, pos_grid_local);
|
||||
|
||||
// atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0
|
||||
// TODO: use float which is better than __half, if N_C % 2 != 0
|
||||
if (std::is_same<scalar_t, at::Half>::value && N_C % 2 == 0) {
|
||||
#pragma unroll
|
||||
for (uint32_t c = 0; c < N_C; c += 2) {
|
||||
// process two __half at once (by interpreting as a __half2)
|
||||
__half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};
|
||||
atomicAdd((__half2*)&grad_grid[index + c], v);
|
||||
}
|
||||
// float, or __half when N_C % 2 != 0 (which means C == 1)
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t c = 0; c < N_C; c++) {
|
||||
atomicAdd(&grad_grid[index + c], w * grad_cur[c]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t, uint32_t D, uint32_t C>
|
||||
__global__ void kernel_input_backward(
|
||||
const scalar_t * __restrict__ grad,
|
||||
const scalar_t * __restrict__ dy_dx,
|
||||
scalar_t * __restrict__ grad_inputs,
|
||||
uint32_t B, uint32_t L
|
||||
) {
|
||||
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (t >= B * D) return;
|
||||
|
||||
const uint32_t b = t / D;
|
||||
const uint32_t d = t - b * D;
|
||||
|
||||
dy_dx += b * L * D * C;
|
||||
|
||||
scalar_t result = 0;
|
||||
|
||||
# pragma unroll
|
||||
for (int l = 0; l < L; l++) {
|
||||
# pragma unroll
|
||||
for (int ch = 0; ch < C; ch++) {
|
||||
result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];
|
||||
}
|
||||
}
|
||||
|
||||
grad_inputs[t] = result;
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t, uint32_t D>
|
||||
void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
|
||||
static constexpr uint32_t N_THREAD = 512;
|
||||
const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), max_level, 1 };
|
||||
switch (C) {
|
||||
case 1: kernel_grid<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
|
||||
case 2: kernel_grid<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
|
||||
case 4: kernel_grid<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
|
||||
case 8: kernel_grid<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
|
||||
case 16: kernel_grid<scalar_t, D, 16><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
|
||||
case 32: kernel_grid<scalar_t, D, 32><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
|
||||
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, 8, 16 or 32."};
|
||||
}
|
||||
}
|
||||
|
||||
// inputs: [B, D], float, in [0, 1]
|
||||
// embeddings: [sO, C], float
|
||||
// offsets: [L + 1], uint32_t
|
||||
// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.)
|
||||
// H: base resolution
|
||||
// dy_dx: [B, L * D * C]
|
||||
template <typename scalar_t>
|
||||
void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
|
||||
switch (D) {
|
||||
case 2: kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break;
|
||||
case 3: kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break;
|
||||
case 4: kernel_grid_wrapper<scalar_t, 4>(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break;
|
||||
case 5: kernel_grid_wrapper<scalar_t, 5>(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break;
|
||||
default: throw std::runtime_error{"GridEncoding: D must be 2, 3, 4 or 5."};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, uint32_t D>
|
||||
void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
|
||||
static constexpr uint32_t N_THREAD = 256;
|
||||
const uint32_t N_C = std::min(2u, C); // n_features_per_thread
|
||||
const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), max_level, 1 };
|
||||
switch (C) {
|
||||
case 1:
|
||||
kernel_grid_backward<scalar_t, D, 1, 1><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
|
||||
if (dy_dx) kernel_input_backward<scalar_t, D, 1><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
||||
break;
|
||||
case 2:
|
||||
kernel_grid_backward<scalar_t, D, 2, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
|
||||
if (dy_dx) kernel_input_backward<scalar_t, D, 2><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
||||
break;
|
||||
case 4:
|
||||
kernel_grid_backward<scalar_t, D, 4, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
|
||||
if (dy_dx) kernel_input_backward<scalar_t, D, 4><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
||||
break;
|
||||
case 8:
|
||||
kernel_grid_backward<scalar_t, D, 8, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
|
||||
if (dy_dx) kernel_input_backward<scalar_t, D, 8><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
||||
break;
|
||||
case 16:
|
||||
kernel_grid_backward<scalar_t, D, 16, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
|
||||
if (dy_dx) kernel_input_backward<scalar_t, D, 16><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
||||
break;
|
||||
case 32:
|
||||
kernel_grid_backward<scalar_t, D, 32, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
|
||||
if (dy_dx) kernel_input_backward<scalar_t, D, 32><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
||||
break;
|
||||
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, 8, 16 or 32."};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// grad: [L, B, C], float
|
||||
// inputs: [B, D], float, in [0, 1]
|
||||
// embeddings: [sO, C], float
|
||||
// offsets: [L + 1], uint32_t
|
||||
// grad_embeddings: [sO, C]
|
||||
// H: base resolution
|
||||
template <typename scalar_t>
|
||||
void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
|
||||
switch (D) {
|
||||
case 2: kernel_grid_backward_wrapper<scalar_t, 2>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
|
||||
case 3: kernel_grid_backward_wrapper<scalar_t, 3>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
|
||||
case 4: kernel_grid_backward_wrapper<scalar_t, 4>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
|
||||
case 5: kernel_grid_backward_wrapper<scalar_t, 5>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
|
||||
default: throw std::runtime_error{"GridEncoding: D must be 2, 3, 4 or 5."};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
|
||||
CHECK_CUDA(inputs);
|
||||
CHECK_CUDA(embeddings);
|
||||
CHECK_CUDA(offsets);
|
||||
CHECK_CUDA(outputs);
|
||||
// CHECK_CUDA(dy_dx);
|
||||
|
||||
CHECK_CONTIGUOUS(inputs);
|
||||
CHECK_CONTIGUOUS(embeddings);
|
||||
CHECK_CONTIGUOUS(offsets);
|
||||
CHECK_CONTIGUOUS(outputs);
|
||||
// CHECK_CONTIGUOUS(dy_dx);
|
||||
|
||||
CHECK_IS_FLOATING(inputs);
|
||||
CHECK_IS_FLOATING(embeddings);
|
||||
CHECK_IS_INT(offsets);
|
||||
CHECK_IS_FLOATING(outputs);
|
||||
// CHECK_IS_FLOATING(dy_dx);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
embeddings.scalar_type(), "grid_encode_forward", ([&] {
|
||||
grid_encode_forward_cuda<scalar_t>(inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), outputs.data_ptr<scalar_t>(), B, D, C, L, max_level, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners, interp);
|
||||
}));
|
||||
}
|
||||
|
||||
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
|
||||
CHECK_CUDA(grad);
|
||||
CHECK_CUDA(inputs);
|
||||
CHECK_CUDA(embeddings);
|
||||
CHECK_CUDA(offsets);
|
||||
CHECK_CUDA(grad_embeddings);
|
||||
// CHECK_CUDA(dy_dx);
|
||||
// CHECK_CUDA(grad_inputs);
|
||||
|
||||
CHECK_CONTIGUOUS(grad);
|
||||
CHECK_CONTIGUOUS(inputs);
|
||||
CHECK_CONTIGUOUS(embeddings);
|
||||
CHECK_CONTIGUOUS(offsets);
|
||||
CHECK_CONTIGUOUS(grad_embeddings);
|
||||
// CHECK_CONTIGUOUS(dy_dx);
|
||||
// CHECK_CONTIGUOUS(grad_inputs);
|
||||
|
||||
CHECK_IS_FLOATING(grad);
|
||||
CHECK_IS_FLOATING(inputs);
|
||||
CHECK_IS_FLOATING(embeddings);
|
||||
CHECK_IS_INT(offsets);
|
||||
CHECK_IS_FLOATING(grad_embeddings);
|
||||
// CHECK_IS_FLOATING(dy_dx);
|
||||
// CHECK_IS_FLOATING(grad_inputs);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
grad.scalar_type(), "grid_encode_backward", ([&] {
|
||||
grid_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), grad_embeddings.data_ptr<scalar_t>(), B, D, C, L, max_level, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners, interp);
|
||||
}));
|
||||
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t, uint32_t D, uint32_t C>
|
||||
__global__ void kernel_grad_tv(
|
||||
const scalar_t * __restrict__ inputs,
|
||||
const scalar_t * __restrict__ grid,
|
||||
scalar_t * __restrict__ grad,
|
||||
const int * __restrict__ offsets,
|
||||
const float weight,
|
||||
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
|
||||
const uint32_t gridtype,
|
||||
const bool align_corners
|
||||
) {
|
||||
const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (b >= B) return;
|
||||
|
||||
const uint32_t level = blockIdx.y;
|
||||
|
||||
// locate
|
||||
inputs += b * D;
|
||||
grid += (uint32_t)offsets[level] * C;
|
||||
grad += (uint32_t)offsets[level] * C;
|
||||
|
||||
// check input range (should be in [0, 1])
|
||||
bool flag_oob = false;
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
if (inputs[d] < 0 || inputs[d] > 1) {
|
||||
flag_oob = true;
|
||||
}
|
||||
}
|
||||
|
||||
// if input out of bound, do nothing
|
||||
if (flag_oob) return;
|
||||
|
||||
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
|
||||
const uint32_t resolution = (uint32_t)ceil(exp2f(level * S) * H);
|
||||
|
||||
// calculate coordinate
|
||||
float pos[D];
|
||||
uint32_t pos_grid[D]; // [0, resolution]
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
// align_corners
|
||||
if (align_corners) {
|
||||
pos[d] = inputs[d] * (float)(resolution - 1); // [0, resolution - 1]
|
||||
pos_grid[d] = min((uint32_t)floorf(pos[d]), resolution - 2); // left-top corner, [0, resolution - 2]
|
||||
} else {
|
||||
pos[d] = fminf(fmaxf(inputs[d] * (float)resolution - 0.5f, 0.0f), (float)(resolution - 1)); // [-0.5, resolution-0.5] --> [0, resolution - 1]
|
||||
pos_grid[d] = (uint32_t)floorf(pos[d]); // left-top corner, [0, resolution - 1]
|
||||
}
|
||||
}
|
||||
|
||||
//printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
|
||||
|
||||
// total variation on pos_grid
|
||||
scalar_t results[C] = {0}; // temp results in register
|
||||
scalar_t idelta[C] = {0};
|
||||
|
||||
uint32_t index = get_grid_index<D, C>(gridtype, 0, hashmap_size, resolution, pos_grid);
|
||||
|
||||
scalar_t w = weight / (2 * D);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
|
||||
uint32_t cur_d = pos_grid[d];
|
||||
scalar_t grad_val;
|
||||
|
||||
// right side
|
||||
if (cur_d < resolution) {
|
||||
pos_grid[d] = cur_d + 1;
|
||||
uint32_t index_right = get_grid_index<D, C>(gridtype, 0, hashmap_size, resolution, pos_grid);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t ch = 0; ch < C; ch++) {
|
||||
grad_val = (grid[index + ch] - grid[index_right + ch]);
|
||||
results[ch] += grad_val;
|
||||
idelta[ch] += grad_val * grad_val;
|
||||
}
|
||||
}
|
||||
|
||||
// left side
|
||||
if (cur_d > 0) {
|
||||
pos_grid[d] = cur_d - 1;
|
||||
uint32_t index_left = get_grid_index<D, C>(gridtype, 0, hashmap_size, resolution, pos_grid);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t ch = 0; ch < C; ch++) {
|
||||
grad_val = (grid[index + ch] - grid[index_left + ch]);
|
||||
results[ch] += grad_val;
|
||||
idelta[ch] += grad_val * grad_val;
|
||||
}
|
||||
}
|
||||
|
||||
// reset
|
||||
pos_grid[d] = cur_d;
|
||||
}
|
||||
|
||||
// writing to global memory (slow)
|
||||
#pragma unroll
|
||||
for (uint32_t ch = 0; ch < C; ch++) {
|
||||
// index may collide, so use atomic!
|
||||
atomicAdd(&grad[index + ch], w * results[ch] * rsqrtf(idelta[ch] + 1e-9f));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t, uint32_t D>
|
||||
void kernel_grad_tv_wrapper(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {
|
||||
static constexpr uint32_t N_THREAD = 512;
|
||||
const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };
|
||||
switch (C) {
|
||||
case 1: kernel_grad_tv<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
|
||||
case 2: kernel_grad_tv<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
|
||||
case 4: kernel_grad_tv<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
|
||||
case 8: kernel_grad_tv<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
|
||||
case 16: kernel_grad_tv<scalar_t, D, 16><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
|
||||
case 32: kernel_grad_tv<scalar_t, D, 32><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
|
||||
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, 8, 16 or 32."};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t>
|
||||
void grad_total_variation_cuda(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {
|
||||
switch (D) {
|
||||
case 2: kernel_grad_tv_wrapper<scalar_t, 2>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
|
||||
case 3: kernel_grad_tv_wrapper<scalar_t, 3>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
|
||||
case 4: kernel_grad_tv_wrapper<scalar_t, 4>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
|
||||
case 5: kernel_grad_tv_wrapper<scalar_t, 5>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
|
||||
default: throw std::runtime_error{"GridEncoding: D must be 2, 3, 4, or 5."};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
embeddings.scalar_type(), "grad_total_variation", ([&] {
|
||||
grad_total_variation_cuda<scalar_t>(inputs.data_ptr<scalar_t>(), embeddings.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(), offsets.data_ptr<int>(), weight, B, D, C, L, S, H, gridtype, align_corners);
|
||||
}));
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void kernel_grad_wd(
|
||||
const scalar_t * __restrict__ grid,
|
||||
scalar_t * __restrict__ grad,
|
||||
const int * __restrict__ offsets,
|
||||
const float weight,
|
||||
const uint32_t B, const uint32_t L, const uint32_t C
|
||||
) {
|
||||
const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (b >= B * C) return;
|
||||
|
||||
// locate
|
||||
grid += b;
|
||||
grad += b;
|
||||
|
||||
// decide in which level is this thread...
|
||||
uint32_t level = 0;
|
||||
const uint32_t n = b / C;
|
||||
// binary search b in offsets
|
||||
uint32_t l = 0, r = L;
|
||||
while (l < r) {
|
||||
uint32_t m = (l + r) / 2;
|
||||
if (offsets[m] <= n) {
|
||||
level = m;
|
||||
l = m + 1;
|
||||
} else {
|
||||
r = m;
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
|
||||
grad[0] += 2 * weight * grid[0] / hashmap_size;
|
||||
}
|
||||
|
||||
void grad_weight_decay(const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L) {
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
embeddings.scalar_type(), "grad_weight_decay", ([&] {
|
||||
static constexpr uint32_t N_THREAD = 1024;
|
||||
const dim3 blocks_hashgrid = { div_round_up(B * C, N_THREAD), 1, 1 };
|
||||
kernel_grad_wd<scalar_t><<<blocks_hashgrid, N_THREAD>>>(embeddings.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(), offsets.data_ptr<int>(), weight, B, L, C);
|
||||
}));
|
||||
}
|
||||
18
gridencoder/src/gridencoder.h
Normal file
18
gridencoder/src/gridencoder.h
Normal file
@@ -0,0 +1,18 @@
|
||||
#ifndef _HASH_ENCODE_H
|
||||
#define _HASH_ENCODE_H
|
||||
|
||||
#include <stdint.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
// inputs: [B, D], float, in [0, 1]
|
||||
// embeddings: [sO, C], float
|
||||
// offsets: [L + 1], uint32_t
|
||||
// outputs: [B, L * C], float
|
||||
// H: base resolution
|
||||
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp);
|
||||
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp);
|
||||
|
||||
void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners);
|
||||
void grad_weight_decay(const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L);
|
||||
|
||||
#endif
|
||||
52
guidance/clip_utils.py
Normal file
52
guidance/clip_utils.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as TF
|
||||
|
||||
import clip
|
||||
|
||||
class CLIP(nn.Module):
|
||||
def __init__(self, device, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.device = device
|
||||
self.clip_model, self.clip_preprocess = clip.load("ViT-B/16", device=self.device, jit=False)
|
||||
|
||||
self.aug = T.Compose([
|
||||
T.Resize((224, 224)),
|
||||
T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
])
|
||||
|
||||
def get_text_embeds(self, prompt, **kwargs):
|
||||
|
||||
text = clip.tokenize(prompt).to(self.device)
|
||||
text_z = self.clip_model.encode_text(text)
|
||||
text_z = text_z / text_z.norm(dim=-1, keepdim=True)
|
||||
|
||||
return text_z
|
||||
|
||||
def get_img_embeds(self, image, **kwargs):
|
||||
|
||||
image_z = self.clip_model.encode_image(self.aug(image))
|
||||
image_z = image_z / image_z.norm(dim=-1, keepdim=True)
|
||||
|
||||
return image_z
|
||||
|
||||
|
||||
def train_step(self, clip_z, pred_rgb, grad_scale=10, **kwargs):
|
||||
|
||||
image_z = self.clip_model.encode_image(self.aug(pred_rgb))
|
||||
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
|
||||
|
||||
loss = 0
|
||||
if 'image' in clip_z:
|
||||
loss = loss - (image_z * clip_z['image']).sum(-1).mean()
|
||||
|
||||
if 'text' in clip_z:
|
||||
loss = loss - (image_z * clip_z['text']).sum(-1).mean()
|
||||
|
||||
loss = loss * grad_scale
|
||||
|
||||
return loss
|
||||
|
||||
207
guidance/if_utils.py
Normal file
207
guidance/if_utils.py
Normal file
@@ -0,0 +1,207 @@
|
||||
from transformers import logging
|
||||
from diffusers import IFPipeline, DDPMScheduler
|
||||
|
||||
# suppress partial model loading warning
|
||||
logging.set_verbosity_error()
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
|
||||
class SpecifyGradient(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, input_tensor, gt_grad):
|
||||
ctx.save_for_backward(gt_grad)
|
||||
# we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
|
||||
return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_scale):
|
||||
gt_grad, = ctx.saved_tensors
|
||||
gt_grad = gt_grad * grad_scale
|
||||
return gt_grad, None
|
||||
|
||||
def seed_everything(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
#torch.backends.cudnn.deterministic = True
|
||||
#torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
class IF(nn.Module):
|
||||
def __init__(self, device, vram_O, t_range=[0.02, 0.98]):
|
||||
super().__init__()
|
||||
|
||||
self.device = device
|
||||
|
||||
print(f'[INFO] loading DeepFloyd IF-I-XL...')
|
||||
|
||||
model_key = "DeepFloyd/IF-I-XL-v1.0"
|
||||
|
||||
is_torch2 = torch.__version__[0] == '2'
|
||||
|
||||
# Create model
|
||||
pipe = IFPipeline.from_pretrained(model_key, variant="fp16", torch_dtype=torch.float16)
|
||||
if not is_torch2:
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
if vram_O:
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.enable_attention_slicing(1)
|
||||
pipe.enable_model_cpu_offload()
|
||||
else:
|
||||
pipe.to(device)
|
||||
|
||||
self.unet = pipe.unet
|
||||
self.tokenizer = pipe.tokenizer
|
||||
self.text_encoder = pipe.text_encoder
|
||||
self.unet = pipe.unet
|
||||
self.scheduler = pipe.scheduler
|
||||
|
||||
self.pipe = pipe
|
||||
|
||||
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
|
||||
self.min_step = int(self.num_train_timesteps * t_range[0])
|
||||
self.max_step = int(self.num_train_timesteps * t_range[1])
|
||||
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
|
||||
|
||||
print(f'[INFO] loaded DeepFloyd IF-I-XL!')
|
||||
|
||||
@torch.no_grad()
|
||||
def get_text_embeds(self, prompt):
|
||||
# prompt: [str]
|
||||
|
||||
# TODO: should I add the preprocessing at https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py#LL486C10-L486C28
|
||||
prompt = self.pipe._text_preprocessing(prompt, clean_caption=False)
|
||||
inputs = self.tokenizer(prompt, padding='max_length', max_length=77, truncation=True, add_special_tokens=True, return_tensors='pt')
|
||||
embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, grad_scale=1):
|
||||
|
||||
# [0, 1] to [-1, 1] and make sure shape is [64, 64]
|
||||
images = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
|
||||
|
||||
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
|
||||
t = torch.randint(self.min_step, self.max_step + 1, (images.shape[0],), dtype=torch.long, device=self.device)
|
||||
|
||||
# predict the noise residual with unet, NO grad!
|
||||
with torch.no_grad():
|
||||
# add noise
|
||||
noise = torch.randn_like(images)
|
||||
images_noisy = self.scheduler.add_noise(images, noise, t)
|
||||
|
||||
# pred noise
|
||||
model_input = torch.cat([images_noisy] * 2)
|
||||
model_input = self.scheduler.scale_model_input(model_input, t)
|
||||
tt = torch.cat([t] * 2)
|
||||
noise_pred = self.unet(model_input, tt, encoder_hidden_states=text_embeddings).sample
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
|
||||
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# TODO: how to use the variance here?
|
||||
# noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
||||
|
||||
# w(t), sigma_t^2
|
||||
w = (1 - self.alphas[t])
|
||||
grad = grad_scale * w[:, None, None, None] * (noise_pred - noise)
|
||||
grad = torch.nan_to_num(grad)
|
||||
|
||||
# since we omitted an item in grad, we need to use the custom function to specify the gradient
|
||||
loss = SpecifyGradient.apply(images, grad)
|
||||
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
def produce_imgs(self, text_embeddings, height=64, width=64, num_inference_steps=50, guidance_scale=7.5):
|
||||
|
||||
images = torch.randn((1, 3, height, width), device=text_embeddings.device, dtype=text_embeddings.dtype)
|
||||
images = images * self.scheduler.init_noise_sigma
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for i, t in enumerate(self.scheduler.timesteps):
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
model_input = torch.cat([images] * 2)
|
||||
model_input = self.scheduler.scale_model_input(model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
|
||||
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
images = self.scheduler.step(noise_pred, t, images).prev_sample
|
||||
|
||||
images = (images + 1) / 2
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
|
||||
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(negative_prompts, str):
|
||||
negative_prompts = [negative_prompts]
|
||||
|
||||
# Prompts -> text embeds
|
||||
pos_embeds = self.get_text_embeds(prompts) # [1, 77, 768]
|
||||
neg_embeds = self.get_text_embeds(negative_prompts)
|
||||
text_embeds = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768]
|
||||
|
||||
# Text embeds -> img
|
||||
imgs = self.produce_imgs(text_embeds, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]
|
||||
|
||||
# Img to Numpy
|
||||
imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
|
||||
imgs = (imgs * 255).round().astype('uint8')
|
||||
|
||||
return imgs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
import argparse
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('prompt', type=str)
|
||||
parser.add_argument('--negative', default='', type=str)
|
||||
parser.add_argument('--vram_O', action='store_true', help="optimization for low VRAM usage")
|
||||
parser.add_argument('-H', type=int, default=64)
|
||||
parser.add_argument('-W', type=int, default=64)
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--steps', type=int, default=50)
|
||||
opt = parser.parse_args()
|
||||
|
||||
seed_everything(opt.seed)
|
||||
|
||||
device = torch.device('cuda')
|
||||
|
||||
sd = IF(device, opt.vram_O)
|
||||
|
||||
imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
|
||||
|
||||
# visualize image
|
||||
plt.imshow(imgs[0])
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
|
||||
707
guidance/sd_utils.py
Normal file
707
guidance/sd_utils.py
Normal file
@@ -0,0 +1,707 @@
|
||||
from typing import List, Optional, Sequence, Tuple, Union, Mapping
|
||||
import os
|
||||
|
||||
from dataclasses import dataclass
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from os.path import isfile
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from torchvision.io import read_image
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import functional as TVF
|
||||
from torchvision.utils import make_grid, save_image
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer, CLIPProcessor
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def spherical_dist_loss(x, y):
|
||||
x = F.normalize(x, dim=-1)
|
||||
y = F.normalize(y, dim=-1)
|
||||
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
|
||||
|
||||
|
||||
def seed_everything(seed=None):
|
||||
if seed:
|
||||
seed = int(seed)
|
||||
random.seed(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
def save_tensor2image(x: torch.Tensor, path, channel_last=True, quality=75, **kwargs):
|
||||
# assume the input x is channel last
|
||||
if x.ndim == 4 and channel_last:
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
TVF.to_pil_image(make_grid(x, value_range=(0, 1), **kwargs)).save(path, quality=quality)
|
||||
|
||||
|
||||
def to_pil(x: torch.Tensor, **kwargs) -> Image.Image:
|
||||
return TVF.to_pil_image(make_grid(x, value_range=(0, 1), **kwargs))
|
||||
|
||||
|
||||
def to_np_img(x: torch.Tensor) -> np.ndarray:
|
||||
return (x.detach().permute(0, 2, 3, 1).cpu().numpy() * 255).round().astype(np.uint8)
|
||||
|
||||
|
||||
class SpecifyGradient(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, input_tensor, gt_grad):
|
||||
ctx.save_for_backward(gt_grad)
|
||||
# we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
|
||||
return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_scale):
|
||||
gt_grad, = ctx.saved_tensors
|
||||
gt_grad = gt_grad * grad_scale
|
||||
return gt_grad, None
|
||||
|
||||
|
||||
def token_replace(prompt, negative, learned_embeds_path):
|
||||
# Set up automatic token replacement for prompt
|
||||
if '<token>' in prompt or '<token>' in negative:
|
||||
if learned_embeds_path is None:
|
||||
raise ValueError(
|
||||
'--learned_embeds_path must be specified when using <token>')
|
||||
import torch
|
||||
tmp = list(torch.load(learned_embeds_path, map_location='cpu').keys())
|
||||
if len(tmp) != 1:
|
||||
raise ValueError(
|
||||
'Something is wrong with the dict passed in for --learned_embeds_path')
|
||||
token = tmp[0]
|
||||
prompt = prompt.replace('<token>', token)
|
||||
negative = negative.replace('<token>', token)
|
||||
logger.info(f'Prompt after replacing <token>: {prompt}')
|
||||
logger.info(f'Negative prompt after replacing <token>: {negative}')
|
||||
return prompt, negative
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet2DConditionOutput:
|
||||
# Not sure how to check what unet_traced.pt contains, and user wants. HalfTensor or FloatTensor
|
||||
sample: torch.HalfTensor
|
||||
|
||||
|
||||
def enable_vram(pipe):
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
pipe.enable_vae_slicing()
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.enable_attention_slicing(1)
|
||||
# pipe.enable_model_cpu_offload()
|
||||
|
||||
|
||||
def get_model_path(sd_version='2.1', clip_version='large', hf_key=None):
|
||||
if hf_key is not None:
|
||||
logger.info(f'[INFO] using hugging face custom model key: {hf_key}')
|
||||
sd_path = hf_key
|
||||
elif sd_version == '2.1':
|
||||
sd_path = "stabilityai/stable-diffusion-2-1-base"
|
||||
elif sd_version == '2.0':
|
||||
sd_path = "stabilityai/stable-diffusion-2-base"
|
||||
elif sd_version == '1.5':
|
||||
sd_path = "runwayml/stable-diffusion-v1-5"
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Stable-diffusion version {sd_version} not supported.')
|
||||
if clip_version == 'base':
|
||||
clip_path = "openai/clip-vit-base-patch32"
|
||||
else:
|
||||
clip_path = "openai/clip-vit-large-patch14"
|
||||
return sd_path, clip_path
|
||||
|
||||
|
||||
class StableDiffusion(nn.Module):
|
||||
def __init__(self, device, fp16, vram_O,
|
||||
sd_version='2.1', hf_key=None,
|
||||
t_range=[0.02, 0.98],
|
||||
use_clip=False,
|
||||
clip_version='base',
|
||||
clip_iterative=True,
|
||||
clip_t=0.4,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.device = device
|
||||
self.sd_version = sd_version
|
||||
self.vram_O = vram_O
|
||||
self.fp16 = fp16
|
||||
|
||||
logger.info(f'[INFO] loading stable diffusion...')
|
||||
|
||||
sd_path, clip_path = get_model_path(sd_version, clip_version, hf_key)
|
||||
self.precision_t = torch.float16 if fp16 else torch.float32
|
||||
|
||||
# Create model
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
sd_path, torch_dtype=self.precision_t, local_files_only=False)
|
||||
|
||||
if isfile('./unet_traced.pt'):
|
||||
# use jitted unet
|
||||
unet_traced = torch.jit.load('./unet_traced.pt')
|
||||
|
||||
class TracedUNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.in_channels = pipe.unet.in_channels
|
||||
self.device = pipe.unet.device
|
||||
|
||||
def forward(self, latent_model_input, t, encoder_hidden_states):
|
||||
sample = unet_traced(
|
||||
latent_model_input, t, encoder_hidden_states)[0]
|
||||
return UNet2DConditionOutput(sample=sample)
|
||||
pipe.unet = TracedUNet()
|
||||
|
||||
self.vae = pipe.vae
|
||||
self.tokenizer = pipe.tokenizer
|
||||
self.text_encoder = pipe.text_encoder
|
||||
self.unet = pipe.unet
|
||||
|
||||
if kwargs.get('learned_embeds_path', None) is not None:
|
||||
learned_embeds_path = kwargs['learned_embeds_path']
|
||||
if os.path.exists(learned_embeds_path):
|
||||
logger.info(
|
||||
f'[INFO] loading learned embeddings from {kwargs["learned_embeds_path"]}')
|
||||
self.add_tokens_to_model_from_path(learned_embeds_path, kwargs.get('overrride_token', None))
|
||||
else:
|
||||
logger.warning(f'learned_embeds_path {learned_embeds_path} does not exist!')
|
||||
|
||||
if vram_O:
|
||||
# this will change device from gpu to other types (meta)
|
||||
enable_vram(pipe)
|
||||
else:
|
||||
if is_xformers_available():
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
pipe.to(device)
|
||||
|
||||
self.scheduler = DDIMScheduler.from_pretrained(
|
||||
sd_path, subfolder="scheduler", torch_dtype=self.precision_t, local_files_only=False)
|
||||
|
||||
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
|
||||
self.min_step = int(self.num_train_timesteps * t_range[0])
|
||||
self.max_step = int(self.num_train_timesteps * t_range[1])
|
||||
self.alphas = self.scheduler.alphas_cumprod.to(
|
||||
self.device) # for convenience
|
||||
|
||||
logger.info(f'[INFO] loaded stable diffusion!')
|
||||
|
||||
# for CLIP
|
||||
self.use_clip = use_clip
|
||||
if self.use_clip:
|
||||
#breakpoint()
|
||||
self.clip_model = CLIPModel.from_pretrained(clip_path).to(device)
|
||||
image_processor = CLIPProcessor.from_pretrained(clip_path).image_processor
|
||||
self.image_processor = transforms.Compose([
|
||||
transforms.Resize((image_processor.crop_size['height'], image_processor.crop_size['width'])),
|
||||
transforms.Normalize(image_processor.image_mean, image_processor.image_std),
|
||||
])
|
||||
for p in self.clip_model.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
self.clip_iterative = clip_iterative
|
||||
self.clip_t = int(self.num_train_timesteps * clip_t)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_text_embeds(self, prompt):
|
||||
# Tokenize text and get embeddings
|
||||
text_input = self.tokenizer(
|
||||
prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
|
||||
|
||||
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
||||
return text_embeddings
|
||||
|
||||
@torch.no_grad()
|
||||
def get_all_text_embeds(self, prompt):
|
||||
# Tokenize text and get embeddings
|
||||
text_input = self.tokenizer(
|
||||
prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
|
||||
|
||||
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))
|
||||
# text_z = text_z / text_z.norm(dim=-1, keepdim=True)
|
||||
|
||||
# return all text embeddings and class embeddings
|
||||
return torch.cat([text_embeddings[0], text_embeddings[1].unsqueeze(1)], dim=1)
|
||||
|
||||
# @torch.no_grad()
|
||||
def get_clip_img_embeds(self, img):
|
||||
img = self.image_processor(img)
|
||||
image_z = self.clip_model.get_image_features(img)
|
||||
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
|
||||
return image_z
|
||||
|
||||
def clip_loss(self, ref_z, pred_rgb):
|
||||
image_z = self.get_clip_img_embeds(pred_rgb)
|
||||
loss = spherical_dist_loss(image_z, ref_z)
|
||||
return loss
|
||||
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, as_latent=False, grad_clip=None, grad_scale=1.0,
|
||||
image_ref_clip=None, text_ref_clip=None, clip_guidance=100, clip_image_loss=False,
|
||||
density=None,
|
||||
save_guidance_path=None
|
||||
):
|
||||
enable_clip = self.use_clip and clip_guidance > 0 and not as_latent
|
||||
enable_sds = True
|
||||
#breakpoint()
|
||||
if as_latent:
|
||||
latents = F.interpolate(
|
||||
pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
|
||||
else:
|
||||
# interp to 512x512 to be fed into vae.
|
||||
pred_rgb_512 = F.interpolate(
|
||||
pred_rgb, (512, 512), mode='bilinear', align_corners=False)
|
||||
# encode image into latents with vae, requires grad!
|
||||
latents = self.encode_imgs(pred_rgb_512)
|
||||
|
||||
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
|
||||
# Since during the optimzation, the 3D is getting better.
|
||||
# mn = max(self.min_step, int(self.max_step - (self.max_step - self.min_step) / (self.opt.max_epoch // 3) * self.epoch + 0.5))
|
||||
t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device)
|
||||
if enable_clip and self.clip_iterative:
|
||||
if t > self.clip_t:
|
||||
enable_clip = False
|
||||
else:
|
||||
enable_sds = False
|
||||
|
||||
# predict the noise residual with unet, NO grad!
|
||||
with torch.no_grad():
|
||||
# add noise
|
||||
noise = torch.randn_like(latents)
|
||||
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
||||
|
||||
# pred noise
|
||||
latent_model_input = torch.cat([latents_noisy] * 2)
|
||||
# Save input tensors for UNet
|
||||
# torch.save(latent_model_input, "train_latent_model_input.pt")
|
||||
# torch.save(t, "train_t.pt")
|
||||
# torch.save(text_embeddings, "train_text_embeddings.pt")
|
||||
tt = torch.cat([t]*2)
|
||||
noise_pred = self.unet(latent_model_input, tt,
|
||||
encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance (high scale from paper!)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_text + guidance_scale * \
|
||||
(noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if enable_clip:
|
||||
pred_original_sample = (latents_noisy - (1 - self.alphas[t]) ** (0.5) * noise_pred) / self.alphas[t] ** (0.5)
|
||||
sample = pred_original_sample
|
||||
sample = sample.detach().requires_grad_()
|
||||
|
||||
sample = 1 / self.vae.config.scaling_factor * sample
|
||||
out_image = self.vae.decode(sample).sample
|
||||
out_image = (out_image / 2 + 0.5)#.clamp(0, 1)
|
||||
image_embeddings_clip = self.get_clip_img_embeds(out_image)
|
||||
ref_clip = image_ref_clip if clip_image_loss else text_ref_clip
|
||||
loss_clip = spherical_dist_loss(image_embeddings_clip, ref_clip).mean() * clip_guidance * 50 # 100
|
||||
grad_clipd = - torch.autograd.grad(loss_clip, sample, retain_graph=True)[0]
|
||||
else:
|
||||
grad_clipd = 0
|
||||
|
||||
# import kiui
|
||||
# latents_tmp = torch.randn((1, 4, 64, 64), device=self.device)
|
||||
# latents_tmp = latents_tmp.detach()
|
||||
# kiui.lo(latents_tmp)
|
||||
# self.scheduler.set_timesteps(30)
|
||||
# for i, t in enumerate(self.scheduler.timesteps):
|
||||
# latent_model_input = torch.cat([latents_tmp] * 3)
|
||||
# noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
|
||||
# noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
|
||||
# noise_pred = noise_pred_uncond + 10 * (noise_pred_pos - noise_pred_uncond)
|
||||
# latents_tmp = self.scheduler.step(noise_pred, t, latents_tmp)['prev_sample']
|
||||
# imgs = self.decode_latents(latents_tmp)
|
||||
# kiui.vis.plot_image(imgs)
|
||||
|
||||
if density is not None:
|
||||
with torch.no_grad():
|
||||
density = F.interpolate(density.detach(), (64, 64), mode='bilinear', align_corners=False)
|
||||
ids = torch.nonzero(density.squeeze())
|
||||
spatial_weight = torch.ones_like(density, device=density.device)
|
||||
try:
|
||||
up = ids[:, 0].min()
|
||||
down = ids[:, 0].max() + 1
|
||||
ll = ids[:, 1].min()
|
||||
rr = ids[:, 1].max() + 1
|
||||
spatial_weight[:, :, up:down, ll:rr] += 1
|
||||
except:
|
||||
pass
|
||||
# breakpoint()
|
||||
# w(t), sigma_t^2
|
||||
w = (1 - self.alphas[t])[:, None, None, None]
|
||||
# w = self.alphas[t] ** 0.5 * (1 - self.alphas[t])
|
||||
|
||||
if enable_sds:
|
||||
grad_sds = grad_scale * w * (noise_pred - noise)
|
||||
loss_sds = grad_sds.abs().mean().detach()
|
||||
else:
|
||||
grad_sds = 0.
|
||||
loss_sds = 0.
|
||||
|
||||
if enable_clip:
|
||||
grad_clipd = w * grad_clipd.detach()
|
||||
loss_clipd = grad_clipd.abs().mean().detach()
|
||||
else:
|
||||
grad_clipd = 0.
|
||||
loss_clipd = 0.
|
||||
|
||||
grad = grad_clipd + grad_sds
|
||||
|
||||
if grad_clip is not None:
|
||||
grad = grad.clamp(-grad_clip, grad_clip)
|
||||
|
||||
if density is not None:
|
||||
grad = grad * spatial_weight / 2
|
||||
|
||||
grad = torch.nan_to_num(grad)
|
||||
|
||||
# since we omitted an item in grad, we need to use the custom function to specify the gradient
|
||||
# loss = SpecifyGradient.apply(latents, grad)
|
||||
# loss = loss.abs().mean().detach()
|
||||
latents.backward(gradient=grad, retain_graph=True)
|
||||
loss = grad.abs().mean().detach()
|
||||
|
||||
if not enable_clip:
|
||||
loss_sds = loss
|
||||
|
||||
if save_guidance_path:
|
||||
with torch.no_grad():
|
||||
# save original input
|
||||
images = []
|
||||
os.makedirs(os.path.dirname(save_guidance_path), exist_ok=True)
|
||||
timesteps = torch.arange(-1, 1000, 100, dtype=torch.long, device=self.device)
|
||||
timesteps[0] *= 0
|
||||
for t in timesteps:
|
||||
if as_latent:
|
||||
pred_rgb_512 = self.decode_latents(latents)
|
||||
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
||||
|
||||
# pred noise
|
||||
latent_model_input = torch.cat([latents_noisy] * 2)
|
||||
|
||||
noise_pred = self.unet(latent_model_input, t,
|
||||
encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance (high scale from paper!)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_text + guidance_scale * \
|
||||
(noise_pred_text - noise_pred_uncond)
|
||||
|
||||
pred_original_sample = self.decode_latents((latents_noisy - (1 - self.alphas[t]) ** (0.5) * noise_pred) / self.alphas[t] ** (0.5))
|
||||
|
||||
# visualize predicted denoised image
|
||||
# claforte: discuss this with Vikram!!
|
||||
result_hopefully_less_noisy_image = self.decode_latents(latents - w*(noise_pred - noise))
|
||||
|
||||
# visualize noisier image
|
||||
result_noisier_image = self.decode_latents(latents_noisy)
|
||||
|
||||
# add in the last col, w/o rendered view contraint, using random noise as latent.
|
||||
latent_model_input = torch.cat([noise] * 2)
|
||||
noise_pred = self.unet(latent_model_input, t,
|
||||
encoder_hidden_states=text_embeddings).sample
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_text + guidance_scale * \
|
||||
(noise_pred_text - noise_pred_uncond)
|
||||
noise_diffusion_out = self.decode_latents((noise - (1 - self.alphas[t]) ** (0.5) * noise_pred) / self.alphas[t] ** (0.5))
|
||||
# all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]
|
||||
image = torch.cat([pred_rgb_512, pred_original_sample, result_noisier_image, result_hopefully_less_noisy_image, noise_diffusion_out],dim=0)
|
||||
images.append(image)
|
||||
viz_images = torch.cat(images, dim=0)
|
||||
save_image(viz_images, save_guidance_path, nrow=5)
|
||||
|
||||
return loss, {'loss_sds': loss_sds, 'loss_clipd': loss_clipd}
|
||||
|
||||
@torch.no_grad()
|
||||
def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
|
||||
|
||||
if latents is None:
|
||||
latents = torch.randn(
|
||||
(text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
with torch.autocast('cuda'):
|
||||
for i, t in enumerate(self.scheduler.timesteps):
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
# latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# Save input tensors for UNet
|
||||
# torch.save(latent_model_input, "produce_latents_latent_model_input.pt")
|
||||
# torch.save(t, "produce_latents_t.pt")
|
||||
# torch.save(text_embeddings, "produce_latents_text_embeddings.pt")
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
|
||||
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_text + guidance_scale * \
|
||||
(noise_pred_text - noise_pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(noise_pred, t, latents)[
|
||||
'prev_sample']
|
||||
|
||||
return latents
|
||||
|
||||
def decode_latents(self, latents):
|
||||
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
# with torch.no_grad():
|
||||
imgs = self.vae.decode(latents).sample
|
||||
|
||||
imgs = (imgs / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
return imgs
|
||||
|
||||
def encode_imgs(self, imgs):
|
||||
# imgs: [B, 3, H, W]
|
||||
|
||||
imgs = 2 * imgs - 1
|
||||
|
||||
posterior = self.vae.encode(imgs).latent_dist
|
||||
latents = posterior.sample() * self.vae.config.scaling_factor
|
||||
|
||||
return latents
|
||||
|
||||
def encode_imgs_mean(self, imgs):
|
||||
# imgs: [B, 3, H, W]
|
||||
|
||||
imgs = 2 * imgs - 1
|
||||
|
||||
latents = self.vae.encode(imgs).latent_dist.mean
|
||||
latents = latents * self.vae.config.scaling_factor
|
||||
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, to_numpy=True):
|
||||
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(negative_prompts, str):
|
||||
negative_prompts = [negative_prompts] * len(prompts)
|
||||
|
||||
prompts = tuple(prompts)
|
||||
negative_prompts = tuple(negative_prompts)
|
||||
# Prompts -> text embeds
|
||||
pos_embeds = self.get_text_embeds(prompts) # [1, 77, 768]
|
||||
neg_embeds = self.get_text_embeds(negative_prompts)
|
||||
text_embeds = torch.cat(
|
||||
[neg_embeds, pos_embeds], dim=0) # [2, 77, 768]
|
||||
|
||||
# Text embeds -> img latents
|
||||
latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents,
|
||||
num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]
|
||||
|
||||
# Img latents -> imgs
|
||||
imgs = self.decode_latents(latents.to(
|
||||
text_embeds.dtype)) # [1, 3, 512, 512]
|
||||
|
||||
# Img to Numpy
|
||||
if to_numpy:
|
||||
imgs = to_np_img(imgs)
|
||||
return imgs
|
||||
|
||||
@torch.no_grad()
|
||||
def img_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, img=None, to_numpy=True, t=50):
|
||||
"""
|
||||
Known issues:
|
||||
1. Not able to reconstruct images even with no noise.
|
||||
"""
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(negative_prompts, str):
|
||||
negative_prompts = [negative_prompts]
|
||||
|
||||
# Prompts -> text embeds
|
||||
pos_embeds = self.get_text_embeds(prompts) # [1, 77, 768]
|
||||
neg_embeds = self.get_text_embeds(negative_prompts)
|
||||
text_embeds = torch.cat(
|
||||
[neg_embeds, pos_embeds], dim=0) # [2, 77, 768]
|
||||
|
||||
# image to latent
|
||||
# interp to 512x512 to be fed into vae.
|
||||
if isinstance(img, str):
|
||||
img = TVF.to_tensor(Image.open(img))[None, :3].cuda()
|
||||
|
||||
img_512 = F.interpolate(
|
||||
img.to(text_embeds.dtype), (512, 512), mode='bilinear', align_corners=False)
|
||||
# logger.info(img_512.shape, img_512, '\n', img_512.min(), img_512.max(), img_512.mean())
|
||||
|
||||
# encode image into latents with vae, requires grad!
|
||||
latents = self.encode_imgs(img_512).repeat(
|
||||
text_embeds.shape[0] // 2, 1, 1, 1)
|
||||
# logger.info(latents.shape, latents, '\n', latents.min(), latents.max(), latents.mean())
|
||||
|
||||
noise = torch.randn_like(latents)
|
||||
if t > 0:
|
||||
latents_noise = self.scheduler.add_noise(
|
||||
latents, noise, torch.tensor(t).to(torch.int32))
|
||||
else:
|
||||
latents_noise = latents
|
||||
|
||||
# Text embeds -> img latents
|
||||
latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents_noise,
|
||||
num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]
|
||||
|
||||
# Img latents -> imgs
|
||||
imgs = self.decode_latents(latents.to(
|
||||
text_embeds.dtype)) # [1, 3, 512, 512]
|
||||
|
||||
# Img to Numpy
|
||||
if to_numpy:
|
||||
imgs = to_np_img(imgs)
|
||||
return imgs
|
||||
|
||||
def add_tokens_to_model(self, learned_embeds: Mapping[str, Tensor], override_token: Optional[Union[str, dict]] = None) -> None:
|
||||
r"""Adds tokens to the tokenizer and text encoder of a model."""
|
||||
|
||||
# Loop over learned embeddings
|
||||
new_tokens = []
|
||||
for token, embedding in learned_embeds.items():
|
||||
embedding = embedding.to(
|
||||
self.text_encoder.get_input_embeddings().weight.dtype)
|
||||
if override_token is not None:
|
||||
token = override_token if isinstance(
|
||||
override_token, str) else override_token[token]
|
||||
|
||||
# Add the token to the tokenizer
|
||||
num_added_tokens = self.tokenizer.add_tokens(token)
|
||||
if num_added_tokens == 0:
|
||||
raise ValueError((f"The tokenizer already contains the token {token}. Please pass a "
|
||||
"different `token` that is not already in the tokenizer."))
|
||||
|
||||
# Resize the token embeddings
|
||||
self.text_encoder._resize_token_embeddings(len(self.tokenizer))
|
||||
|
||||
# Get the id for the token and assign the embeds
|
||||
token_id = self.tokenizer.convert_tokens_to_ids(token)
|
||||
self.text_encoder.get_input_embeddings(
|
||||
).weight.data[token_id] = embedding
|
||||
new_tokens.append(token)
|
||||
|
||||
logger.info(
|
||||
f'Added {len(new_tokens)} tokens to tokenizer and text embedding: {new_tokens}')
|
||||
|
||||
def add_tokens_to_model_from_path(self, learned_embeds_path: str, override_token: Optional[Union[str, dict]] = None) -> None:
|
||||
r"""Loads tokens from a file and adds them to the tokenizer and text encoder of a model."""
|
||||
learned_embeds: Mapping[str, Tensor] = torch.load(
|
||||
learned_embeds_path, map_location='cpu')
|
||||
self.add_tokens_to_model(learned_embeds, override_token)
|
||||
|
||||
def check_prompt(self, opt):
|
||||
texts = ['', ', front view', ', side view', ', back view']
|
||||
for view_text in texts:
|
||||
text = opt.text + view_text
|
||||
logger.info(f'Checking stable diffusion model with prompt: {text}')
|
||||
# Generate
|
||||
image_check = self.prompt_to_img(
|
||||
prompts=[text] * opt.get('prompt_check_nums', 5), guidance_scale=7.5, to_numpy=False,
|
||||
num_inference_steps=opt.get('num_inference_steps', 50))
|
||||
# Save
|
||||
output_dir_check = Path(opt.workspace) / 'prompt_check'
|
||||
output_dir_check.mkdir(exist_ok=True, parents=True)
|
||||
to_pil(image_check).save(output_dir_check / f'generations_{view_text}.png')
|
||||
(output_dir_check / 'prompt.txt').write_text(text)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
import argparse
|
||||
import matplotlib.pyplot as plt
|
||||
from easydict import EasyDict as edict
|
||||
import glob
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--text', type=str)
|
||||
parser.add_argument('--negative', default='', type=str)
|
||||
parser.add_argument('--workspace', default='out/sd', type=str)
|
||||
parser.add_argument('--image_path', default=None, type=str)
|
||||
parser.add_argument('--learned_embeds_path', type=str,
|
||||
default=None, help="path to learned embeds"
|
||||
)
|
||||
parser.add_argument('--sd_version', type=str, default='1.5',
|
||||
choices=['1.5', '2.0', '2.1'], help="stable diffusion version")
|
||||
parser.add_argument('--hf_key', type=str, default=None,
|
||||
help="hugging face Stable diffusion model key")
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="use float16 for training")
|
||||
parser.add_argument('--vram_O', action='store_true',
|
||||
help="optimization for low VRAM usage")
|
||||
parser.add_argument('--gudiance_scale', type=float, default=100)
|
||||
parser.add_argument('-H', type=int, default=512)
|
||||
parser.add_argument('-W', type=int, default=512)
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--num_inference_steps', type=int, default=50)
|
||||
parser.add_argument('--noise_t', type=int, default=50)
|
||||
parser.add_argument('--prompt_check_nums', type=int, default=5)
|
||||
opt, unknown = parser.parse_known_args()
|
||||
|
||||
# seed_everything(opt.seed)
|
||||
device = torch.device('cuda')
|
||||
opt = edict(vars(opt))
|
||||
workspace = opt.workspace
|
||||
|
||||
opt.original_text = opt.text
|
||||
opt.original_negative = opt.negative
|
||||
if opt.learned_embeds_path is not None:
|
||||
# cml:
|
||||
# python guidance/sd_utils.py --text "A high-resolution DSLR image of <token>" --learned_embeds_path out/learned_embeds/ --workspace out/teddy_bear
|
||||
# check prompt
|
||||
if os.path.isdir(opt.learned_embeds_path):
|
||||
learned_embeds_paths = glob.glob(os.path.join(opt.learned_embeds_path, 'learned_embeds*bin'))
|
||||
else:
|
||||
learned_embeds_paths = [opt.learned_embeds_path]
|
||||
|
||||
for learned_embeds_path in learned_embeds_paths:
|
||||
embed_name = os.path.basename(learned_embeds_path).split('.')[0]
|
||||
opt.workspace = os.path.join(workspace, embed_name)
|
||||
sd = StableDiffusion(device, opt.fp16, opt.vram_O,
|
||||
opt.sd_version, opt.hf_key,
|
||||
learned_embeds_path=learned_embeds_path
|
||||
)
|
||||
# Add tokenizer
|
||||
if learned_embeds_path is not None: # add textual inversion tokens to model
|
||||
opt.text, opt.negative = token_replace(
|
||||
opt.original_text, opt.original_negative, learned_embeds_path)
|
||||
logger.info(opt.text, opt.negative)
|
||||
sd.check_prompt(opt)
|
||||
else:
|
||||
#breakpoint()
|
||||
if opt.image_path is not None:
|
||||
save_promt = '_'.join(opt.text.split(' ')) + '_' + opt.image_path.split(
|
||||
'/')[-1].split('.')[0] + '_' + str(opt.noise_t) + '_' + str(opt.num_inference_steps)
|
||||
imgs = sd.img_to_img([opt.text]*opt.prompt_check_nums, [opt.negative]*opt.prompt_check_nums, opt.H, opt.W, opt.num_inference_steps,
|
||||
to_numpy=False, img=opt.image_path, t=opt.noise_t, guidance_scale=opt.gudiance_scale)
|
||||
else:
|
||||
save_promt = '_'.join(opt.text.split(' '))
|
||||
imgs = sd.prompt_to_img([opt.text]*opt.prompt_check_nums, [opt.negative]
|
||||
* opt.prompt_check_nums, opt.H, opt.W, opt.num_inference_steps, to_numpy=False)
|
||||
# visualize image
|
||||
output_dir_check = Path(opt.workspace)
|
||||
output_dir_check.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
to_pil(imgs).save(output_dir_check / f'{save_promt}.png')
|
||||
81
guidance/shape_utils.py
Normal file
81
guidance/shape_utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
from shap_e.models.transmitter.base import Transmitter
|
||||
from shap_e.models.query import Query
|
||||
from shap_e.models.nerstf.renderer import NeRSTFRenderer
|
||||
from shap_e.util.collections import AttrDict
|
||||
from shap_e.diffusion.sample import sample_latents
|
||||
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
|
||||
from shap_e.models.download import load_model, load_config
|
||||
from shap_e.util.image_util import load_image
|
||||
from shap_e.models.nn.meta import subdict
|
||||
import torch
|
||||
import gc
|
||||
|
||||
|
||||
camera_to_shapes = [
|
||||
torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32),
|
||||
torch.tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=torch.float32), # to bird view
|
||||
torch.tensor([[0, 1, 0], [0, 0, 1], [-1, 0, 0]], dtype=torch.float32), # to rotaed bird view
|
||||
torch.tensor([[0, -1, 0], [0, 0, 1], [-1, 0, 0]], dtype=torch.float32), # to rotaed bird view
|
||||
torch.tensor([[-1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=torch.float32), # to bird view
|
||||
]
|
||||
|
||||
|
||||
def get_density(
|
||||
render,
|
||||
query: Query,
|
||||
params: Dict[str, torch.Tensor],
|
||||
options: AttrDict[str, Any],
|
||||
) -> torch.Tensor:
|
||||
assert render.nerstf is not None
|
||||
return render.nerstf(query, params=subdict(params, "nerstf"), options=options).density
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_shape_from_image(image_path, pos,
|
||||
rpst_type='sdf', # or 'density'
|
||||
get_color=True,
|
||||
shape_guidance=3, device='cuda'):
|
||||
xm = load_model('transmitter', device=device)
|
||||
model = load_model('image300M', device=device)
|
||||
diffusion = diffusion_from_config(load_config('diffusion'))
|
||||
latent = sample_latents(
|
||||
batch_size=1,
|
||||
model=model,
|
||||
diffusion=diffusion,
|
||||
guidance_scale=shape_guidance,
|
||||
model_kwargs=dict(images=[load_image(image_path)]),
|
||||
progress=True,
|
||||
clip_denoised=True,
|
||||
use_fp16=True,
|
||||
use_karras=True,
|
||||
karras_steps=64,
|
||||
sigma_min=1e-3,
|
||||
sigma_max=160,
|
||||
s_churn=0,
|
||||
)[0]
|
||||
|
||||
params = (xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params(
|
||||
latent[None]
|
||||
)
|
||||
|
||||
rpsts, colors = [], []
|
||||
for camera_to_shape in camera_to_shapes:
|
||||
query = Query(
|
||||
position=pos @ camera_to_shape.to(pos.device),
|
||||
direction=None,
|
||||
)
|
||||
|
||||
if rpst_type == 'sdf':
|
||||
rpst = xm.renderer.get_signed_distance(query, params, AttrDict())
|
||||
else:
|
||||
rpst = get_density(xm.renderer, query, params, AttrDict())
|
||||
rpsts.append(rpst.squeeze())
|
||||
|
||||
if get_color:
|
||||
color = xm.renderer.get_texture(query, params, AttrDict())
|
||||
else:
|
||||
color = None
|
||||
colors.append(color)
|
||||
|
||||
return rpsts, colors
|
||||
332
guidance/zero123_utils.py
Normal file
332
guidance/zero123_utils.py
Normal file
@@ -0,0 +1,332 @@
|
||||
import math
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from torchvision.utils import save_image
|
||||
|
||||
from diffusers import DDIMScheduler
|
||||
|
||||
import sys
|
||||
from os import path
|
||||
sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
class SpecifyGradient(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, input_tensor, gt_grad):
|
||||
ctx.save_for_backward(gt_grad)
|
||||
# we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
|
||||
return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_scale):
|
||||
gt_grad, = ctx.saved_tensors
|
||||
gt_grad = gt_grad * grad_scale
|
||||
return gt_grad, None
|
||||
|
||||
# load model
|
||||
def load_model_from_config(config, ckpt, device, vram_O=False, verbose=False):
|
||||
|
||||
pl_sd = torch.load(ckpt, map_location='cpu')
|
||||
|
||||
if 'global_step' in pl_sd and verbose:
|
||||
print(f'[INFO] Global Step: {pl_sd["global_step"]}')
|
||||
|
||||
sd = pl_sd['state_dict']
|
||||
|
||||
model = instantiate_from_config(config.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
|
||||
if len(m) > 0 and verbose:
|
||||
print('[INFO] missing keys: \n', m)
|
||||
if len(u) > 0 and verbose:
|
||||
print('[INFO] unexpected keys: \n', u)
|
||||
|
||||
# manually load ema and delete it to save GPU memory
|
||||
if model.use_ema:
|
||||
if verbose:
|
||||
print('[INFO] loading EMA...')
|
||||
model.model_ema.copy_to(model.model)
|
||||
del model.model_ema
|
||||
|
||||
if vram_O:
|
||||
# we don't need decoder
|
||||
del model.first_stage_model.decoder
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model.eval().to(device)
|
||||
|
||||
return model
|
||||
|
||||
class Zero123(nn.Module):
|
||||
def __init__(self, device, fp16,
|
||||
config='./pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml',
|
||||
ckpt='./pretrained/zero123/105000.ckpt', vram_O=False, t_range=[0.02, 0.98], opt=None):
|
||||
super().__init__()
|
||||
|
||||
self.device = device
|
||||
self.fp16 = fp16
|
||||
self.vram_O = vram_O
|
||||
self.t_range = t_range
|
||||
self.opt = opt
|
||||
|
||||
self.config = OmegaConf.load(config)
|
||||
self.model = load_model_from_config(self.config, ckpt, device=self.device, vram_O=vram_O)
|
||||
|
||||
# timesteps: use diffuser for convenience... hope it's alright.
|
||||
self.num_train_timesteps = self.config.model.params.timesteps
|
||||
|
||||
self.scheduler = DDIMScheduler(
|
||||
self.num_train_timesteps,
|
||||
self.config.model.params.linear_start,
|
||||
self.config.model.params.linear_end,
|
||||
beta_schedule='scaled_linear',
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
steps_offset=1,
|
||||
)
|
||||
|
||||
self.min_step = int(self.num_train_timesteps * t_range[0])
|
||||
self.max_step = int(self.num_train_timesteps * t_range[1])
|
||||
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
|
||||
|
||||
@torch.no_grad()
|
||||
def get_img_embeds(self, x):
|
||||
# x: image tensor [B, 3, 256, 256] in [0, 1]
|
||||
x = x * 2 - 1
|
||||
c = [self.model.get_learned_conditioning(xx.unsqueeze(0)) for xx in x] #.tile(n_samples, 1, 1)
|
||||
v = [self.model.encode_first_stage(xx.unsqueeze(0)).mode() for xx in x]
|
||||
return c, v
|
||||
|
||||
def angle_between(self, sph_v1, sph_v2):
|
||||
def sph2cart(sv):
|
||||
r, theta, phi = sv[0], sv[1], sv[2]
|
||||
return torch.tensor([r * torch.sin(theta) * torch.cos(phi), r * torch.sin(theta) * torch.sin(phi), r * torch.cos(theta)])
|
||||
def unit_vector(v):
|
||||
return v / torch.linalg.norm(v)
|
||||
def angle_between_2_sph(sv1, sv2):
|
||||
v1, v2 = sph2cart(sv1), sph2cart(sv2)
|
||||
v1_u, v2_u = unit_vector(v1), unit_vector(v2)
|
||||
return torch.arccos(torch.clip(torch.dot(v1_u, v2_u), -1.0, 1.0))
|
||||
angles = torch.empty(len(sph_v1), len(sph_v2))
|
||||
for i, sv1 in enumerate(sph_v1):
|
||||
for j, sv2 in enumerate(sph_v2):
|
||||
angles[i][j] = angle_between_2_sph(sv1, sv2)
|
||||
return angles
|
||||
|
||||
def train_step(self, embeddings, pred_rgb, polar, azimuth, radius, guidance_scale=3, as_latent=False, grad_scale=1, save_guidance_path:Path=None):
|
||||
# pred_rgb: tensor [1, 3, H, W] in [0, 1]
|
||||
|
||||
# adjust SDS scale based on how far the novel view is from the known view
|
||||
ref_radii = embeddings['ref_radii']
|
||||
ref_polars = embeddings['ref_polars']
|
||||
ref_azimuths = embeddings['ref_azimuths']
|
||||
v1 = torch.stack([radius + ref_radii[0], torch.deg2rad(polar + ref_polars[0]), torch.deg2rad(azimuth + ref_azimuths[0])], dim=-1) # polar,azimuth,radius are all actually delta wrt default
|
||||
v2 = torch.stack([torch.tensor(ref_radii), torch.deg2rad(torch.tensor(ref_polars)), torch.deg2rad(torch.tensor(ref_azimuths))], dim=-1)
|
||||
angles = torch.rad2deg(self.angle_between(v1, v2)).to(self.device)
|
||||
if self.opt.zero123_grad_scale == 'angle':
|
||||
grad_scale = (angles.min(dim=1)[0] / (180/len(ref_azimuths))) * grad_scale # rethink 180/len(ref_azimuths) # claforte: try inverting grad_scale or just fixing it to 1.0
|
||||
elif self.opt.zero123_grad_scale == 'None':
|
||||
grad_scale = 1.0 # claforte: I think this might converge faster...?
|
||||
else:
|
||||
assert False, f'Unrecognized `zero123_grad_scale`: {self.opt.zero123_grad_scale}'
|
||||
|
||||
if as_latent:
|
||||
latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1
|
||||
else:
|
||||
pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False)
|
||||
latents = self.encode_imgs(pred_rgb_256)
|
||||
|
||||
t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device)
|
||||
|
||||
# Set weights acc to closeness in angle
|
||||
if len(ref_azimuths) > 1:
|
||||
inv_angles = 1/angles
|
||||
inv_angles[inv_angles > 100] = 100
|
||||
inv_angles /= inv_angles.max(dim=-1, keepdim=True)[0]
|
||||
inv_angles[inv_angles < 0.1] = 0
|
||||
else:
|
||||
inv_angles = torch.tensor([1.]).to(self.device)
|
||||
|
||||
# Multiply closeness-weight by user-given weights
|
||||
zero123_ws = torch.tensor(embeddings['zero123_ws'])[None, :].to(self.device) * inv_angles
|
||||
zero123_ws /= zero123_ws.max(dim=-1, keepdim=True)[0]
|
||||
zero123_ws[zero123_ws < 0.1] = 0
|
||||
|
||||
with torch.no_grad():
|
||||
noise = torch.randn_like(latents)
|
||||
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
||||
|
||||
x_in = torch.cat([latents_noisy] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
|
||||
noise_preds = []
|
||||
# Loop through each ref image
|
||||
for (zero123_w, c_crossattn, c_concat, ref_polar, ref_azimuth, ref_radius) in zip(zero123_ws.T,
|
||||
embeddings['c_crossattn'], embeddings['c_concat'],
|
||||
ref_polars, ref_azimuths, ref_radii):
|
||||
# polar,azimuth,radius are all actually delta wrt default
|
||||
p = polar + ref_polars[0] - ref_polar
|
||||
a = azimuth + ref_azimuths[0] - ref_azimuth
|
||||
a[a > 180] -= 360 # range in [-180, 180]
|
||||
r = radius + ref_radii[0] - ref_radius
|
||||
# T = torch.tensor([math.radians(p), math.sin(math.radians(-a)), math.cos(math.radians(a)), r])
|
||||
# T = T[None, None, :].to(self.device)
|
||||
T = torch.stack([torch.deg2rad(p), torch.sin(torch.deg2rad(-a)), torch.cos(torch.deg2rad(a)), r], dim=-1)[:, None, :]
|
||||
cond = {}
|
||||
clip_emb = self.model.cc_projection(torch.cat([c_crossattn.repeat(len(T), 1, 1), T], dim=-1))
|
||||
cond['c_crossattn'] = [torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)]
|
||||
cond['c_concat'] = [torch.cat([torch.zeros_like(c_concat).repeat(len(T), 1, 1, 1).to(self.device), c_concat.repeat(len(T), 1, 1, 1)], dim=0)]
|
||||
noise_pred = self.model.apply_model(x_in, t_in, cond)
|
||||
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
noise_preds.append(zero123_w[:, None, None, None] * noise_pred)
|
||||
|
||||
noise_pred = torch.stack(noise_preds).sum(dim=0) / zero123_ws.sum(dim=-1)[:, None, None, None]
|
||||
|
||||
w = (1 - self.alphas[t])
|
||||
grad = (grad_scale * w)[:, None, None, None] * (noise_pred - noise)
|
||||
grad = torch.nan_to_num(grad)
|
||||
|
||||
# import kiui
|
||||
# if not as_latent:
|
||||
# kiui.vis.plot_image(pred_rgb_256)
|
||||
# kiui.vis.plot_matrix(latents)
|
||||
# kiui.vis.plot_matrix(grad)
|
||||
|
||||
# import kiui
|
||||
# latents = torch.randn((1, 4, 32, 32), device=self.device)
|
||||
# kiui.lo(latents)
|
||||
# self.scheduler.set_timesteps(30)
|
||||
# with torch.no_grad():
|
||||
# for i, t in enumerate(self.scheduler.timesteps):
|
||||
# x_in = torch.cat([latents] * 2)
|
||||
# t_in = torch.cat([t.view(1)] * 2).to(self.device)
|
||||
|
||||
# noise_pred = self.model.apply_model(x_in, t_in, cond)
|
||||
# noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
||||
# noise_pred = noise_pred_uncond + 3 * (noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
# latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']
|
||||
# imgs = self.decode_latents(latents)
|
||||
# print(polar, azimuth, radius)
|
||||
# kiui.vis.plot_image(pred_rgb_256, imgs)
|
||||
|
||||
if save_guidance_path:
|
||||
with torch.no_grad():
|
||||
if as_latent:
|
||||
pred_rgb_256 = self.decode_latents(latents) # claforte: test!
|
||||
|
||||
# visualize predicted denoised image
|
||||
result_hopefully_less_noisy_image = self.decode_latents(self.model.predict_start_from_noise(latents_noisy, t, noise_pred))
|
||||
|
||||
# visualize noisier image
|
||||
result_noisier_image = self.decode_latents(latents_noisy)
|
||||
|
||||
# all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]
|
||||
viz_images = torch.cat([pred_rgb_256, result_noisier_image, result_hopefully_less_noisy_image],dim=-1)
|
||||
save_image(viz_images, save_guidance_path)
|
||||
|
||||
# since we omitted an item in grad, we need to use the custom function to specify the gradient
|
||||
# loss = SpecifyGradient.apply(latents, grad)
|
||||
latents.backward(gradient=grad, retain_graph=True)
|
||||
loss = grad.abs().mean().detach()
|
||||
return loss
|
||||
|
||||
# verification
|
||||
@torch.no_grad()
|
||||
def __call__(self,
|
||||
image, # image tensor [1, 3, H, W] in [0, 1]
|
||||
polar=0, azimuth=0, radius=0, # new view params
|
||||
scale=3, ddim_steps=50, ddim_eta=1, h=256, w=256, # diffusion params
|
||||
c_crossattn=None, c_concat=None, post_process=True,
|
||||
):
|
||||
|
||||
if c_crossattn is None:
|
||||
embeddings = self.get_img_embeds(image)
|
||||
|
||||
T = torch.tensor([math.radians(polar), math.sin(math.radians(azimuth)), math.cos(math.radians(azimuth)), radius])
|
||||
T = T[None, None, :].to(self.device)
|
||||
|
||||
cond = {}
|
||||
clip_emb = self.model.cc_projection(torch.cat([embeddings['c_crossattn'] if c_crossattn is None else c_crossattn, T], dim=-1))
|
||||
cond['c_crossattn'] = [torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)]
|
||||
cond['c_concat'] = [torch.cat([torch.zeros_like(embeddings['c_concat']).to(self.device), embeddings['c_concat']], dim=0)] if c_concat is None else [torch.cat([torch.zeros_like(c_concat).to(self.device), c_concat], dim=0)]
|
||||
|
||||
# produce latents loop
|
||||
latents = torch.randn((1, 4, h // 8, w // 8), device=self.device)
|
||||
self.scheduler.set_timesteps(ddim_steps)
|
||||
|
||||
for i, t in enumerate(self.scheduler.timesteps):
|
||||
x_in = torch.cat([latents] * 2)
|
||||
t_in = torch.cat([t.view(1)] * 2).to(self.device)
|
||||
|
||||
noise_pred = self.model.apply_model(x_in, t_in, cond)
|
||||
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + scale * (noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(noise_pred, t, latents, eta=ddim_eta)['prev_sample']
|
||||
|
||||
imgs = self.decode_latents(latents)
|
||||
imgs = imgs.cpu().numpy().transpose(0, 2, 3, 1) if post_process else imgs
|
||||
|
||||
return imgs
|
||||
|
||||
def decode_latents(self, latents):
|
||||
# zs: [B, 4, 32, 32] Latent space image
|
||||
# with self.model.ema_scope():
|
||||
imgs = self.model.decode_first_stage(latents)
|
||||
imgs = (imgs / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
return imgs # [B, 3, 256, 256] RGB space image
|
||||
|
||||
def encode_imgs(self, imgs):
|
||||
# imgs: [B, 3, 256, 256] RGB space image
|
||||
# with self.model.ema_scope():
|
||||
imgs = imgs * 2 - 1
|
||||
latents = torch.cat([self.model.get_first_stage_encoding(self.model.encode_first_stage(img.unsqueeze(0))) for img in imgs], dim=0)
|
||||
return latents # [B, 4, 32, 32] Latent space image
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import cv2
|
||||
import argparse
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('input', type=str)
|
||||
parser.add_argument('--fp16', action='store_true', help="use float16 for training") # no use now, can only run in fp32
|
||||
|
||||
parser.add_argument('--polar', type=float, default=0, help='delta polar angle in [-90, 90]')
|
||||
parser.add_argument('--azimuth', type=float, default=0, help='delta azimuth angle in [-180, 180]')
|
||||
parser.add_argument('--radius', type=float, default=0, help='delta camera radius multiplier in [-0.5, 0.5]')
|
||||
|
||||
opt = parser.parse_args()
|
||||
|
||||
device = torch.device('cuda')
|
||||
|
||||
print(f'[INFO] loading image from {opt.input} ...')
|
||||
image = cv2.imread(opt.input, cv2.IMREAD_UNCHANGED)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA)
|
||||
image = image.astype(np.float32) / 255.0
|
||||
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).contiguous().to(device)
|
||||
|
||||
print(f'[INFO] loading model ...')
|
||||
zero123 = Zero123(device, opt.fp16, opt=opt)
|
||||
|
||||
print(f'[INFO] running model ...')
|
||||
outputs = zero123(image, polar=opt.polar, azimuth=opt.azimuth, radius=opt.radius)
|
||||
plt.imshow(outputs[0])
|
||||
plt.show()
|
||||
24
install.sh
Normal file
24
install.sh
Normal file
@@ -0,0 +1,24 @@
|
||||
|
||||
# for KAUST cluster
|
||||
module load cuda/11.7.0
|
||||
module load gcc/7.5.0
|
||||
module load eigen
|
||||
|
||||
# for aws ubuntu. install eigen
|
||||
#sudo apt update && sudo apt upgrade
|
||||
#sudo apt install libeigen3-dev
|
||||
|
||||
# a100: 8.0; v100: 7.0; 2080ti: 7.5; titan xp: 6.1
|
||||
export TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0"
|
||||
|
||||
# use python venv
|
||||
python -m venv venv_magic123
|
||||
source venv_magic123/bin/activate
|
||||
|
||||
# use conda
|
||||
# conda create -n magic123 python=3.10 ipython -y
|
||||
# conda activate magic123
|
||||
|
||||
pip3 install torch torchvision
|
||||
pip3 install -r requirements.txt
|
||||
bash scripts/install_ext.sh
|
||||
77
ldm/extras.py
Executable file
77
ldm/extras.py
Executable file
@@ -0,0 +1,77 @@
|
||||
from pathlib import Path
|
||||
from omegaconf import OmegaConf
|
||||
import torch
|
||||
from ldm.util import instantiate_from_config
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
|
||||
from contextlib import contextmanager
|
||||
import logging
|
||||
|
||||
@contextmanager
|
||||
def all_logging_disabled(highest_level=logging.CRITICAL):
|
||||
"""
|
||||
A context manager that will prevent any logging messages
|
||||
triggered during the body from being processed.
|
||||
|
||||
:param highest_level: the maximum logging level in use.
|
||||
This would only need to be changed if a custom level greater than CRITICAL
|
||||
is defined.
|
||||
|
||||
https://gist.github.com/simon-weber/7853144
|
||||
"""
|
||||
# two kind-of hacks here:
|
||||
# * can't get the highest logging level in effect => delegate to the user
|
||||
# * can't get the current module-level override => use an undocumented
|
||||
# (but non-private!) interface
|
||||
|
||||
previous_level = logging.root.manager.disable
|
||||
|
||||
logging.disable(highest_level)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
logging.disable(previous_level)
|
||||
|
||||
def load_training_dir(train_dir, device, epoch="last"):
|
||||
"""Load a checkpoint and config from training directory"""
|
||||
train_dir = Path(train_dir)
|
||||
ckpt = list(train_dir.rglob(f"*{epoch}.ckpt"))
|
||||
assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files"
|
||||
config = list(train_dir.rglob(f"*-project.yaml"))
|
||||
assert len(ckpt) > 0, f"didn't find any config in {train_dir}"
|
||||
if len(config) > 1:
|
||||
print(f"found {len(config)} matching config files")
|
||||
config = sorted(config)[-1]
|
||||
print(f"selecting {config}")
|
||||
else:
|
||||
config = config[0]
|
||||
|
||||
|
||||
config = OmegaConf.load(config)
|
||||
return load_model_from_config(config, ckpt[0], device)
|
||||
|
||||
def load_model_from_config(config, ckpt, device="cpu", verbose=False):
|
||||
"""Loads a model from config and a ckpt
|
||||
if config is a path will use omegaconf to load
|
||||
"""
|
||||
if isinstance(config, (str, Path)):
|
||||
config = OmegaConf.load(config)
|
||||
|
||||
with all_logging_disabled():
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
global_step = pl_sd["global_step"]
|
||||
sd = pl_sd["state_dict"]
|
||||
model = instantiate_from_config(config.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0 and verbose:
|
||||
print("missing keys:")
|
||||
print(m)
|
||||
if len(u) > 0 and verbose:
|
||||
print("unexpected keys:")
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.cond_stage_model.device = device
|
||||
return model
|
||||
96
ldm/guidance.py
Executable file
96
ldm/guidance.py
Executable file
@@ -0,0 +1,96 @@
|
||||
from typing import List, Tuple
|
||||
from scipy import interpolate
|
||||
import numpy as np
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
from IPython.display import clear_output
|
||||
import abc
|
||||
|
||||
|
||||
class GuideModel(torch.nn.Module, abc.ABC):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@abc.abstractmethod
|
||||
def preprocess(self, x_img):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def compute_loss(self, inp):
|
||||
pass
|
||||
|
||||
|
||||
class Guider(torch.nn.Module):
|
||||
def __init__(self, sampler, guide_model, scale=1.0, verbose=False):
|
||||
"""Apply classifier guidance
|
||||
|
||||
Specify a guidance scale as either a scalar
|
||||
Or a schedule as a list of tuples t = 0->1 and scale, e.g.
|
||||
[(0, 10), (0.5, 20), (1, 50)]
|
||||
"""
|
||||
super().__init__()
|
||||
self.sampler = sampler
|
||||
self.index = 0
|
||||
self.show = verbose
|
||||
self.guide_model = guide_model
|
||||
self.history = []
|
||||
|
||||
if isinstance(scale, (Tuple, List)):
|
||||
times = np.array([x[0] for x in scale])
|
||||
values = np.array([x[1] for x in scale])
|
||||
self.scale_schedule = {"times": times, "values": values}
|
||||
else:
|
||||
self.scale_schedule = float(scale)
|
||||
|
||||
self.ddim_timesteps = sampler.ddim_timesteps
|
||||
self.ddpm_num_timesteps = sampler.ddpm_num_timesteps
|
||||
|
||||
|
||||
def get_scales(self):
|
||||
if isinstance(self.scale_schedule, float):
|
||||
return len(self.ddim_timesteps)*[self.scale_schedule]
|
||||
|
||||
interpolater = interpolate.interp1d(self.scale_schedule["times"], self.scale_schedule["values"])
|
||||
fractional_steps = np.array(self.ddim_timesteps)/self.ddpm_num_timesteps
|
||||
return interpolater(fractional_steps)
|
||||
|
||||
def modify_score(self, model, e_t, x, t, c):
|
||||
|
||||
# TODO look up index by t
|
||||
scale = self.get_scales()[self.index]
|
||||
|
||||
if (scale == 0):
|
||||
return e_t
|
||||
|
||||
sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device)
|
||||
with torch.enable_grad():
|
||||
x_in = x.detach().requires_grad_(True)
|
||||
pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t)
|
||||
x_img = model.first_stage_model.decode((1/0.18215)*pred_x0)
|
||||
|
||||
inp = self.guide_model.preprocess(x_img)
|
||||
loss = self.guide_model.compute_loss(inp)
|
||||
grads = torch.autograd.grad(loss.sum(), x_in)[0]
|
||||
correction = grads * scale
|
||||
|
||||
if self.show:
|
||||
clear_output(wait=True)
|
||||
print(loss.item(), scale, correction.abs().max().item(), e_t.abs().max().item())
|
||||
self.history.append([loss.item(), scale, correction.min().item(), correction.max().item()])
|
||||
plt.imshow((inp[0].detach().permute(1,2,0).clamp(-1,1).cpu()+1)/2)
|
||||
plt.axis('off')
|
||||
plt.show()
|
||||
plt.imshow(correction[0][0].detach().cpu())
|
||||
plt.axis('off')
|
||||
plt.show()
|
||||
|
||||
|
||||
e_t_mod = e_t - sqrt_1ma*correction
|
||||
if self.show:
|
||||
fig, axs = plt.subplots(1, 3)
|
||||
axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2)
|
||||
axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2)
|
||||
axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2)
|
||||
plt.show()
|
||||
self.index += 1
|
||||
return e_t_mod
|
||||
98
ldm/lr_scheduler.py
Executable file
98
ldm/lr_scheduler.py
Executable file
@@ -0,0 +1,98 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class LambdaWarmUpCosineScheduler:
|
||||
"""
|
||||
note: use with a base_lr of 1.0
|
||||
"""
|
||||
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
|
||||
self.lr_warm_up_steps = warm_up_steps
|
||||
self.lr_start = lr_start
|
||||
self.lr_min = lr_min
|
||||
self.lr_max = lr_max
|
||||
self.lr_max_decay_steps = max_decay_steps
|
||||
self.last_lr = 0.
|
||||
self.verbosity_interval = verbosity_interval
|
||||
|
||||
def schedule(self, n, **kwargs):
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
||||
if n < self.lr_warm_up_steps:
|
||||
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
||||
self.last_lr = lr
|
||||
return lr
|
||||
else:
|
||||
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
||||
t = min(t, 1.0)
|
||||
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
||||
1 + np.cos(t * np.pi))
|
||||
self.last_lr = lr
|
||||
return lr
|
||||
|
||||
def __call__(self, n, **kwargs):
|
||||
return self.schedule(n,**kwargs)
|
||||
|
||||
|
||||
class LambdaWarmUpCosineScheduler2:
|
||||
"""
|
||||
supports repeated iterations, configurable via lists
|
||||
note: use with a base_lr of 1.0.
|
||||
"""
|
||||
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
|
||||
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
|
||||
self.lr_warm_up_steps = warm_up_steps
|
||||
self.f_start = f_start
|
||||
self.f_min = f_min
|
||||
self.f_max = f_max
|
||||
self.cycle_lengths = cycle_lengths
|
||||
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
||||
self.last_f = 0.
|
||||
self.verbosity_interval = verbosity_interval
|
||||
|
||||
def find_in_interval(self, n):
|
||||
interval = 0
|
||||
for cl in self.cum_cycles[1:]:
|
||||
if n <= cl:
|
||||
return interval
|
||||
interval += 1
|
||||
|
||||
def schedule(self, n, **kwargs):
|
||||
cycle = self.find_in_interval(n)
|
||||
n = n - self.cum_cycles[cycle]
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||
f"current cycle {cycle}")
|
||||
if n < self.lr_warm_up_steps[cycle]:
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
||||
self.last_f = f
|
||||
return f
|
||||
else:
|
||||
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
|
||||
t = min(t, 1.0)
|
||||
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
||||
1 + np.cos(t * np.pi))
|
||||
self.last_f = f
|
||||
return f
|
||||
|
||||
def __call__(self, n, **kwargs):
|
||||
return self.schedule(n, **kwargs)
|
||||
|
||||
|
||||
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
||||
|
||||
def schedule(self, n, **kwargs):
|
||||
cycle = self.find_in_interval(n)
|
||||
n = n - self.cum_cycles[cycle]
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||
f"current cycle {cycle}")
|
||||
|
||||
if n < self.lr_warm_up_steps[cycle]:
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
||||
self.last_f = f
|
||||
return f
|
||||
else:
|
||||
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
|
||||
self.last_f = f
|
||||
return f
|
||||
|
||||
443
ldm/models/autoencoder.py
Executable file
443
ldm/models/autoencoder.py
Executable file
@@ -0,0 +1,443 @@
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
import torch.nn.functional as F
|
||||
from contextlib import contextmanager
|
||||
|
||||
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||
|
||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
|
||||
class VQModel(pl.LightningModule):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
batch_resize_range=None,
|
||||
scheduler_config=None,
|
||||
lr_g_factor=1.0,
|
||||
remap=None,
|
||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||
use_ema=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.n_embed = n_embed
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
||||
remap=remap,
|
||||
sane_index_shape=sane_index_shape)
|
||||
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels)==int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
self.batch_resize_range = batch_resize_range
|
||||
if self.batch_resize_range is not None:
|
||||
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
||||
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self)
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
self.scheduler_config = scheduler_config
|
||||
self.lr_g_factor = lr_g_factor
|
||||
|
||||
@contextmanager
|
||||
def ema_scope(self, context=None):
|
||||
if self.use_ema:
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
print(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
return quant, emb_loss, info
|
||||
|
||||
def encode_to_prequant(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, quant):
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
def decode_code(self, code_b):
|
||||
quant_b = self.quantize.embed_code(code_b)
|
||||
dec = self.decode(quant_b)
|
||||
return dec
|
||||
|
||||
def forward(self, input, return_pred_indices=False):
|
||||
quant, diff, (_,_,ind) = self.encode(input)
|
||||
dec = self.decode(quant)
|
||||
if return_pred_indices:
|
||||
return dec, diff, ind
|
||||
return dec, diff
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
if self.batch_resize_range is not None:
|
||||
lower_size = self.batch_resize_range[0]
|
||||
upper_size = self.batch_resize_range[1]
|
||||
if self.global_step <= 4:
|
||||
# do the first few batches with max size to avoid later oom
|
||||
new_resize = upper_size
|
||||
else:
|
||||
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
||||
if new_resize != x.shape[2]:
|
||||
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
||||
x = x.detach()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
# https://github.com/pytorch/pytorch/issues/37142
|
||||
# try not to fool the heuristics
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# autoencode
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train",
|
||||
predicted_indices=ind)
|
||||
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# discriminator
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
log_dict = self._validation_step(batch, batch_idx)
|
||||
with self.ema_scope():
|
||||
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
||||
return log_dict
|
||||
|
||||
def _validation_step(self, batch, batch_idx, suffix=""):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val"+suffix,
|
||||
predicted_indices=ind
|
||||
)
|
||||
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val"+suffix,
|
||||
predicted_indices=ind
|
||||
)
|
||||
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
self.log(f"val{suffix}/rec_loss", rec_loss,
|
||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||
self.log(f"val{suffix}/aeloss", aeloss,
|
||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
||||
del log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr_d = self.learning_rate
|
||||
lr_g = self.lr_g_factor*self.learning_rate
|
||||
print("lr_d", lr_d)
|
||||
print("lr_g", lr_g)
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||
list(self.decoder.parameters())+
|
||||
list(self.quantize.parameters())+
|
||||
list(self.quant_conv.parameters())+
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr_g, betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr_d, betas=(0.5, 0.9))
|
||||
|
||||
if self.scheduler_config is not None:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
print("Setting up LambdaLR scheduler...")
|
||||
scheduler = [
|
||||
{
|
||||
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
},
|
||||
{
|
||||
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
},
|
||||
]
|
||||
return [opt_ae, opt_disc], scheduler
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if only_inputs:
|
||||
log["inputs"] = x
|
||||
return log
|
||||
xrec, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["inputs"] = x
|
||||
log["reconstructions"] = xrec
|
||||
if plot_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, _ = self(x)
|
||||
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
||||
log["reconstructions_ema"] = xrec_ema
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class VQModelInterface(VQModel):
|
||||
def __init__(self, embed_dim, *args, **kwargs):
|
||||
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, h, force_not_quantize=False):
|
||||
# also go through quantization layer
|
||||
if not force_not_quantize:
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
else:
|
||||
quant = h
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
|
||||
class AutoencoderKL(pl.LightningModule):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
assert ddconfig["double_z"]
|
||||
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels)==int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path}")
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z):
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
def forward(self, input, sample_posterior=True):
|
||||
posterior = self.encode(input)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# train encoder+decoder+logvar
|
||||
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# train the discriminator
|
||||
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
|
||||
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val")
|
||||
|
||||
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val")
|
||||
|
||||
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||
list(self.decoder.parameters())+
|
||||
list(self.quant_conv.parameters())+
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, only_inputs=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if not only_inputs:
|
||||
xrec, posterior = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log["reconstructions"] = xrec
|
||||
log["inputs"] = x
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class IdentityFirstStage(torch.nn.Module):
|
||||
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
||||
super().__init__()
|
||||
|
||||
def encode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def decode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def quantize(self, x, *args, **kwargs):
|
||||
if self.vq_interface:
|
||||
return x, None, [None, None, None]
|
||||
return x
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return x
|
||||
0
ldm/models/diffusion/__init__.py
Executable file
0
ldm/models/diffusion/__init__.py
Executable file
267
ldm/models/diffusion/classifier.py
Executable file
267
ldm/models/diffusion/classifier.py
Executable file
@@ -0,0 +1,267 @@
|
||||
import os
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
from omegaconf import OmegaConf
|
||||
from torch.nn import functional as F
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from copy import deepcopy
|
||||
from einops import rearrange
|
||||
from glob import glob
|
||||
from natsort import natsorted
|
||||
|
||||
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
|
||||
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
|
||||
|
||||
__models__ = {
|
||||
'class_label': EncoderUNetModel,
|
||||
'segmentation': UNetModel
|
||||
}
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
|
||||
def __init__(self,
|
||||
diffusion_path,
|
||||
num_classes,
|
||||
ckpt_path=None,
|
||||
pool='attention',
|
||||
label_key=None,
|
||||
diffusion_ckpt_path=None,
|
||||
scheduler_config=None,
|
||||
weight_decay=1.e-2,
|
||||
log_steps=10,
|
||||
monitor='val/loss',
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_classes = num_classes
|
||||
# get latest config of diffusion model
|
||||
diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
|
||||
self.diffusion_config = OmegaConf.load(diffusion_config).model
|
||||
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
|
||||
self.load_diffusion()
|
||||
|
||||
self.monitor = monitor
|
||||
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
|
||||
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
|
||||
self.log_steps = log_steps
|
||||
|
||||
self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
|
||||
else self.diffusion_model.cond_stage_key
|
||||
|
||||
assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
|
||||
|
||||
if self.label_key not in __models__:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.load_classifier(ckpt_path, pool)
|
||||
|
||||
self.scheduler_config = scheduler_config
|
||||
self.use_scheduler = self.scheduler_config is not None
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
||||
sd = torch.load(path, map_location="cpu")
|
||||
if "state_dict" in list(sd.keys()):
|
||||
sd = sd["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
||||
sd, strict=False)
|
||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
if len(unexpected) > 0:
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def load_diffusion(self):
|
||||
model = instantiate_from_config(self.diffusion_config)
|
||||
self.diffusion_model = model.eval()
|
||||
self.diffusion_model.train = disabled_train
|
||||
for param in self.diffusion_model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def load_classifier(self, ckpt_path, pool):
|
||||
model_config = deepcopy(self.diffusion_config.params.unet_config.params)
|
||||
model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
|
||||
model_config.out_channels = self.num_classes
|
||||
if self.label_key == 'class_label':
|
||||
model_config.pool = pool
|
||||
|
||||
self.model = __models__[self.label_key](**model_config)
|
||||
if ckpt_path is not None:
|
||||
print('#####################################################################')
|
||||
print(f'load from ckpt "{ckpt_path}"')
|
||||
print('#####################################################################')
|
||||
self.init_from_ckpt(ckpt_path)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_x_noisy(self, x, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x))
|
||||
continuous_sqrt_alpha_cumprod = None
|
||||
if self.diffusion_model.use_continuous_noise:
|
||||
continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
|
||||
# todo: make sure t+1 is correct here
|
||||
|
||||
return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
|
||||
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
|
||||
|
||||
def forward(self, x_noisy, t, *args, **kwargs):
|
||||
return self.model(x_noisy, t)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = rearrange(x, 'b h w c -> b c h w')
|
||||
x = x.to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def get_conditioning(self, batch, k=None):
|
||||
if k is None:
|
||||
k = self.label_key
|
||||
assert k is not None, 'Needs to provide label key'
|
||||
|
||||
targets = batch[k].to(self.device)
|
||||
|
||||
if self.label_key == 'segmentation':
|
||||
targets = rearrange(targets, 'b h w c -> b c h w')
|
||||
for down in range(self.numd):
|
||||
h, w = targets.shape[-2:]
|
||||
targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
|
||||
|
||||
# targets = rearrange(targets,'b c h w -> b h w c')
|
||||
|
||||
return targets
|
||||
|
||||
def compute_top_k(self, logits, labels, k, reduction="mean"):
|
||||
_, top_ks = torch.topk(logits, k, dim=1)
|
||||
if reduction == "mean":
|
||||
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
|
||||
elif reduction == "none":
|
||||
return (top_ks == labels[:, None]).float().sum(dim=-1)
|
||||
|
||||
def on_train_epoch_start(self):
|
||||
# save some memory
|
||||
self.diffusion_model.model.to('cpu')
|
||||
|
||||
@torch.no_grad()
|
||||
def write_logs(self, loss, logits, targets):
|
||||
log_prefix = 'train' if self.training else 'val'
|
||||
log = {}
|
||||
log[f"{log_prefix}/loss"] = loss.mean()
|
||||
log[f"{log_prefix}/acc@1"] = self.compute_top_k(
|
||||
logits, targets, k=1, reduction="mean"
|
||||
)
|
||||
log[f"{log_prefix}/acc@5"] = self.compute_top_k(
|
||||
logits, targets, k=5, reduction="mean"
|
||||
)
|
||||
|
||||
self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
|
||||
self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
|
||||
self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
|
||||
lr = self.optimizers().param_groups[0]['lr']
|
||||
self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
|
||||
|
||||
def shared_step(self, batch, t=None):
|
||||
x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
|
||||
targets = self.get_conditioning(batch)
|
||||
if targets.dim() == 4:
|
||||
targets = targets.argmax(dim=1)
|
||||
if t is None:
|
||||
t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
|
||||
else:
|
||||
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
|
||||
x_noisy = self.get_x_noisy(x, t)
|
||||
logits = self(x_noisy, t)
|
||||
|
||||
loss = F.cross_entropy(logits, targets, reduction='none')
|
||||
|
||||
self.write_logs(loss.detach(), logits.detach(), targets.detach())
|
||||
|
||||
loss = loss.mean()
|
||||
return loss, logits, x_noisy, targets
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss, *_ = self.shared_step(batch)
|
||||
return loss
|
||||
|
||||
def reset_noise_accs(self):
|
||||
self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
|
||||
range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
|
||||
|
||||
def on_validation_start(self):
|
||||
self.reset_noise_accs()
|
||||
|
||||
@torch.no_grad()
|
||||
def validation_step(self, batch, batch_idx):
|
||||
loss, *_ = self.shared_step(batch)
|
||||
|
||||
for t in self.noisy_acc:
|
||||
_, logits, _, targets = self.shared_step(batch, t)
|
||||
self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
|
||||
self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
|
||||
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
|
||||
|
||||
if self.use_scheduler:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
print("Setting up LambdaLR scheduler...")
|
||||
scheduler = [
|
||||
{
|
||||
'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
}]
|
||||
return [optimizer], scheduler
|
||||
|
||||
return optimizer
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, N=8, *args, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.diffusion_model.first_stage_key)
|
||||
log['inputs'] = x
|
||||
|
||||
y = self.get_conditioning(batch)
|
||||
|
||||
if self.label_key == 'class_label':
|
||||
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
|
||||
log['labels'] = y
|
||||
|
||||
if ismap(y):
|
||||
log['labels'] = self.diffusion_model.to_rgb(y)
|
||||
|
||||
for step in range(self.log_steps):
|
||||
current_time = step * self.log_time_interval
|
||||
|
||||
_, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
|
||||
|
||||
log[f'inputs@t{current_time}'] = x_noisy
|
||||
|
||||
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
|
||||
pred = rearrange(pred, 'b h w c -> b c h w')
|
||||
|
||||
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
|
||||
|
||||
for key in log:
|
||||
log[key] = log[key][:N]
|
||||
|
||||
return log
|
||||
328
ldm/models/diffusion/ddim.py
Executable file
328
ldm/models/diffusion/ddim.py
Executable file
@@ -0,0 +1,328 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from einops import rearrange
|
||||
|
||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
|
||||
from ldm.models.diffusion.sampling_util import renorm_thresholding, norm_thresholding, spatial_norm_thresholding
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
|
||||
def to(self, device):
|
||||
"""Same as to in torch module
|
||||
Don't really underestand why this isn't a module in the first place"""
|
||||
for k, v in self.__dict__.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
new_v = getattr(self, k).to(device)
|
||||
setattr(self, k, new_v)
|
||||
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
dynamic_threshold=None,
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
# print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||
|
||||
samples, intermediates = self.ddim_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(self, cond, shape,
|
||||
x_T=None, ddim_use_original_steps=False,
|
||||
callback=None, timesteps=None, quantize_denoised=False,
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
|
||||
t_start=-1):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
# print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold)
|
||||
img, pred_x0 = outs
|
||||
if callback:
|
||||
img = callback(i, img, pred_x0)
|
||||
if img_callback:
|
||||
img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||
dynamic_threshold=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [torch.cat([
|
||||
unconditional_conditioning[k][i],
|
||||
c[k][i]]) for i in range(len(c[k]))]
|
||||
else:
|
||||
c_in[k] = torch.cat([
|
||||
unconditional_conditioning[k],
|
||||
c[k]])
|
||||
else:
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
|
||||
print(t, sqrt_one_minus_at, a_t)
|
||||
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
|
||||
if dynamic_threshold is not None:
|
||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
|
||||
unconditional_guidance_scale=1.0, unconditional_conditioning=None):
|
||||
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
|
||||
|
||||
assert t_enc <= num_reference_steps
|
||||
num_steps = t_enc
|
||||
|
||||
if use_original_steps:
|
||||
alphas_next = self.alphas_cumprod[:num_steps]
|
||||
alphas = self.alphas_cumprod_prev[:num_steps]
|
||||
else:
|
||||
alphas_next = self.ddim_alphas[:num_steps]
|
||||
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
|
||||
|
||||
x_next = x0
|
||||
intermediates = []
|
||||
inter_steps = []
|
||||
for i in tqdm(range(num_steps), desc='Encoding Image'):
|
||||
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
|
||||
if unconditional_guidance_scale == 1.:
|
||||
noise_pred = self.model.apply_model(x_next, t, c)
|
||||
else:
|
||||
assert unconditional_conditioning is not None
|
||||
e_t_uncond, noise_pred = torch.chunk(
|
||||
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
|
||||
torch.cat((unconditional_conditioning, c))), 2)
|
||||
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
|
||||
|
||||
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
|
||||
weighted_noise_pred = alphas_next[i].sqrt() * (
|
||||
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
|
||||
x_next = xt_weighted + weighted_noise_pred
|
||||
if return_intermediates and i % (
|
||||
num_steps // return_intermediates) == 0 and i < num_steps - 1:
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
elif return_intermediates and i >= num_steps - 2:
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
|
||||
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
|
||||
if return_intermediates:
|
||||
out.update({'intermediates': intermediates})
|
||||
return x_next, out
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
# t serves as an index to gather the correct alphas
|
||||
if use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
|
||||
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||
use_original_steps=False):
|
||||
|
||||
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
# print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
|
||||
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning)
|
||||
return x_dec
|
||||
1994
ldm/models/diffusion/ddpm.py
Executable file
1994
ldm/models/diffusion/ddpm.py
Executable file
File diff suppressed because it is too large
Load Diff
259
ldm/models/diffusion/plms.py
Executable file
259
ldm/models/diffusion/plms.py
Executable file
@@ -0,0 +1,259 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
|
||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
||||
from ldm.models.diffusion.sampling_util import norm_thresholding
|
||||
|
||||
|
||||
class PLMSSampler(object):
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
if ddim_eta != 0:
|
||||
raise ValueError('ddim_eta must be 0 for PLMS')
|
||||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
dynamic_threshold=None,
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for PLMS sampling is {size}')
|
||||
|
||||
samples, intermediates = self.plms_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sampling(self, cond, shape,
|
||||
x_T=None, ddim_use_original_steps=False,
|
||||
callback=None, timesteps=None, quantize_denoised=False,
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||
dynamic_threshold=None):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
||||
old_eps = []
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
old_eps=old_eps, t_next=ts_next,
|
||||
dynamic_threshold=dynamic_threshold)
|
||||
img, pred_x0, e_t = outs
|
||||
old_eps.append(e_t)
|
||||
if len(old_eps) >= 4:
|
||||
old_eps.pop(0)
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
|
||||
dynamic_threshold=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [torch.cat([
|
||||
unconditional_conditioning[k][i],
|
||||
c[k][i]]) for i in range(len(c[k]))]
|
||||
else:
|
||||
c_in[k] = torch.cat([
|
||||
unconditional_conditioning[k],
|
||||
c[k]])
|
||||
else:
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
return e_t
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
if dynamic_threshold is not None:
|
||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
e_t = get_model_output(x, t)
|
||||
if len(old_eps) == 0:
|
||||
# Pseudo Improved Euler (2nd order)
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||
e_t_next = get_model_output(x_prev, t_next)
|
||||
e_t_prime = (e_t + e_t_next) / 2
|
||||
elif len(old_eps) == 1:
|
||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||
elif len(old_eps) == 2:
|
||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||
elif len(old_eps) >= 3:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||
|
||||
return x_prev, pred_x0, e_t
|
||||
50
ldm/models/diffusion/sampling_util.py
Executable file
50
ldm/models/diffusion/sampling_util.py
Executable file
@@ -0,0 +1,50 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
|
||||
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
def renorm_thresholding(x0, value):
|
||||
# renorm
|
||||
pred_max = x0.max()
|
||||
pred_min = x0.min()
|
||||
pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1
|
||||
pred_x0 = 2 * pred_x0 - 1. # -1 ... 1
|
||||
|
||||
s = torch.quantile(
|
||||
rearrange(pred_x0, 'b ... -> b (...)').abs(),
|
||||
value,
|
||||
dim=-1
|
||||
)
|
||||
s.clamp_(min=1.0)
|
||||
s = s.view(-1, *((1,) * (pred_x0.ndim - 1)))
|
||||
|
||||
# clip by threshold
|
||||
# pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max
|
||||
|
||||
# temporary hack: numpy on cpu
|
||||
pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy()
|
||||
pred_x0 = torch.tensor(pred_x0).to(self.model.device)
|
||||
|
||||
# re.renorm
|
||||
pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1
|
||||
pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range
|
||||
return pred_x0
|
||||
|
||||
|
||||
def norm_thresholding(x0, value):
|
||||
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
|
||||
return x0 * (value / s)
|
||||
|
||||
|
||||
def spatial_norm_thresholding(x0, value):
|
||||
# b c h w
|
||||
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
|
||||
return x0 * (value / s)
|
||||
266
ldm/modules/attention.py
Executable file
266
ldm/modules/attention.py
Executable file
@@ -0,0 +1,266 @@
|
||||
from inspect import isfunction
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from ldm.modules.diffusionmodules.util import checkpoint
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return{el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
|
||||
def init_(tensor):
|
||||
dim = tensor.shape[-1]
|
||||
std = 1 / math.sqrt(dim)
|
||||
tensor.uniform_(-std, std)
|
||||
return tensor
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
def __init__(self, dim, heads=4, dim_head=32):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
||||
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
||||
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
q = rearrange(q, 'b c h w -> b (h w) c')
|
||||
k = rearrange(k, 'b c h w -> b c (h w)')
|
||||
w_ = torch.einsum('bij,bjk->bik', q, k)
|
||||
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = rearrange(v, 'b c h w -> b c (h w)')
|
||||
w_ = rearrange(w_, 'b i j -> b j i')
|
||||
h_ = torch.einsum('bij,bjk->bik', v, w_)
|
||||
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
||||
disable_self_attn=False):
|
||||
super().__init__()
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data.
|
||||
First, project the input (aka embedding)
|
||||
and reshape to b, t, d.
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
"""
|
||||
def __init__(self, in_channels, n_heads, d_head,
|
||||
depth=1, dropout=0., context_dim=None,
|
||||
disable_self_attn=False):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels)
|
||||
|
||||
self.proj_in = nn.Conv2d(in_channels,
|
||||
inner_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,
|
||||
disable_self_attn=disable_self_attn)
|
||||
for d in range(depth)]
|
||||
)
|
||||
|
||||
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0))
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context=context)
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
0
ldm/modules/diffusionmodules/__init__.py
Executable file
0
ldm/modules/diffusionmodules/__init__.py
Executable file
835
ldm/modules/diffusionmodules/model.py
Executable file
835
ldm/modules/diffusionmodules/model.py
Executable file
@@ -0,0 +1,835 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.modules.attention import LinearAttention
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
Build sinusoidal embeddings.
|
||||
This matches the implementation in tensor2tensor, but differs slightly
|
||||
from the description in Section 3.5 of "Attention Is All You Need".
|
||||
"""
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x*torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0,1,0,1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
||||
dropout, temb_channels=512):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels,
|
||||
out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x+h
|
||||
|
||||
|
||||
class LinAttnBlock(LinearAttention):
|
||||
"""to match AttnBlock usage"""
|
||||
def __init__(self, in_channels):
|
||||
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
q = q.reshape(b,c,h*w)
|
||||
q = q.permute(0,2,1) # b,hw,c
|
||||
k = k.reshape(b,c,h*w) # b,c,hw
|
||||
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b,c,h*w)
|
||||
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b,c,h,w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla"):
|
||||
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
if attn_type == "vanilla":
|
||||
return AttnBlock(in_channels)
|
||||
elif attn_type == "none":
|
||||
return nn.Identity(in_channels)
|
||||
else:
|
||||
return LinAttnBlock(in_channels)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch*4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.use_timestep = use_timestep
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
self.temb = nn.Module()
|
||||
self.temb.dense = nn.ModuleList([
|
||||
torch.nn.Linear(self.ch,
|
||||
self.temb_ch),
|
||||
torch.nn.Linear(self.temb_ch,
|
||||
self.temb_ch),
|
||||
])
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch*in_ch_mult[i_level]
|
||||
block_out = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions-1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch*ch_mult[i_level]
|
||||
skip_in = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
if i_block == self.num_res_blocks:
|
||||
skip_in = ch*in_ch_mult[i_level]
|
||||
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x, t=None, context=None):
|
||||
#assert x.shape[2] == x.shape[3] == self.resolution
|
||||
if context is not None:
|
||||
# assume aligned context, cat along channel axis
|
||||
x = torch.cat((x, context), dim=1)
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
assert t is not None
|
||||
temb = get_timestep_embedding(t, self.ch)
|
||||
temb = self.temb.dense[0](temb)
|
||||
temb = nonlinearity(temb)
|
||||
temb = self.temb.dense[1](temb)
|
||||
else:
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions-1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](
|
||||
torch.cat([h, hs.pop()], dim=1), temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.conv_out.weight
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
|
||||
**ignore_kwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch*in_ch_mult[i_level]
|
||||
block_out = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions-1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
2*z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions-1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
||||
attn_type="vanilla", **ignorekwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
block_in = ch*ch_mult[self.num_resolutions-1]
|
||||
curr_res = resolution // 2**(self.num_resolutions-1)
|
||||
self.z_shape = (1,z_channels,curr_res,curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(z_channels,
|
||||
block_in,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
#assert z.shape[1:] == self.z_shape[1:]
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
|
||||
|
||||
class SimpleDecoder(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
|
||||
ResnetBlock(in_channels=in_channels,
|
||||
out_channels=2 * in_channels,
|
||||
temb_channels=0, dropout=0.0),
|
||||
ResnetBlock(in_channels=2 * in_channels,
|
||||
out_channels=4 * in_channels,
|
||||
temb_channels=0, dropout=0.0),
|
||||
ResnetBlock(in_channels=4 * in_channels,
|
||||
out_channels=2 * in_channels,
|
||||
temb_channels=0, dropout=0.0),
|
||||
nn.Conv2d(2*in_channels, in_channels, 1),
|
||||
Upsample(in_channels, with_conv=True)])
|
||||
# end
|
||||
self.norm_out = Normalize(in_channels)
|
||||
self.conv_out = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.model):
|
||||
if i in [1,2,3]:
|
||||
x = layer(x, None)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
h = self.norm_out(x)
|
||||
h = nonlinearity(h)
|
||||
x = self.conv_out(h)
|
||||
return x
|
||||
|
||||
|
||||
class UpsampleDecoder(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
|
||||
ch_mult=(2,2), dropout=0.0):
|
||||
super().__init__()
|
||||
# upsampling
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
block_in = in_channels
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.res_blocks = nn.ModuleList()
|
||||
self.upsample_blocks = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
res_block = []
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
res_block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
self.res_blocks.append(nn.ModuleList(res_block))
|
||||
if i_level != self.num_resolutions - 1:
|
||||
self.upsample_blocks.append(Upsample(block_in, True))
|
||||
curr_res = curr_res * 2
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
# upsampling
|
||||
h = x
|
||||
for k, i_level in enumerate(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.res_blocks[i_level][i_block](h, None)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.upsample_blocks[k](h)
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class LatentRescaler(nn.Module):
|
||||
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
|
||||
super().__init__()
|
||||
# residual block, interpolate, residual block
|
||||
self.factor = factor
|
||||
self.conv_in = nn.Conv2d(in_channels,
|
||||
mid_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0) for _ in range(depth)])
|
||||
self.attn = AttnBlock(mid_channels)
|
||||
self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0) for _ in range(depth)])
|
||||
|
||||
self.conv_out = nn.Conv2d(mid_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
for block in self.res_block1:
|
||||
x = block(x, None)
|
||||
x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
|
||||
x = self.attn(x)
|
||||
for block in self.res_block2:
|
||||
x = block(x, None)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class MergedRescaleEncoder(nn.Module):
|
||||
def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True,
|
||||
ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
|
||||
super().__init__()
|
||||
intermediate_chn = ch * ch_mult[-1]
|
||||
self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
|
||||
z_channels=intermediate_chn, double_z=False, resolution=resolution,
|
||||
attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
|
||||
out_ch=None)
|
||||
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
|
||||
mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
x = self.rescaler(x)
|
||||
return x
|
||||
|
||||
|
||||
class MergedRescaleDecoder(nn.Module):
|
||||
def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
|
||||
dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
|
||||
super().__init__()
|
||||
tmp_chn = z_channels*ch_mult[-1]
|
||||
self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
|
||||
ch_mult=ch_mult, resolution=resolution, ch=ch)
|
||||
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
|
||||
out_channels=tmp_chn, depth=rescale_module_depth)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.rescaler(x)
|
||||
x = self.decoder(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsampler(nn.Module):
|
||||
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
|
||||
super().__init__()
|
||||
assert out_size >= in_size
|
||||
num_blocks = int(np.log2(out_size//in_size))+1
|
||||
factor_up = 1.+ (out_size % in_size)
|
||||
print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
|
||||
self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
|
||||
out_channels=in_channels)
|
||||
self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
|
||||
attn_resolutions=[], in_channels=None, ch=in_channels,
|
||||
ch_mult=[ch_mult for _ in range(num_blocks)])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.rescaler(x)
|
||||
x = self.decoder(x)
|
||||
return x
|
||||
|
||||
|
||||
class Resize(nn.Module):
|
||||
def __init__(self, in_channels=None, learned=False, mode="bilinear"):
|
||||
super().__init__()
|
||||
self.with_conv = learned
|
||||
self.mode = mode
|
||||
if self.with_conv:
|
||||
print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
|
||||
raise NotImplementedError()
|
||||
assert in_channels is not None
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x, scale_factor=1.0):
|
||||
if scale_factor==1.0:
|
||||
return x
|
||||
else:
|
||||
x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
|
||||
return x
|
||||
|
||||
class FirstStagePostProcessor(nn.Module):
|
||||
|
||||
def __init__(self, ch_mult:list, in_channels,
|
||||
pretrained_model:nn.Module=None,
|
||||
reshape=False,
|
||||
n_channels=None,
|
||||
dropout=0.,
|
||||
pretrained_config=None):
|
||||
super().__init__()
|
||||
if pretrained_config is None:
|
||||
assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
|
||||
self.pretrained_model = pretrained_model
|
||||
else:
|
||||
assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
|
||||
self.instantiate_pretrained(pretrained_config)
|
||||
|
||||
self.do_reshape = reshape
|
||||
|
||||
if n_channels is None:
|
||||
n_channels = self.pretrained_model.encoder.ch
|
||||
|
||||
self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
|
||||
self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
|
||||
stride=1,padding=1)
|
||||
|
||||
blocks = []
|
||||
downs = []
|
||||
ch_in = n_channels
|
||||
for m in ch_mult:
|
||||
blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
|
||||
ch_in = m * n_channels
|
||||
downs.append(Downsample(ch_in, with_conv=False))
|
||||
|
||||
self.model = nn.ModuleList(blocks)
|
||||
self.downsampler = nn.ModuleList(downs)
|
||||
|
||||
|
||||
def instantiate_pretrained(self, config):
|
||||
model = instantiate_from_config(config)
|
||||
self.pretrained_model = model.eval()
|
||||
# self.pretrained_model.train = False
|
||||
for param in self.pretrained_model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_with_pretrained(self,x):
|
||||
c = self.pretrained_model.encode(x)
|
||||
if isinstance(c, DiagonalGaussianDistribution):
|
||||
c = c.mode()
|
||||
return c
|
||||
|
||||
def forward(self,x):
|
||||
z_fs = self.encode_with_pretrained(x)
|
||||
z = self.proj_norm(z_fs)
|
||||
z = self.proj(z)
|
||||
z = nonlinearity(z)
|
||||
|
||||
for submodel, downmodel in zip(self.model,self.downsampler):
|
||||
z = submodel(z,temb=None)
|
||||
z = downmodel(z)
|
||||
|
||||
if self.do_reshape:
|
||||
z = rearrange(z,'b c h w -> b (h w) c')
|
||||
return z
|
||||
|
||||
996
ldm/modules/diffusionmodules/openaimodel.py
Executable file
996
ldm/modules/diffusionmodules/openaimodel.py
Executable file
@@ -0,0 +1,996 @@
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
import math
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
checkpoint,
|
||||
conv_nd,
|
||||
linear,
|
||||
avg_pool_nd,
|
||||
zero_module,
|
||||
normalization,
|
||||
timestep_embedding,
|
||||
)
|
||||
from ldm.modules.attention import SpatialTransformer
|
||||
from ldm.util import exists
|
||||
|
||||
|
||||
# dummy replace
|
||||
def convert_module_to_f16(x):
|
||||
pass
|
||||
|
||||
def convert_module_to_f32(x):
|
||||
pass
|
||||
|
||||
|
||||
## go
|
||||
class AttentionPool2d(nn.Module):
|
||||
"""
|
||||
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spacial_dim: int,
|
||||
embed_dim: int,
|
||||
num_heads_channels: int,
|
||||
output_dim: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
|
||||
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
||||
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
||||
self.num_heads = embed_dim // num_heads_channels
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, *_spatial = x.shape
|
||||
x = x.reshape(b, c, -1) # NC(HW)
|
||||
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
||||
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
||||
x = self.qkv_proj(x)
|
||||
x = self.attention(x)
|
||||
x = self.c_proj(x)
|
||||
return x[:, :, 0]
|
||||
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
"""
|
||||
Any module where forward() takes timestep embeddings as a second argument.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
Apply the module to `x` given `emb` timestep embeddings.
|
||||
"""
|
||||
|
||||
|
||||
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
"""
|
||||
A sequential module that passes timestep embeddings to the children that
|
||||
support it as an extra input.
|
||||
"""
|
||||
|
||||
def forward(self, x, emb, context=None):
|
||||
for layer in self:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, SpatialTransformer):
|
||||
x = layer(x, context)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(
|
||||
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
||||
)
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
class TransposedUpsample(nn.Module):
|
||||
'Learned 2x upsampling without padding'
|
||||
def __init__(self, channels, out_channels=None, ks=5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
|
||||
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
|
||||
|
||||
def forward(self,x):
|
||||
return self.up(x)
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
if use_conv:
|
||||
self.op = conv_nd(
|
||||
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
|
||||
)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class ResBlock(TimestepBlock):
|
||||
"""
|
||||
A residual block that can optionally change the number of channels.
|
||||
:param channels: the number of input channels.
|
||||
:param emb_channels: the number of timestep embedding channels.
|
||||
:param dropout: the rate of dropout.
|
||||
:param out_channels: if specified, the number of out channels.
|
||||
:param use_conv: if True and out_channels is specified, use a spatial
|
||||
convolution instead of a smaller 1x1 convolution to change the
|
||||
channels in the skip connection.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
||||
:param up: if True, use this block for upsampling.
|
||||
:param down: if True, use this block for downsampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
emb_channels,
|
||||
dropout,
|
||||
out_channels=None,
|
||||
use_conv=False,
|
||||
use_scale_shift_norm=False,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
up=False,
|
||||
down=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.emb_channels = emb_channels
|
||||
self.dropout = dropout
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
||||
)
|
||||
|
||||
self.updown = up or down
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(channels, False, dims)
|
||||
self.x_upd = Upsample(channels, False, dims)
|
||||
elif down:
|
||||
self.h_upd = Downsample(channels, False, dims)
|
||||
self.x_upd = Downsample(channels, False, dims)
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
normalization(self.out_channels),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
||||
),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, 3, padding=1
|
||||
)
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
Apply the block to a Tensor, conditioned on a timestep embedding.
|
||||
:param x: an [N x C x ...] Tensor of features.
|
||||
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
return checkpoint(
|
||||
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
||||
)
|
||||
|
||||
|
||||
def _forward(self, x, emb):
|
||||
if self.updown:
|
||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||
h = in_rest(x)
|
||||
h = self.h_upd(h)
|
||||
x = self.x_upd(x)
|
||||
h = in_conv(h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
while len(emb_out.shape) < len(h.shape):
|
||||
emb_out = emb_out[..., None]
|
||||
if self.use_scale_shift_norm:
|
||||
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||
scale, shift = th.chunk(emb_out, 2, dim=1)
|
||||
h = out_norm(h) * (1 + scale) + shift
|
||||
h = out_rest(h)
|
||||
else:
|
||||
h = h + emb_out
|
||||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other.
|
||||
Originally ported from here, but adapted to the N-d case.
|
||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
use_checkpoint=False,
|
||||
use_new_attention_order=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
if num_head_channels == -1:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm = normalization(channels)
|
||||
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
||||
if use_new_attention_order:
|
||||
# split qkv before split heads
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
else:
|
||||
# split heads before split qkv
|
||||
self.attention = QKVAttentionLegacy(self.num_heads)
|
||||
|
||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
|
||||
#return pt_checkpoint(self._forward, x) # pytorch
|
||||
|
||||
def _forward(self, x):
|
||||
b, c, *spatial = x.shape
|
||||
x = x.reshape(b, c, -1)
|
||||
qkv = self.qkv(self.norm(x))
|
||||
h = self.attention(qkv)
|
||||
h = self.proj_out(h)
|
||||
return (x + h).reshape(b, c, *spatial)
|
||||
|
||||
|
||||
def count_flops_attn(model, _x, y):
|
||||
"""
|
||||
A counter for the `thop` package to count the operations in an
|
||||
attention operation.
|
||||
Meant to be used like:
|
||||
macs, params = thop.profile(
|
||||
model,
|
||||
inputs=(inputs, timestamps),
|
||||
custom_ops={QKVAttention: QKVAttention.count_flops},
|
||||
)
|
||||
"""
|
||||
b, c, *spatial = y[0].shape
|
||||
num_spatial = int(np.prod(spatial))
|
||||
# We perform two matmuls with the same number of ops.
|
||||
# The first computes the weight matrix, the second computes
|
||||
# the combination of the value vectors.
|
||||
matmul_ops = 2 * b * (num_spatial ** 2) * c
|
||||
model.total_ops += th.DoubleTensor([matmul_ops])
|
||||
|
||||
|
||||
class QKVAttentionLegacy(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
|
||||
"""
|
||||
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
"bct,bcs->bts", q * scale, k * scale
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight, v)
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
|
||||
class QKVAttention(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention and splits in a different order.
|
||||
"""
|
||||
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
"bct,bcs->bts",
|
||||
(q * scale).view(bs * self.n_heads, ch, length),
|
||||
(k * scale).view(bs * self.n_heads, ch, length),
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
|
||||
class UNetModel(nn.Module):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding.
|
||||
:param in_channels: channels in the input Tensor.
|
||||
:param model_channels: base channel count for the model.
|
||||
:param out_channels: channels in the output Tensor.
|
||||
:param num_res_blocks: number of residual blocks per downsample.
|
||||
:param attention_resolutions: a collection of downsample rates at which
|
||||
attention will take place. May be a set, list, or tuple.
|
||||
For example, if this contains 4, then at 4x downsampling, attention
|
||||
will be used.
|
||||
:param dropout: the dropout probability.
|
||||
:param channel_mult: channel multiplier for each level of the UNet.
|
||||
:param conv_resample: if True, use learned convolutions for upsampling and
|
||||
downsampling.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
:param num_classes: if specified (as an int), then this model will be
|
||||
class-conditional with `num_classes` classes.
|
||||
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
||||
:param num_heads: the number of attention heads in each attention layer.
|
||||
:param num_heads_channels: if specified, ignore num_heads and instead use
|
||||
a fixed channel width per attention head.
|
||||
:param num_heads_upsample: works with num_heads to set a different number
|
||||
of heads for upsampling. Deprecated.
|
||||
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
||||
:param resblock_updown: use residual blocks for up/downsampling.
|
||||
:param use_new_attention_order: use a different attention pattern for potentially
|
||||
increased efficiency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
in_channels,
|
||||
model_channels,
|
||||
out_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
num_classes=None,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=-1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
use_spatial_transformer=False, # custom transformer support
|
||||
transformer_depth=1, # custom transformer support
|
||||
context_dim=None, # custom transformer support
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
legacy=True,
|
||||
disable_self_attentions=None,
|
||||
num_attention_blocks=None
|
||||
):
|
||||
super().__init__()
|
||||
if use_spatial_transformer:
|
||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||
|
||||
if context_dim is not None:
|
||||
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
||||
from omegaconf.listconfig import ListConfig
|
||||
if type(context_dim) == ListConfig:
|
||||
context_dim = list(context_dim)
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
if num_heads == -1:
|
||||
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
if num_head_channels == -1:
|
||||
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
self.image_size = image_size
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
if isinstance(num_res_blocks, int):
|
||||
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||
else:
|
||||
if len(num_res_blocks) != len(channel_mult):
|
||||
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
||||
"as a list/tuple (per-level) with the same length as channel_mult")
|
||||
self.num_res_blocks = num_res_blocks
|
||||
#self.num_res_blocks = num_res_blocks
|
||||
if disable_self_attentions is not None:
|
||||
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||
assert len(disable_self_attentions) == len(channel_mult)
|
||||
if num_attention_blocks is not None:
|
||||
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
||||
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
||||
f"attention will still not be set.") # todo: convert to warning
|
||||
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.num_classes = num_classes
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = th.float16 if use_fp16 else th.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
self.predict_codebook_ids = n_embed is not None
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
if self.num_classes is not None:
|
||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
)
|
||||
]
|
||||
)
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for nr in range(self.num_res_blocks[level]):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
if exists(disable_self_attentions):
|
||||
disabled_sa = disable_self_attentions[level]
|
||||
else:
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(self.num_res_blocks[level] + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch + ich,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=model_channels * mult,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = model_channels * mult
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
if exists(disable_self_attentions):
|
||||
disabled_sa = disable_self_attentions[level]
|
||||
else:
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads_upsample,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa
|
||||
)
|
||||
)
|
||||
if level and i == self.num_res_blocks[level]:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
up=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
||||
)
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
||||
)
|
||||
if self.predict_codebook_ids:
|
||||
self.id_predictor = nn.Sequential(
|
||||
normalization(ch),
|
||||
conv_nd(dims, model_channels, n_embed, 1),
|
||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||
)
|
||||
|
||||
def convert_to_fp16(self):
|
||||
"""
|
||||
Convert the torso of the model to float16.
|
||||
"""
|
||||
self.input_blocks.apply(convert_module_to_f16)
|
||||
self.middle_block.apply(convert_module_to_f16)
|
||||
self.output_blocks.apply(convert_module_to_f16)
|
||||
|
||||
def convert_to_fp32(self):
|
||||
"""
|
||||
Convert the torso of the model to float32.
|
||||
"""
|
||||
self.input_blocks.apply(convert_module_to_f32)
|
||||
self.middle_block.apply(convert_module_to_f32)
|
||||
self.output_blocks.apply(convert_module_to_f32)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:param context: conditioning plugged in via crossattn
|
||||
:param y: an [N] Tensor of labels, if class-conditional.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape == (x.shape[0],)
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb, context)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, context)
|
||||
for module in self.output_blocks:
|
||||
h = th.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb, context)
|
||||
h = h.type(x.dtype)
|
||||
if self.predict_codebook_ids:
|
||||
return self.id_predictor(h)
|
||||
else:
|
||||
return self.out(h)
|
||||
|
||||
|
||||
class EncoderUNetModel(nn.Module):
|
||||
"""
|
||||
The half UNet model with attention and timestep embedding.
|
||||
For usage, see UNet.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
in_channels,
|
||||
model_channels,
|
||||
out_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
pool="adaptive",
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = th.float16 if use_fp16 else th.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
)
|
||||
]
|
||||
)
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for _ in range(num_res_blocks):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
self.pool = pool
|
||||
if pool == "adaptive":
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
zero_module(conv_nd(dims, ch, out_channels, 1)),
|
||||
nn.Flatten(),
|
||||
)
|
||||
elif pool == "attention":
|
||||
assert num_head_channels != -1
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
AttentionPool2d(
|
||||
(image_size // ds), ch, num_head_channels, out_channels
|
||||
),
|
||||
)
|
||||
elif pool == "spatial":
|
||||
self.out = nn.Sequential(
|
||||
nn.Linear(self._feature_size, 2048),
|
||||
nn.ReLU(),
|
||||
nn.Linear(2048, self.out_channels),
|
||||
)
|
||||
elif pool == "spatial_v2":
|
||||
self.out = nn.Sequential(
|
||||
nn.Linear(self._feature_size, 2048),
|
||||
normalization(2048),
|
||||
nn.SiLU(),
|
||||
nn.Linear(2048, self.out_channels),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unexpected {pool} pooling")
|
||||
|
||||
def convert_to_fp16(self):
|
||||
"""
|
||||
Convert the torso of the model to float16.
|
||||
"""
|
||||
self.input_blocks.apply(convert_module_to_f16)
|
||||
self.middle_block.apply(convert_module_to_f16)
|
||||
|
||||
def convert_to_fp32(self):
|
||||
"""
|
||||
Convert the torso of the model to float32.
|
||||
"""
|
||||
self.input_blocks.apply(convert_module_to_f32)
|
||||
self.middle_block.apply(convert_module_to_f32)
|
||||
|
||||
def forward(self, x, timesteps):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:return: an [N x K] Tensor of outputs.
|
||||
"""
|
||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
|
||||
results = []
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb)
|
||||
if self.pool.startswith("spatial"):
|
||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||
h = self.middle_block(h, emb)
|
||||
if self.pool.startswith("spatial"):
|
||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||
h = th.cat(results, axis=-1)
|
||||
return self.out(h)
|
||||
else:
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
|
||||
267
ldm/modules/diffusionmodules/util.py
Executable file
267
ldm/modules/diffusionmodules/util.py
Executable file
@@ -0,0 +1,267 @@
|
||||
# adopted from
|
||||
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
||||
# and
|
||||
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
||||
# and
|
||||
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
||||
#
|
||||
# thanks!
|
||||
|
||||
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import repeat
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
|
||||
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
if schedule == "linear":
|
||||
betas = (
|
||||
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
||||
)
|
||||
|
||||
elif schedule == "cosine":
|
||||
timesteps = (
|
||||
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
||||
)
|
||||
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
||||
alphas = torch.cos(alphas).pow(2)
|
||||
alphas = alphas / alphas[0]
|
||||
betas = 1 - alphas[1:] / alphas[:-1]
|
||||
betas = np.clip(betas, a_min=0, a_max=0.999)
|
||||
|
||||
elif schedule == "sqrt_linear":
|
||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
||||
elif schedule == "sqrt":
|
||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
|
||||
else:
|
||||
raise ValueError(f"schedule '{schedule}' unknown.")
|
||||
return betas.numpy()
|
||||
|
||||
|
||||
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
|
||||
if ddim_discr_method == 'uniform':
|
||||
c = num_ddpm_timesteps // num_ddim_timesteps
|
||||
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
||||
elif ddim_discr_method == 'quad':
|
||||
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
|
||||
else:
|
||||
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
|
||||
|
||||
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||
steps_out = ddim_timesteps + 1
|
||||
if verbose:
|
||||
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||
return steps_out
|
||||
|
||||
|
||||
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
||||
# select alphas for computing the variance schedule
|
||||
alphas = alphacums[ddim_timesteps]
|
||||
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
||||
|
||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
||||
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
||||
if verbose:
|
||||
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
||||
print(f'For the chosen value of eta, which is {eta}, '
|
||||
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
||||
return sigmas, alphas, alphas_prev
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return np.array(betas)
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||
|
||||
|
||||
def checkpoint(func, inputs, params, flag):
|
||||
"""
|
||||
Evaluate a function without caching intermediate activations, allowing for
|
||||
reduced memory at the expense of extra compute in the backward pass.
|
||||
:param func: the function to evaluate.
|
||||
:param inputs: the argument sequence to pass to `func`.
|
||||
:param params: a sequence of parameters `func` depends on but does not
|
||||
explicitly take as arguments.
|
||||
:param flag: if False, disable gradient checkpointing.
|
||||
"""
|
||||
if flag:
|
||||
args = tuple(inputs) + tuple(params)
|
||||
return CheckpointFunction.apply(func, len(inputs), *args)
|
||||
else:
|
||||
return func(*inputs)
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, length, *args):
|
||||
ctx.run_function = run_function
|
||||
ctx.input_tensors = list(args[:length])
|
||||
ctx.input_params = list(args[length:])
|
||||
|
||||
with torch.no_grad():
|
||||
output_tensors = ctx.run_function(*ctx.input_tensors)
|
||||
return output_tensors
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *output_grads):
|
||||
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
||||
with torch.enable_grad():
|
||||
# Fixes a bug where the first op in run_function modifies the
|
||||
# Tensor storage in place, which is not allowed for detach()'d
|
||||
# Tensors.
|
||||
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
||||
output_tensors = ctx.run_function(*shallow_copies)
|
||||
input_grads = torch.autograd.grad(
|
||||
output_tensors,
|
||||
ctx.input_tensors + ctx.input_params,
|
||||
output_grads,
|
||||
allow_unused=True,
|
||||
)
|
||||
del ctx.input_tensors
|
||||
del ctx.input_params
|
||||
del output_tensors
|
||||
return (None, None) + input_grads
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
else:
|
||||
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
||||
return embedding
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def scale_module(module, scale):
|
||||
"""
|
||||
Scale the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().mul_(scale)
|
||||
return module
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def normalization(channels):
|
||||
"""
|
||||
Make a standard normalization layer.
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNorm32(32, channels)
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
class HybridConditioner(nn.Module):
|
||||
|
||||
def __init__(self, c_concat_config, c_crossattn_config):
|
||||
super().__init__()
|
||||
self.concat_conditioner = instantiate_from_config(c_concat_config)
|
||||
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
|
||||
|
||||
def forward(self, c_concat, c_crossattn):
|
||||
c_concat = self.concat_conditioner(c_concat)
|
||||
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
||||
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
||||
0
ldm/modules/distributions/__init__.py
Executable file
0
ldm/modules/distributions/__init__.py
Executable file
92
ldm/modules/distributions/distributions.py
Executable file
92
ldm/modules/distributions/distributions.py
Executable file
@@ -0,0 +1,92 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AbstractDistribution:
|
||||
def sample(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def mode(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiracDistribution(AbstractDistribution):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def sample(self):
|
||||
return self.value
|
||||
|
||||
def mode(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
||||
+ self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
|
||||
def nll(self, sample, dims=[1,2,3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
"""
|
||||
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
||||
Compute the KL divergence between two gaussians.
|
||||
Shapes are automatically broadcasted, so batches can be compared to
|
||||
scalars, among other use cases.
|
||||
"""
|
||||
tensor = None
|
||||
for obj in (mean1, logvar1, mean2, logvar2):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor = obj
|
||||
break
|
||||
assert tensor is not None, "at least one argument must be a Tensor"
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for torch.exp().
|
||||
logvar1, logvar2 = [
|
||||
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
||||
for x in (logvar1, logvar2)
|
||||
]
|
||||
|
||||
return 0.5 * (
|
||||
-1.0
|
||||
+ logvar2
|
||||
- logvar1
|
||||
+ torch.exp(logvar1 - logvar2)
|
||||
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
||||
)
|
||||
76
ldm/modules/ema.py
Executable file
76
ldm/modules/ema.py
Executable file
@@ -0,0 +1,76 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LitEma(nn.Module):
|
||||
def __init__(self, model, decay=0.9999, use_num_upates=True):
|
||||
super().__init__()
|
||||
if decay < 0.0 or decay > 1.0:
|
||||
raise ValueError('Decay must be between 0 and 1')
|
||||
|
||||
self.m_name2s_name = {}
|
||||
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
||||
self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
|
||||
else torch.tensor(-1,dtype=torch.int))
|
||||
|
||||
for name, p in model.named_parameters():
|
||||
if p.requires_grad:
|
||||
#remove as '.'-character is not allowed in buffers
|
||||
s_name = name.replace('.','')
|
||||
self.m_name2s_name.update({name:s_name})
|
||||
self.register_buffer(s_name,p.clone().detach().data)
|
||||
|
||||
self.collected_params = []
|
||||
|
||||
def forward(self,model):
|
||||
decay = self.decay
|
||||
|
||||
if self.num_updates >= 0:
|
||||
self.num_updates += 1
|
||||
decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
|
||||
|
||||
one_minus_decay = 1.0 - decay
|
||||
|
||||
with torch.no_grad():
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
sname = self.m_name2s_name[key]
|
||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def copy_to(self, model):
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def store(self, parameters):
|
||||
"""
|
||||
Save the current parameters for restoring later.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
temporarily stored.
|
||||
"""
|
||||
self.collected_params = [param.clone() for param in parameters]
|
||||
|
||||
def restore(self, parameters):
|
||||
"""
|
||||
Restore the parameters stored with the `store` method.
|
||||
Useful to validate the model with EMA parameters without affecting the
|
||||
original optimization process. Store the parameters before the
|
||||
`copy_to` method. After validation (or model saving), use this to
|
||||
restore the former parameters.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored parameters.
|
||||
"""
|
||||
for c_param, param in zip(self.collected_params, parameters):
|
||||
param.data.copy_(c_param.data)
|
||||
0
ldm/modules/encoders/__init__.py
Executable file
0
ldm/modules/encoders/__init__.py
Executable file
550
ldm/modules/encoders/modules.py
Executable file
550
ldm/modules/encoders/modules.py
Executable file
@@ -0,0 +1,550 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
import kornia
|
||||
|
||||
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
||||
from ldm.util import default
|
||||
import clip
|
||||
|
||||
|
||||
class AbstractEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
class IdentityEncoder(AbstractEncoder):
|
||||
|
||||
def encode(self, x):
|
||||
return x
|
||||
|
||||
class FaceClipEncoder(AbstractEncoder):
|
||||
def __init__(self, augment=True, retreival_key=None):
|
||||
super().__init__()
|
||||
self.encoder = FrozenCLIPImageEmbedder()
|
||||
self.augment = augment
|
||||
self.retreival_key = retreival_key
|
||||
|
||||
def forward(self, img):
|
||||
encodings = []
|
||||
with torch.no_grad():
|
||||
x_offset = 125
|
||||
if self.retreival_key:
|
||||
# Assumes retrieved image are packed into the second half of channels
|
||||
face = img[:,3:,190:440,x_offset:(512-x_offset)]
|
||||
other = img[:,:3,...].clone()
|
||||
else:
|
||||
face = img[:,:,190:440,x_offset:(512-x_offset)]
|
||||
other = img.clone()
|
||||
|
||||
if self.augment:
|
||||
face = K.RandomHorizontalFlip()(face)
|
||||
|
||||
other[:,:,190:440,x_offset:(512-x_offset)] *= 0
|
||||
encodings = [
|
||||
self.encoder.encode(face),
|
||||
self.encoder.encode(other),
|
||||
]
|
||||
|
||||
return torch.cat(encodings, dim=1)
|
||||
|
||||
def encode(self, img):
|
||||
if isinstance(img, list):
|
||||
# Uncondition
|
||||
return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device)
|
||||
|
||||
return self(img)
|
||||
|
||||
class FaceIdClipEncoder(AbstractEncoder):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.encoder = FrozenCLIPImageEmbedder()
|
||||
for p in self.encoder.parameters():
|
||||
p.requires_grad = False
|
||||
self.id = FrozenFaceEncoder("/home/jpinkney/code/stable-diffusion/model_ir_se50.pth", augment=True)
|
||||
|
||||
def forward(self, img):
|
||||
encodings = []
|
||||
with torch.no_grad():
|
||||
face = kornia.geometry.resize(img, (256, 256),
|
||||
interpolation='bilinear', align_corners=True)
|
||||
|
||||
other = img.clone()
|
||||
other[:,:,184:452,122:396] *= 0
|
||||
encodings = [
|
||||
self.id.encode(face),
|
||||
self.encoder.encode(other),
|
||||
]
|
||||
|
||||
return torch.cat(encodings, dim=1)
|
||||
|
||||
def encode(self, img):
|
||||
if isinstance(img, list):
|
||||
# Uncondition
|
||||
return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device)
|
||||
|
||||
return self(img)
|
||||
|
||||
class ClassEmbedder(nn.Module):
|
||||
def __init__(self, embed_dim, n_classes=1000, key='class'):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.embedding = nn.Embedding(n_classes, embed_dim)
|
||||
|
||||
def forward(self, batch, key=None):
|
||||
if key is None:
|
||||
key = self.key
|
||||
# this is for use in crossattn
|
||||
c = batch[key][:, None]
|
||||
c = self.embedding(c)
|
||||
return c
|
||||
|
||||
|
||||
class TransformerEmbedder(AbstractEncoder):
|
||||
"""Some transformer encoder layers"""
|
||||
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
|
||||
attn_layers=Encoder(dim=n_embed, depth=n_layer))
|
||||
|
||||
def forward(self, tokens):
|
||||
tokens = tokens.to(self.device) # meh
|
||||
z = self.transformer(tokens, return_embeddings=True)
|
||||
return z
|
||||
|
||||
def encode(self, x):
|
||||
return self(x)
|
||||
|
||||
|
||||
class BERTTokenizer(AbstractEncoder):
|
||||
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
|
||||
def __init__(self, device="cuda", vq_interface=True, max_length=77):
|
||||
super().__init__()
|
||||
from transformers import BertTokenizerFast # TODO: add to reuquirements
|
||||
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
self.device = device
|
||||
self.vq_interface = vq_interface
|
||||
self.max_length = max_length
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
return tokens
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, text):
|
||||
tokens = self(text)
|
||||
if not self.vq_interface:
|
||||
return tokens
|
||||
return None, None, [None, None, tokens]
|
||||
|
||||
def decode(self, text):
|
||||
return text
|
||||
|
||||
|
||||
class BERTEmbedder(AbstractEncoder):
|
||||
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
|
||||
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
|
||||
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
|
||||
super().__init__()
|
||||
self.use_tknz_fn = use_tokenizer
|
||||
if self.use_tknz_fn:
|
||||
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
|
||||
self.device = device
|
||||
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
|
||||
attn_layers=Encoder(dim=n_embed, depth=n_layer),
|
||||
emb_dropout=embedding_dropout)
|
||||
|
||||
def forward(self, text):
|
||||
if self.use_tknz_fn:
|
||||
tokens = self.tknz_fn(text)#.to(self.device)
|
||||
else:
|
||||
tokens = text
|
||||
z = self.transformer(tokens, return_embeddings=True)
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
# output of length 77
|
||||
return self(text)
|
||||
|
||||
|
||||
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
class FrozenT5Embedder(AbstractEncoder):
|
||||
"""Uses the T5 transformer encoder for text"""
|
||||
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
super().__init__()
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
||||
self.transformer = T5EncoderModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
#self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
from ldm.thirdp.psp.id_loss import IDFeatures
|
||||
import kornia.augmentation as K
|
||||
|
||||
class FrozenFaceEncoder(AbstractEncoder):
|
||||
def __init__(self, model_path, augment=False):
|
||||
super().__init__()
|
||||
self.loss_fn = IDFeatures(model_path)
|
||||
# face encoder is frozen
|
||||
for p in self.loss_fn.parameters():
|
||||
p.requires_grad = False
|
||||
# Mapper is trainable
|
||||
self.mapper = torch.nn.Linear(512, 768)
|
||||
p = 0.25
|
||||
if augment:
|
||||
self.augment = K.AugmentationSequential(
|
||||
K.RandomHorizontalFlip(p=0.5),
|
||||
K.RandomEqualize(p=p),
|
||||
# K.RandomPlanckianJitter(p=p),
|
||||
# K.RandomPlasmaBrightness(p=p),
|
||||
# K.RandomPlasmaContrast(p=p),
|
||||
# K.ColorJiggle(0.02, 0.2, 0.2, p=p),
|
||||
)
|
||||
else:
|
||||
self.augment = False
|
||||
|
||||
def forward(self, img):
|
||||
if isinstance(img, list):
|
||||
# Uncondition
|
||||
return torch.zeros((1, 1, 768), device=self.mapper.weight.device)
|
||||
|
||||
if self.augment is not None:
|
||||
# Transforms require 0-1
|
||||
img = self.augment((img + 1)/2)
|
||||
img = 2*img - 1
|
||||
|
||||
feat = self.loss_fn(img, crop=True)
|
||||
feat = self.mapper(feat.unsqueeze(1))
|
||||
return feat
|
||||
|
||||
def encode(self, img):
|
||||
return self(img)
|
||||
|
||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
#self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPVisionModel
|
||||
class ClipImageProjector(AbstractEncoder):
|
||||
"""
|
||||
Uses the CLIP image encoder.
|
||||
"""
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", max_length=77): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
self.model = CLIPVisionModel.from_pretrained(version)
|
||||
self.model.train()
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.antialias = True
|
||||
self.mapper = torch.nn.Linear(1024, 768)
|
||||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||
null_cond = self.get_null_cond(version, max_length)
|
||||
self.register_buffer('null_cond', null_cond)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_null_cond(self, version, max_length):
|
||||
device = self.mean.device
|
||||
embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length)
|
||||
null_cond = embedder([""])
|
||||
return null_cond
|
||||
|
||||
def preprocess(self, x):
|
||||
# Expects inputs in the range -1, 1
|
||||
x = kornia.geometry.resize(x, (224, 224),
|
||||
interpolation='bicubic',align_corners=True,
|
||||
antialias=self.antialias)
|
||||
x = (x + 1.) / 2.
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
if isinstance(x, list):
|
||||
return self.null_cond
|
||||
# x is assumed to be in range [-1,1]
|
||||
x = self.preprocess(x)
|
||||
outputs = self.model(pixel_values=x)
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
last_hidden_state = self.mapper(last_hidden_state)
|
||||
return F.pad(last_hidden_state, [0,0, 0,self.max_length-last_hidden_state.shape[1], 0,0])
|
||||
|
||||
def encode(self, im):
|
||||
return self(im)
|
||||
|
||||
class ProjectedFrozenCLIPEmbedder(AbstractEncoder):
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
self.embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length)
|
||||
self.projection = torch.nn.Linear(768, 768)
|
||||
|
||||
def forward(self, text):
|
||||
z = self.embedder(text)
|
||||
return self.projection(z)
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
class FrozenCLIPImageEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the CLIP image encoder.
|
||||
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model='ViT-L/14',
|
||||
jit=False,
|
||||
device='cpu',
|
||||
antialias=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.model, _ = clip.load(name=model, device=device, jit=jit)
|
||||
# We don't use the text part so delete it
|
||||
del self.model.transformer
|
||||
self.antialias = antialias
|
||||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||
|
||||
def preprocess(self, x):
|
||||
# Expects inputs in the range -1, 1
|
||||
x = kornia.geometry.resize(x, (224, 224),
|
||||
interpolation='bicubic',align_corners=True,
|
||||
antialias=self.antialias)
|
||||
x = (x + 1.) / 2.
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
# x is assumed to be in range [-1,1]
|
||||
if isinstance(x, list):
|
||||
# [""] denotes condition dropout for ucg
|
||||
device = self.model.visual.conv1.weight.device
|
||||
return torch.zeros(1, 768, device=device)
|
||||
return self.model.encode_image(self.preprocess(x)).float()
|
||||
|
||||
def encode(self, im):
|
||||
return self(im).unsqueeze(1)
|
||||
|
||||
from torchvision import transforms
|
||||
import random
|
||||
|
||||
class FrozenCLIPImageMutliEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the CLIP image encoder.
|
||||
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model='ViT-L/14',
|
||||
jit=False,
|
||||
device='cpu',
|
||||
antialias=True,
|
||||
max_crops=5,
|
||||
):
|
||||
super().__init__()
|
||||
self.model, _ = clip.load(name=model, device=device, jit=jit)
|
||||
# We don't use the text part so delete it
|
||||
del self.model.transformer
|
||||
self.antialias = antialias
|
||||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||
self.max_crops = max_crops
|
||||
|
||||
def preprocess(self, x):
|
||||
|
||||
# Expects inputs in the range -1, 1
|
||||
randcrop = transforms.RandomResizedCrop(224, scale=(0.085, 1.0), ratio=(1,1))
|
||||
max_crops = self.max_crops
|
||||
patches = []
|
||||
crops = [randcrop(x) for _ in range(max_crops)]
|
||||
patches.extend(crops)
|
||||
x = torch.cat(patches, dim=0)
|
||||
x = (x + 1.) / 2.
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
# x is assumed to be in range [-1,1]
|
||||
if isinstance(x, list):
|
||||
# [""] denotes condition dropout for ucg
|
||||
device = self.model.visual.conv1.weight.device
|
||||
return torch.zeros(1, self.max_crops, 768, device=device)
|
||||
batch_tokens = []
|
||||
for im in x:
|
||||
patches = self.preprocess(im.unsqueeze(0))
|
||||
tokens = self.model.encode_image(patches).float()
|
||||
for t in tokens:
|
||||
if random.random() < 0.1:
|
||||
t *= 0
|
||||
batch_tokens.append(tokens.unsqueeze(0))
|
||||
|
||||
return torch.cat(batch_tokens, dim=0)
|
||||
|
||||
def encode(self, im):
|
||||
return self(im)
|
||||
|
||||
class SpatialRescaler(nn.Module):
|
||||
def __init__(self,
|
||||
n_stages=1,
|
||||
method='bilinear',
|
||||
multiplier=0.5,
|
||||
in_channels=3,
|
||||
out_channels=None,
|
||||
bias=False):
|
||||
super().__init__()
|
||||
self.n_stages = n_stages
|
||||
assert self.n_stages >= 0
|
||||
assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
|
||||
self.multiplier = multiplier
|
||||
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
|
||||
self.remap_output = out_channels is not None
|
||||
if self.remap_output:
|
||||
print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
|
||||
self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
|
||||
|
||||
def forward(self,x):
|
||||
for stage in range(self.n_stages):
|
||||
x = self.interpolator(x, scale_factor=self.multiplier)
|
||||
|
||||
|
||||
if self.remap_output:
|
||||
x = self.channel_mapper(x)
|
||||
return x
|
||||
|
||||
def encode(self, x):
|
||||
return self(x)
|
||||
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
||||
|
||||
|
||||
class LowScaleEncoder(nn.Module):
|
||||
def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64,
|
||||
scale_factor=1.0):
|
||||
super().__init__()
|
||||
self.max_noise_level = max_noise_level
|
||||
self.model = instantiate_from_config(model_config)
|
||||
self.augmentation_schedule = self.register_schedule(timesteps=timesteps, linear_start=linear_start,
|
||||
linear_end=linear_end)
|
||||
self.out_size = output_size
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def register_schedule(self, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
|
||||
cosine_s=cosine_s)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
||||
|
||||
to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||
|
||||
self.register_buffer('betas', to_torch(betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
||||
|
||||
def forward(self, x):
|
||||
z = self.model.encode(x).sample()
|
||||
z = z * self.scale_factor
|
||||
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||
z = self.q_sample(z, noise_level)
|
||||
if self.out_size is not None:
|
||||
z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode
|
||||
# z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
|
||||
return z, noise_level
|
||||
|
||||
def decode(self, z):
|
||||
z = z / self.scale_factor
|
||||
return self.model.decode(z)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from ldm.util import count_params
|
||||
sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"]
|
||||
model = FrozenT5Embedder(version="google/t5-v1_1-xl").cuda()
|
||||
count_params(model, True)
|
||||
z = model(sentences)
|
||||
print(z.shape)
|
||||
|
||||
model = FrozenCLIPEmbedder().cuda()
|
||||
count_params(model, True)
|
||||
z = model(sentences)
|
||||
print(z.shape)
|
||||
|
||||
print("done.")
|
||||
676
ldm/modules/evaluate/adm_evaluator.py
Executable file
676
ldm/modules/evaluate/adm_evaluator.py
Executable file
@@ -0,0 +1,676 @@
|
||||
import argparse
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
import zipfile
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from multiprocessing import cpu_count
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from typing import Iterable, Optional, Tuple
|
||||
import yaml
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import tensorflow.compat.v1 as tf
|
||||
from scipy import linalg
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
|
||||
INCEPTION_V3_PATH = "classify_image_graph_def.pb"
|
||||
|
||||
FID_POOL_NAME = "pool_3:0"
|
||||
FID_SPATIAL_NAME = "mixed_6/conv:0"
|
||||
|
||||
REQUIREMENTS = f"This script has the following requirements: \n" \
|
||||
'tensorflow-gpu>=2.0' + "\n" + 'scipy' + "\n" + "requests" + "\n" + "tqdm"
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ref_batch", help="path to reference batch npz file")
|
||||
parser.add_argument("--sample_batch", help="path to sample batch npz file")
|
||||
args = parser.parse_args()
|
||||
|
||||
config = tf.ConfigProto(
|
||||
allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
|
||||
)
|
||||
config.gpu_options.allow_growth = True
|
||||
evaluator = Evaluator(tf.Session(config=config))
|
||||
|
||||
print("warming up TensorFlow...")
|
||||
# This will cause TF to print a bunch of verbose stuff now rather
|
||||
# than after the next print(), to help prevent confusion.
|
||||
evaluator.warmup()
|
||||
|
||||
print("computing reference batch activations...")
|
||||
ref_acts = evaluator.read_activations(args.ref_batch)
|
||||
print("computing/reading reference batch statistics...")
|
||||
ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
|
||||
|
||||
print("computing sample batch activations...")
|
||||
sample_acts = evaluator.read_activations(args.sample_batch)
|
||||
print("computing/reading sample batch statistics...")
|
||||
sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts)
|
||||
|
||||
print("Computing evaluations...")
|
||||
is_ = evaluator.compute_inception_score(sample_acts[0])
|
||||
print("Inception Score:", is_)
|
||||
fid = sample_stats.frechet_distance(ref_stats)
|
||||
print("FID:", fid)
|
||||
sfid = sample_stats_spatial.frechet_distance(ref_stats_spatial)
|
||||
print("sFID:", sfid)
|
||||
prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
|
||||
print("Precision:", prec)
|
||||
print("Recall:", recall)
|
||||
|
||||
savepath = '/'.join(args.sample_batch.split('/')[:-1])
|
||||
results_file = os.path.join(savepath,'evaluation_metrics.yaml')
|
||||
print(f'Saving evaluation results to "{results_file}"')
|
||||
|
||||
results = {
|
||||
'IS': is_,
|
||||
'FID': fid,
|
||||
'sFID': sfid,
|
||||
'Precision:':prec,
|
||||
'Recall': recall
|
||||
}
|
||||
|
||||
with open(results_file, 'w') as f:
|
||||
yaml.dump(results, f, default_flow_style=False)
|
||||
|
||||
class InvalidFIDException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FIDStatistics:
|
||||
def __init__(self, mu: np.ndarray, sigma: np.ndarray):
|
||||
self.mu = mu
|
||||
self.sigma = sigma
|
||||
|
||||
def frechet_distance(self, other, eps=1e-6):
|
||||
"""
|
||||
Compute the Frechet distance between two sets of statistics.
|
||||
"""
|
||||
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
|
||||
mu1, sigma1 = self.mu, self.sigma
|
||||
mu2, sigma2 = other.mu, other.sigma
|
||||
|
||||
mu1 = np.atleast_1d(mu1)
|
||||
mu2 = np.atleast_1d(mu2)
|
||||
|
||||
sigma1 = np.atleast_2d(sigma1)
|
||||
sigma2 = np.atleast_2d(sigma2)
|
||||
|
||||
assert (
|
||||
mu1.shape == mu2.shape
|
||||
), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
|
||||
assert (
|
||||
sigma1.shape == sigma2.shape
|
||||
), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
|
||||
|
||||
diff = mu1 - mu2
|
||||
|
||||
# product might be almost singular
|
||||
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
||||
if not np.isfinite(covmean).all():
|
||||
msg = (
|
||||
"fid calculation produces singular product; adding %s to diagonal of cov estimates"
|
||||
% eps
|
||||
)
|
||||
warnings.warn(msg)
|
||||
offset = np.eye(sigma1.shape[0]) * eps
|
||||
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
||||
|
||||
# numerical error might give slight imaginary component
|
||||
if np.iscomplexobj(covmean):
|
||||
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
||||
m = np.max(np.abs(covmean.imag))
|
||||
raise ValueError("Imaginary component {}".format(m))
|
||||
covmean = covmean.real
|
||||
|
||||
tr_covmean = np.trace(covmean)
|
||||
|
||||
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
||||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(
|
||||
self,
|
||||
session,
|
||||
batch_size=64,
|
||||
softmax_batch_size=512,
|
||||
):
|
||||
self.sess = session
|
||||
self.batch_size = batch_size
|
||||
self.softmax_batch_size = softmax_batch_size
|
||||
self.manifold_estimator = ManifoldEstimator(session)
|
||||
with self.sess.graph.as_default():
|
||||
self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
|
||||
self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
|
||||
self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)
|
||||
self.softmax = _create_softmax_graph(self.softmax_input)
|
||||
|
||||
def warmup(self):
|
||||
self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
|
||||
|
||||
def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
|
||||
with open_npz_array(npz_path, "arr_0") as reader:
|
||||
return self.compute_activations(reader.read_batches(self.batch_size))
|
||||
|
||||
def compute_activations(self, batches: Iterable[np.ndarray],silent=False) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Compute image features for downstream evals.
|
||||
|
||||
:param batches: a iterator over NHWC numpy arrays in [0, 255].
|
||||
:return: a tuple of numpy arrays of shape [N x X], where X is a feature
|
||||
dimension. The tuple is (pool_3, spatial).
|
||||
"""
|
||||
preds = []
|
||||
spatial_preds = []
|
||||
it = batches if silent else tqdm(batches)
|
||||
for batch in it:
|
||||
batch = batch.astype(np.float32)
|
||||
pred, spatial_pred = self.sess.run(
|
||||
[self.pool_features, self.spatial_features], {self.image_input: batch}
|
||||
)
|
||||
preds.append(pred.reshape([pred.shape[0], -1]))
|
||||
spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
|
||||
return (
|
||||
np.concatenate(preds, axis=0),
|
||||
np.concatenate(spatial_preds, axis=0),
|
||||
)
|
||||
|
||||
def read_statistics(
|
||||
self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
|
||||
) -> Tuple[FIDStatistics, FIDStatistics]:
|
||||
obj = np.load(npz_path)
|
||||
if "mu" in list(obj.keys()):
|
||||
return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
|
||||
obj["mu_s"], obj["sigma_s"]
|
||||
)
|
||||
return tuple(self.compute_statistics(x) for x in activations)
|
||||
|
||||
def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
|
||||
mu = np.mean(activations, axis=0)
|
||||
sigma = np.cov(activations, rowvar=False)
|
||||
return FIDStatistics(mu, sigma)
|
||||
|
||||
def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
|
||||
softmax_out = []
|
||||
for i in range(0, len(activations), self.softmax_batch_size):
|
||||
acts = activations[i : i + self.softmax_batch_size]
|
||||
softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
|
||||
preds = np.concatenate(softmax_out, axis=0)
|
||||
# https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
|
||||
scores = []
|
||||
for i in range(0, len(preds), split_size):
|
||||
part = preds[i : i + split_size]
|
||||
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
|
||||
kl = np.mean(np.sum(kl, 1))
|
||||
scores.append(np.exp(kl))
|
||||
return float(np.mean(scores))
|
||||
|
||||
def compute_prec_recall(
|
||||
self, activations_ref: np.ndarray, activations_sample: np.ndarray
|
||||
) -> Tuple[float, float]:
|
||||
radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
|
||||
radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
|
||||
pr = self.manifold_estimator.evaluate_pr(
|
||||
activations_ref, radii_1, activations_sample, radii_2
|
||||
)
|
||||
return (float(pr[0][0]), float(pr[1][0]))
|
||||
|
||||
|
||||
class ManifoldEstimator:
|
||||
"""
|
||||
A helper for comparing manifolds of feature vectors.
|
||||
|
||||
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session,
|
||||
row_batch_size=10000,
|
||||
col_batch_size=10000,
|
||||
nhood_sizes=(3,),
|
||||
clamp_to_percentile=None,
|
||||
eps=1e-5,
|
||||
):
|
||||
"""
|
||||
Estimate the manifold of given feature vectors.
|
||||
|
||||
:param session: the TensorFlow session.
|
||||
:param row_batch_size: row batch size to compute pairwise distances
|
||||
(parameter to trade-off between memory usage and performance).
|
||||
:param col_batch_size: column batch size to compute pairwise distances.
|
||||
:param nhood_sizes: number of neighbors used to estimate the manifold.
|
||||
:param clamp_to_percentile: prune hyperspheres that have radius larger than
|
||||
the given percentile.
|
||||
:param eps: small number for numerical stability.
|
||||
"""
|
||||
self.distance_block = DistanceBlock(session)
|
||||
self.row_batch_size = row_batch_size
|
||||
self.col_batch_size = col_batch_size
|
||||
self.nhood_sizes = nhood_sizes
|
||||
self.num_nhoods = len(nhood_sizes)
|
||||
self.clamp_to_percentile = clamp_to_percentile
|
||||
self.eps = eps
|
||||
|
||||
def warmup(self):
|
||||
feats, radii = (
|
||||
np.zeros([1, 2048], dtype=np.float32),
|
||||
np.zeros([1, 1], dtype=np.float32),
|
||||
)
|
||||
self.evaluate_pr(feats, radii, feats, radii)
|
||||
|
||||
def manifold_radii(self, features: np.ndarray) -> np.ndarray:
|
||||
num_images = len(features)
|
||||
|
||||
# Estimate manifold of features by calculating distances to k-NN of each sample.
|
||||
radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
|
||||
distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
|
||||
seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
|
||||
|
||||
for begin1 in range(0, num_images, self.row_batch_size):
|
||||
end1 = min(begin1 + self.row_batch_size, num_images)
|
||||
row_batch = features[begin1:end1]
|
||||
|
||||
for begin2 in range(0, num_images, self.col_batch_size):
|
||||
end2 = min(begin2 + self.col_batch_size, num_images)
|
||||
col_batch = features[begin2:end2]
|
||||
|
||||
# Compute distances between batches.
|
||||
distance_batch[
|
||||
0 : end1 - begin1, begin2:end2
|
||||
] = self.distance_block.pairwise_distances(row_batch, col_batch)
|
||||
|
||||
# Find the k-nearest neighbor from the current batch.
|
||||
radii[begin1:end1, :] = np.concatenate(
|
||||
[
|
||||
x[:, self.nhood_sizes]
|
||||
for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
if self.clamp_to_percentile is not None:
|
||||
max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
|
||||
radii[radii > max_distances] = 0
|
||||
return radii
|
||||
|
||||
def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):
|
||||
"""
|
||||
Evaluate if new feature vectors are at the manifold.
|
||||
"""
|
||||
num_eval_images = eval_features.shape[0]
|
||||
num_ref_images = radii.shape[0]
|
||||
distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
|
||||
batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
|
||||
max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
|
||||
nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
|
||||
|
||||
for begin1 in range(0, num_eval_images, self.row_batch_size):
|
||||
end1 = min(begin1 + self.row_batch_size, num_eval_images)
|
||||
feature_batch = eval_features[begin1:end1]
|
||||
|
||||
for begin2 in range(0, num_ref_images, self.col_batch_size):
|
||||
end2 = min(begin2 + self.col_batch_size, num_ref_images)
|
||||
ref_batch = features[begin2:end2]
|
||||
|
||||
distance_batch[
|
||||
0 : end1 - begin1, begin2:end2
|
||||
] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
|
||||
|
||||
# From the minibatch of new feature vectors, determine if they are in the estimated manifold.
|
||||
# If a feature vector is inside a hypersphere of some reference sample, then
|
||||
# the new sample lies at the estimated manifold.
|
||||
# The radii of the hyperspheres are determined from distances of neighborhood size k.
|
||||
samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
|
||||
batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
|
||||
|
||||
max_realism_score[begin1:end1] = np.max(
|
||||
radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
|
||||
)
|
||||
nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)
|
||||
|
||||
return {
|
||||
"fraction": float(np.mean(batch_predictions)),
|
||||
"batch_predictions": batch_predictions,
|
||||
"max_realisim_score": max_realism_score,
|
||||
"nearest_indices": nearest_indices,
|
||||
}
|
||||
|
||||
def evaluate_pr(
|
||||
self,
|
||||
features_1: np.ndarray,
|
||||
radii_1: np.ndarray,
|
||||
features_2: np.ndarray,
|
||||
radii_2: np.ndarray,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Evaluate precision and recall efficiently.
|
||||
|
||||
:param features_1: [N1 x D] feature vectors for reference batch.
|
||||
:param radii_1: [N1 x K1] radii for reference vectors.
|
||||
:param features_2: [N2 x D] feature vectors for the other batch.
|
||||
:param radii_2: [N x K2] radii for other vectors.
|
||||
:return: a tuple of arrays for (precision, recall):
|
||||
- precision: an np.ndarray of length K1
|
||||
- recall: an np.ndarray of length K2
|
||||
"""
|
||||
features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool)
|
||||
features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool)
|
||||
for begin_1 in range(0, len(features_1), self.row_batch_size):
|
||||
end_1 = begin_1 + self.row_batch_size
|
||||
batch_1 = features_1[begin_1:end_1]
|
||||
for begin_2 in range(0, len(features_2), self.col_batch_size):
|
||||
end_2 = begin_2 + self.col_batch_size
|
||||
batch_2 = features_2[begin_2:end_2]
|
||||
batch_1_in, batch_2_in = self.distance_block.less_thans(
|
||||
batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
|
||||
)
|
||||
features_1_status[begin_1:end_1] |= batch_1_in
|
||||
features_2_status[begin_2:end_2] |= batch_2_in
|
||||
return (
|
||||
np.mean(features_2_status.astype(np.float64), axis=0),
|
||||
np.mean(features_1_status.astype(np.float64), axis=0),
|
||||
)
|
||||
|
||||
|
||||
class DistanceBlock:
|
||||
"""
|
||||
Calculate pairwise distances between vectors.
|
||||
|
||||
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
|
||||
"""
|
||||
|
||||
def __init__(self, session):
|
||||
self.session = session
|
||||
|
||||
# Initialize TF graph to calculate pairwise distances.
|
||||
with session.graph.as_default():
|
||||
self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
|
||||
self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
|
||||
distance_block_16 = _batch_pairwise_distances(
|
||||
tf.cast(self._features_batch1, tf.float16),
|
||||
tf.cast(self._features_batch2, tf.float16),
|
||||
)
|
||||
self.distance_block = tf.cond(
|
||||
tf.reduce_all(tf.math.is_finite(distance_block_16)),
|
||||
lambda: tf.cast(distance_block_16, tf.float32),
|
||||
lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),
|
||||
)
|
||||
|
||||
# Extra logic for less thans.
|
||||
self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
|
||||
self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
|
||||
dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
|
||||
self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
|
||||
self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)
|
||||
|
||||
def pairwise_distances(self, U, V):
|
||||
"""
|
||||
Evaluate pairwise distances between two batches of feature vectors.
|
||||
"""
|
||||
return self.session.run(
|
||||
self.distance_block,
|
||||
feed_dict={self._features_batch1: U, self._features_batch2: V},
|
||||
)
|
||||
|
||||
def less_thans(self, batch_1, radii_1, batch_2, radii_2):
|
||||
return self.session.run(
|
||||
[self._batch_1_in, self._batch_2_in],
|
||||
feed_dict={
|
||||
self._features_batch1: batch_1,
|
||||
self._features_batch2: batch_2,
|
||||
self._radii1: radii_1,
|
||||
self._radii2: radii_2,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _batch_pairwise_distances(U, V):
|
||||
"""
|
||||
Compute pairwise distances between two batches of feature vectors.
|
||||
"""
|
||||
with tf.variable_scope("pairwise_dist_block"):
|
||||
# Squared norms of each row in U and V.
|
||||
norm_u = tf.reduce_sum(tf.square(U), 1)
|
||||
norm_v = tf.reduce_sum(tf.square(V), 1)
|
||||
|
||||
# norm_u as a column and norm_v as a row vectors.
|
||||
norm_u = tf.reshape(norm_u, [-1, 1])
|
||||
norm_v = tf.reshape(norm_v, [1, -1])
|
||||
|
||||
# Pairwise squared Euclidean distances.
|
||||
D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
|
||||
|
||||
return D
|
||||
|
||||
|
||||
class NpzArrayReader(ABC):
|
||||
@abstractmethod
|
||||
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remaining(self) -> int:
|
||||
pass
|
||||
|
||||
def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
|
||||
def gen_fn():
|
||||
while True:
|
||||
batch = self.read_batch(batch_size)
|
||||
if batch is None:
|
||||
break
|
||||
yield batch
|
||||
|
||||
rem = self.remaining()
|
||||
num_batches = rem // batch_size + int(rem % batch_size != 0)
|
||||
return BatchIterator(gen_fn, num_batches)
|
||||
|
||||
|
||||
class BatchIterator:
|
||||
def __init__(self, gen_fn, length):
|
||||
self.gen_fn = gen_fn
|
||||
self.length = length
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __iter__(self):
|
||||
return self.gen_fn()
|
||||
|
||||
|
||||
class StreamingNpzArrayReader(NpzArrayReader):
|
||||
def __init__(self, arr_f, shape, dtype):
|
||||
self.arr_f = arr_f
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
self.idx = 0
|
||||
|
||||
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
||||
if self.idx >= self.shape[0]:
|
||||
return None
|
||||
|
||||
bs = min(batch_size, self.shape[0] - self.idx)
|
||||
self.idx += bs
|
||||
|
||||
if self.dtype.itemsize == 0:
|
||||
return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
|
||||
|
||||
read_count = bs * np.prod(self.shape[1:])
|
||||
read_size = int(read_count * self.dtype.itemsize)
|
||||
data = _read_bytes(self.arr_f, read_size, "array data")
|
||||
return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
|
||||
|
||||
def remaining(self) -> int:
|
||||
return max(0, self.shape[0] - self.idx)
|
||||
|
||||
|
||||
class MemoryNpzArrayReader(NpzArrayReader):
|
||||
def __init__(self, arr):
|
||||
self.arr = arr
|
||||
self.idx = 0
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str, arr_name: str):
|
||||
with open(path, "rb") as f:
|
||||
arr = np.load(f)[arr_name]
|
||||
return cls(arr)
|
||||
|
||||
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
||||
if self.idx >= self.arr.shape[0]:
|
||||
return None
|
||||
|
||||
res = self.arr[self.idx : self.idx + batch_size]
|
||||
self.idx += batch_size
|
||||
return res
|
||||
|
||||
def remaining(self) -> int:
|
||||
return max(0, self.arr.shape[0] - self.idx)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
|
||||
with _open_npy_file(path, arr_name) as arr_f:
|
||||
version = np.lib.format.read_magic(arr_f)
|
||||
if version == (1, 0):
|
||||
header = np.lib.format.read_array_header_1_0(arr_f)
|
||||
elif version == (2, 0):
|
||||
header = np.lib.format.read_array_header_2_0(arr_f)
|
||||
else:
|
||||
yield MemoryNpzArrayReader.load(path, arr_name)
|
||||
return
|
||||
shape, fortran, dtype = header
|
||||
if fortran or dtype.hasobject:
|
||||
yield MemoryNpzArrayReader.load(path, arr_name)
|
||||
else:
|
||||
yield StreamingNpzArrayReader(arr_f, shape, dtype)
|
||||
|
||||
|
||||
def _read_bytes(fp, size, error_template="ran out of data"):
|
||||
"""
|
||||
Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
|
||||
|
||||
Read from file-like object until size bytes are read.
|
||||
Raises ValueError if not EOF is encountered before size bytes are read.
|
||||
Non-blocking objects only supported if they derive from io objects.
|
||||
Required as e.g. ZipExtFile in python 2.6 can return less data than
|
||||
requested.
|
||||
"""
|
||||
data = bytes()
|
||||
while True:
|
||||
# io files (default in python3) return None or raise on
|
||||
# would-block, python2 file will truncate, probably nothing can be
|
||||
# done about that. note that regular files can't be non-blocking
|
||||
try:
|
||||
r = fp.read(size - len(data))
|
||||
data += r
|
||||
if len(r) == 0 or len(data) == size:
|
||||
break
|
||||
except io.BlockingIOError:
|
||||
pass
|
||||
if len(data) != size:
|
||||
msg = "EOF: reading %s, expected %d bytes got %d"
|
||||
raise ValueError(msg % (error_template, size, len(data)))
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _open_npy_file(path: str, arr_name: str):
|
||||
with open(path, "rb") as f:
|
||||
with zipfile.ZipFile(f, "r") as zip_f:
|
||||
if f"{arr_name}.npy" not in zip_f.namelist():
|
||||
raise ValueError(f"missing {arr_name} in npz file")
|
||||
with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
|
||||
yield arr_f
|
||||
|
||||
|
||||
def _download_inception_model():
|
||||
if os.path.exists(INCEPTION_V3_PATH):
|
||||
return
|
||||
print("downloading InceptionV3 model...")
|
||||
with requests.get(INCEPTION_V3_URL, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
tmp_path = INCEPTION_V3_PATH + ".tmp"
|
||||
with open(tmp_path, "wb") as f:
|
||||
for chunk in tqdm(r.iter_content(chunk_size=8192)):
|
||||
f.write(chunk)
|
||||
os.rename(tmp_path, INCEPTION_V3_PATH)
|
||||
|
||||
|
||||
def _create_feature_graph(input_batch):
|
||||
_download_inception_model()
|
||||
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
|
||||
with open(INCEPTION_V3_PATH, "rb") as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
pool3, spatial = tf.import_graph_def(
|
||||
graph_def,
|
||||
input_map={f"ExpandDims:0": input_batch},
|
||||
return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
|
||||
name=prefix,
|
||||
)
|
||||
_update_shapes(pool3)
|
||||
spatial = spatial[..., :7]
|
||||
return pool3, spatial
|
||||
|
||||
|
||||
def _create_softmax_graph(input_batch):
|
||||
_download_inception_model()
|
||||
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
|
||||
with open(INCEPTION_V3_PATH, "rb") as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
(matmul,) = tf.import_graph_def(
|
||||
graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
|
||||
)
|
||||
w = matmul.inputs[1]
|
||||
logits = tf.matmul(input_batch, w)
|
||||
return tf.nn.softmax(logits)
|
||||
|
||||
|
||||
def _update_shapes(pool3):
|
||||
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
|
||||
ops = pool3.graph.get_operations()
|
||||
for op in ops:
|
||||
for o in op.outputs:
|
||||
shape = o.get_shape()
|
||||
if shape._dims is not None: # pylint: disable=protected-access
|
||||
# shape = [s.value for s in shape] TF 1.x
|
||||
shape = [s for s in shape] # TF 2.x
|
||||
new_shape = []
|
||||
for j, s in enumerate(shape):
|
||||
if s == 1 and j == 0:
|
||||
new_shape.append(None)
|
||||
else:
|
||||
new_shape.append(s)
|
||||
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
|
||||
return pool3
|
||||
|
||||
|
||||
def _numpy_partition(arr, kth, **kwargs):
|
||||
num_workers = min(cpu_count(), len(arr))
|
||||
chunk_size = len(arr) // num_workers
|
||||
extra = len(arr) % num_workers
|
||||
|
||||
start_idx = 0
|
||||
batches = []
|
||||
for i in range(num_workers):
|
||||
size = chunk_size + (1 if i < extra else 0)
|
||||
batches.append(arr[start_idx : start_idx + size])
|
||||
start_idx += size
|
||||
|
||||
with ThreadPool(num_workers) as pool:
|
||||
return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(REQUIREMENTS)
|
||||
main()
|
||||
630
ldm/modules/evaluate/evaluate_perceptualsim.py
Executable file
630
ldm/modules/evaluate/evaluate_perceptualsim.py
Executable file
@@ -0,0 +1,630 @@
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
from collections import namedtuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from torchvision import models
|
||||
from PIL import Image
|
||||
|
||||
from ldm.modules.evaluate.ssim import ssim
|
||||
|
||||
|
||||
transform = transforms.Compose([transforms.ToTensor()])
|
||||
|
||||
def normalize_tensor(in_feat, eps=1e-10):
|
||||
norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1)).view(
|
||||
in_feat.size()[0], 1, in_feat.size()[2], in_feat.size()[3]
|
||||
)
|
||||
return in_feat / (norm_factor.expand_as(in_feat) + eps)
|
||||
|
||||
|
||||
def cos_sim(in0, in1):
|
||||
in0_norm = normalize_tensor(in0)
|
||||
in1_norm = normalize_tensor(in1)
|
||||
N = in0.size()[0]
|
||||
X = in0.size()[2]
|
||||
Y = in0.size()[3]
|
||||
|
||||
return torch.mean(
|
||||
torch.mean(
|
||||
torch.sum(in0_norm * in1_norm, dim=1).view(N, 1, X, Y), dim=2
|
||||
).view(N, 1, 1, Y),
|
||||
dim=3,
|
||||
).view(N)
|
||||
|
||||
|
||||
class squeezenet(torch.nn.Module):
|
||||
def __init__(self, requires_grad=False, pretrained=True):
|
||||
super(squeezenet, self).__init__()
|
||||
pretrained_features = models.squeezenet1_1(
|
||||
pretrained=pretrained
|
||||
).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
self.slice6 = torch.nn.Sequential()
|
||||
self.slice7 = torch.nn.Sequential()
|
||||
self.N_slices = 7
|
||||
for x in range(2):
|
||||
self.slice1.add_module(str(x), pretrained_features[x])
|
||||
for x in range(2, 5):
|
||||
self.slice2.add_module(str(x), pretrained_features[x])
|
||||
for x in range(5, 8):
|
||||
self.slice3.add_module(str(x), pretrained_features[x])
|
||||
for x in range(8, 10):
|
||||
self.slice4.add_module(str(x), pretrained_features[x])
|
||||
for x in range(10, 11):
|
||||
self.slice5.add_module(str(x), pretrained_features[x])
|
||||
for x in range(11, 12):
|
||||
self.slice6.add_module(str(x), pretrained_features[x])
|
||||
for x in range(12, 13):
|
||||
self.slice7.add_module(str(x), pretrained_features[x])
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
h = self.slice1(X)
|
||||
h_relu1 = h
|
||||
h = self.slice2(h)
|
||||
h_relu2 = h
|
||||
h = self.slice3(h)
|
||||
h_relu3 = h
|
||||
h = self.slice4(h)
|
||||
h_relu4 = h
|
||||
h = self.slice5(h)
|
||||
h_relu5 = h
|
||||
h = self.slice6(h)
|
||||
h_relu6 = h
|
||||
h = self.slice7(h)
|
||||
h_relu7 = h
|
||||
vgg_outputs = namedtuple(
|
||||
"SqueezeOutputs",
|
||||
["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"],
|
||||
)
|
||||
out = vgg_outputs(
|
||||
h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class alexnet(torch.nn.Module):
|
||||
def __init__(self, requires_grad=False, pretrained=True):
|
||||
super(alexnet, self).__init__()
|
||||
alexnet_pretrained_features = models.alexnet(
|
||||
pretrained=pretrained
|
||||
).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
self.N_slices = 5
|
||||
for x in range(2):
|
||||
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
|
||||
for x in range(2, 5):
|
||||
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
|
||||
for x in range(5, 8):
|
||||
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
|
||||
for x in range(8, 10):
|
||||
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
|
||||
for x in range(10, 12):
|
||||
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
h = self.slice1(X)
|
||||
h_relu1 = h
|
||||
h = self.slice2(h)
|
||||
h_relu2 = h
|
||||
h = self.slice3(h)
|
||||
h_relu3 = h
|
||||
h = self.slice4(h)
|
||||
h_relu4 = h
|
||||
h = self.slice5(h)
|
||||
h_relu5 = h
|
||||
alexnet_outputs = namedtuple(
|
||||
"AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"]
|
||||
)
|
||||
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class vgg16(torch.nn.Module):
|
||||
def __init__(self, requires_grad=False, pretrained=True):
|
||||
super(vgg16, self).__init__()
|
||||
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
self.N_slices = 5
|
||||
for x in range(4):
|
||||
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(4, 9):
|
||||
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(9, 16):
|
||||
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(16, 23):
|
||||
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(23, 30):
|
||||
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
h = self.slice1(X)
|
||||
h_relu1_2 = h
|
||||
h = self.slice2(h)
|
||||
h_relu2_2 = h
|
||||
h = self.slice3(h)
|
||||
h_relu3_3 = h
|
||||
h = self.slice4(h)
|
||||
h_relu4_3 = h
|
||||
h = self.slice5(h)
|
||||
h_relu5_3 = h
|
||||
vgg_outputs = namedtuple(
|
||||
"VggOutputs",
|
||||
["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"],
|
||||
)
|
||||
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class resnet(torch.nn.Module):
|
||||
def __init__(self, requires_grad=False, pretrained=True, num=18):
|
||||
super(resnet, self).__init__()
|
||||
if num == 18:
|
||||
self.net = models.resnet18(pretrained=pretrained)
|
||||
elif num == 34:
|
||||
self.net = models.resnet34(pretrained=pretrained)
|
||||
elif num == 50:
|
||||
self.net = models.resnet50(pretrained=pretrained)
|
||||
elif num == 101:
|
||||
self.net = models.resnet101(pretrained=pretrained)
|
||||
elif num == 152:
|
||||
self.net = models.resnet152(pretrained=pretrained)
|
||||
self.N_slices = 5
|
||||
|
||||
self.conv1 = self.net.conv1
|
||||
self.bn1 = self.net.bn1
|
||||
self.relu = self.net.relu
|
||||
self.maxpool = self.net.maxpool
|
||||
self.layer1 = self.net.layer1
|
||||
self.layer2 = self.net.layer2
|
||||
self.layer3 = self.net.layer3
|
||||
self.layer4 = self.net.layer4
|
||||
|
||||
def forward(self, X):
|
||||
h = self.conv1(X)
|
||||
h = self.bn1(h)
|
||||
h = self.relu(h)
|
||||
h_relu1 = h
|
||||
h = self.maxpool(h)
|
||||
h = self.layer1(h)
|
||||
h_conv2 = h
|
||||
h = self.layer2(h)
|
||||
h_conv3 = h
|
||||
h = self.layer3(h)
|
||||
h_conv4 = h
|
||||
h = self.layer4(h)
|
||||
h_conv5 = h
|
||||
|
||||
outputs = namedtuple(
|
||||
"Outputs", ["relu1", "conv2", "conv3", "conv4", "conv5"]
|
||||
)
|
||||
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
|
||||
|
||||
return out
|
||||
|
||||
# Off-the-shelf deep network
|
||||
class PNet(torch.nn.Module):
|
||||
"""Pre-trained network with all channels equally weighted by default"""
|
||||
|
||||
def __init__(self, pnet_type="vgg", pnet_rand=False, use_gpu=True):
|
||||
super(PNet, self).__init__()
|
||||
|
||||
self.use_gpu = use_gpu
|
||||
|
||||
self.pnet_type = pnet_type
|
||||
self.pnet_rand = pnet_rand
|
||||
|
||||
self.shift = torch.Tensor([-0.030, -0.088, -0.188]).view(1, 3, 1, 1)
|
||||
self.scale = torch.Tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1)
|
||||
|
||||
if self.pnet_type in ["vgg", "vgg16"]:
|
||||
self.net = vgg16(pretrained=not self.pnet_rand, requires_grad=False)
|
||||
elif self.pnet_type == "alex":
|
||||
self.net = alexnet(
|
||||
pretrained=not self.pnet_rand, requires_grad=False
|
||||
)
|
||||
elif self.pnet_type[:-2] == "resnet":
|
||||
self.net = resnet(
|
||||
pretrained=not self.pnet_rand,
|
||||
requires_grad=False,
|
||||
num=int(self.pnet_type[-2:]),
|
||||
)
|
||||
elif self.pnet_type == "squeeze":
|
||||
self.net = squeezenet(
|
||||
pretrained=not self.pnet_rand, requires_grad=False
|
||||
)
|
||||
|
||||
self.L = self.net.N_slices
|
||||
|
||||
if use_gpu:
|
||||
self.net.cuda()
|
||||
self.shift = self.shift.cuda()
|
||||
self.scale = self.scale.cuda()
|
||||
|
||||
def forward(self, in0, in1, retPerLayer=False):
|
||||
in0_sc = (in0 - self.shift.expand_as(in0)) / self.scale.expand_as(in0)
|
||||
in1_sc = (in1 - self.shift.expand_as(in0)) / self.scale.expand_as(in0)
|
||||
|
||||
outs0 = self.net.forward(in0_sc)
|
||||
outs1 = self.net.forward(in1_sc)
|
||||
|
||||
if retPerLayer:
|
||||
all_scores = []
|
||||
for (kk, out0) in enumerate(outs0):
|
||||
cur_score = 1.0 - cos_sim(outs0[kk], outs1[kk])
|
||||
if kk == 0:
|
||||
val = 1.0 * cur_score
|
||||
else:
|
||||
val = val + cur_score
|
||||
if retPerLayer:
|
||||
all_scores += [cur_score]
|
||||
|
||||
if retPerLayer:
|
||||
return (val, all_scores)
|
||||
else:
|
||||
return val
|
||||
|
||||
|
||||
|
||||
|
||||
# The SSIM metric
|
||||
def ssim_metric(img1, img2, mask=None):
|
||||
return ssim(img1, img2, mask=mask, size_average=False)
|
||||
|
||||
|
||||
# The PSNR metric
|
||||
def psnr(img1, img2, mask=None,reshape=False):
|
||||
b = img1.size(0)
|
||||
if not (mask is None):
|
||||
b = img1.size(0)
|
||||
mse_err = (img1 - img2).pow(2) * mask
|
||||
if reshape:
|
||||
mse_err = mse_err.reshape(b, -1).sum(dim=1) / (
|
||||
3 * mask.reshape(b, -1).sum(dim=1).clamp(min=1)
|
||||
)
|
||||
else:
|
||||
mse_err = mse_err.view(b, -1).sum(dim=1) / (
|
||||
3 * mask.view(b, -1).sum(dim=1).clamp(min=1)
|
||||
)
|
||||
else:
|
||||
if reshape:
|
||||
mse_err = (img1 - img2).pow(2).reshape(b, -1).mean(dim=1)
|
||||
else:
|
||||
mse_err = (img1 - img2).pow(2).view(b, -1).mean(dim=1)
|
||||
|
||||
psnr = 10 * (1 / mse_err).log10()
|
||||
return psnr
|
||||
|
||||
|
||||
# The perceptual similarity metric
|
||||
def perceptual_sim(img1, img2, vgg16):
|
||||
# First extract features
|
||||
dist = vgg16(img1 * 2 - 1, img2 * 2 - 1)
|
||||
|
||||
return dist
|
||||
|
||||
def load_img(img_name, size=None):
|
||||
try:
|
||||
img = Image.open(img_name)
|
||||
|
||||
if type(size) == int:
|
||||
img = img.resize((size, size))
|
||||
elif size is not None:
|
||||
img = img.resize((size[1], size[0]))
|
||||
|
||||
img = transform(img).cuda()
|
||||
img = img.unsqueeze(0)
|
||||
except Exception as e:
|
||||
print("Failed at loading %s " % img_name)
|
||||
print(e)
|
||||
img = torch.zeros(1, 3, 256, 256).cuda()
|
||||
raise
|
||||
return img
|
||||
|
||||
|
||||
def compute_perceptual_similarity(folder, pred_img, tgt_img, take_every_other):
|
||||
|
||||
# Load VGG16 for feature similarity
|
||||
vgg16 = PNet().to("cuda")
|
||||
vgg16.eval()
|
||||
vgg16.cuda()
|
||||
|
||||
values_percsim = []
|
||||
values_ssim = []
|
||||
values_psnr = []
|
||||
folders = os.listdir(folder)
|
||||
for i, f in tqdm(enumerate(sorted(folders))):
|
||||
pred_imgs = glob.glob(folder + f + "/" + pred_img)
|
||||
tgt_imgs = glob.glob(folder + f + "/" + tgt_img)
|
||||
assert len(tgt_imgs) == 1
|
||||
|
||||
perc_sim = 10000
|
||||
ssim_sim = -10
|
||||
psnr_sim = -10
|
||||
for p_img in pred_imgs:
|
||||
t_img = load_img(tgt_imgs[0])
|
||||
p_img = load_img(p_img, size=t_img.shape[2:])
|
||||
t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()
|
||||
perc_sim = min(perc_sim, t_perc_sim)
|
||||
|
||||
ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item())
|
||||
psnr_sim = max(psnr_sim, psnr(p_img, t_img).item())
|
||||
|
||||
values_percsim += [perc_sim]
|
||||
values_ssim += [ssim_sim]
|
||||
values_psnr += [psnr_sim]
|
||||
|
||||
if take_every_other:
|
||||
n_valuespercsim = []
|
||||
n_valuesssim = []
|
||||
n_valuespsnr = []
|
||||
for i in range(0, len(values_percsim) // 2):
|
||||
n_valuespercsim += [
|
||||
min(values_percsim[2 * i], values_percsim[2 * i + 1])
|
||||
]
|
||||
n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]
|
||||
n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]
|
||||
|
||||
values_percsim = n_valuespercsim
|
||||
values_ssim = n_valuesssim
|
||||
values_psnr = n_valuespsnr
|
||||
|
||||
avg_percsim = np.mean(np.array(values_percsim))
|
||||
std_percsim = np.std(np.array(values_percsim))
|
||||
|
||||
avg_psnr = np.mean(np.array(values_psnr))
|
||||
std_psnr = np.std(np.array(values_psnr))
|
||||
|
||||
avg_ssim = np.mean(np.array(values_ssim))
|
||||
std_ssim = np.std(np.array(values_ssim))
|
||||
|
||||
return {
|
||||
"Perceptual similarity": (avg_percsim, std_percsim),
|
||||
"PSNR": (avg_psnr, std_psnr),
|
||||
"SSIM": (avg_ssim, std_ssim),
|
||||
}
|
||||
|
||||
|
||||
def compute_perceptual_similarity_from_list(pred_imgs_list, tgt_imgs_list,
|
||||
take_every_other,
|
||||
simple_format=True):
|
||||
|
||||
# Load VGG16 for feature similarity
|
||||
vgg16 = PNet().to("cuda")
|
||||
vgg16.eval()
|
||||
vgg16.cuda()
|
||||
|
||||
values_percsim = []
|
||||
values_ssim = []
|
||||
values_psnr = []
|
||||
equal_count = 0
|
||||
ambig_count = 0
|
||||
for i, tgt_img in enumerate(tqdm(tgt_imgs_list)):
|
||||
pred_imgs = pred_imgs_list[i]
|
||||
tgt_imgs = [tgt_img]
|
||||
assert len(tgt_imgs) == 1
|
||||
|
||||
if type(pred_imgs) != list:
|
||||
pred_imgs = [pred_imgs]
|
||||
|
||||
perc_sim = 10000
|
||||
ssim_sim = -10
|
||||
psnr_sim = -10
|
||||
assert len(pred_imgs)>0
|
||||
for p_img in pred_imgs:
|
||||
t_img = load_img(tgt_imgs[0])
|
||||
p_img = load_img(p_img, size=t_img.shape[2:])
|
||||
t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()
|
||||
perc_sim = min(perc_sim, t_perc_sim)
|
||||
|
||||
ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item())
|
||||
psnr_sim = max(psnr_sim, psnr(p_img, t_img).item())
|
||||
|
||||
values_percsim += [perc_sim]
|
||||
values_ssim += [ssim_sim]
|
||||
if psnr_sim != np.float("inf"):
|
||||
values_psnr += [psnr_sim]
|
||||
else:
|
||||
if torch.allclose(p_img, t_img):
|
||||
equal_count += 1
|
||||
print("{} equal src and wrp images.".format(equal_count))
|
||||
else:
|
||||
ambig_count += 1
|
||||
print("{} ambiguous src and wrp images.".format(ambig_count))
|
||||
|
||||
if take_every_other:
|
||||
n_valuespercsim = []
|
||||
n_valuesssim = []
|
||||
n_valuespsnr = []
|
||||
for i in range(0, len(values_percsim) // 2):
|
||||
n_valuespercsim += [
|
||||
min(values_percsim[2 * i], values_percsim[2 * i + 1])
|
||||
]
|
||||
n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]
|
||||
n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]
|
||||
|
||||
values_percsim = n_valuespercsim
|
||||
values_ssim = n_valuesssim
|
||||
values_psnr = n_valuespsnr
|
||||
|
||||
avg_percsim = np.mean(np.array(values_percsim))
|
||||
std_percsim = np.std(np.array(values_percsim))
|
||||
|
||||
avg_psnr = np.mean(np.array(values_psnr))
|
||||
std_psnr = np.std(np.array(values_psnr))
|
||||
|
||||
avg_ssim = np.mean(np.array(values_ssim))
|
||||
std_ssim = np.std(np.array(values_ssim))
|
||||
|
||||
if simple_format:
|
||||
# just to make yaml formatting readable
|
||||
return {
|
||||
"Perceptual similarity": [float(avg_percsim), float(std_percsim)],
|
||||
"PSNR": [float(avg_psnr), float(std_psnr)],
|
||||
"SSIM": [float(avg_ssim), float(std_ssim)],
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"Perceptual similarity": (avg_percsim, std_percsim),
|
||||
"PSNR": (avg_psnr, std_psnr),
|
||||
"SSIM": (avg_ssim, std_ssim),
|
||||
}
|
||||
|
||||
|
||||
def compute_perceptual_similarity_from_list_topk(pred_imgs_list, tgt_imgs_list,
|
||||
take_every_other, resize=False):
|
||||
|
||||
# Load VGG16 for feature similarity
|
||||
vgg16 = PNet().to("cuda")
|
||||
vgg16.eval()
|
||||
vgg16.cuda()
|
||||
|
||||
values_percsim = []
|
||||
values_ssim = []
|
||||
values_psnr = []
|
||||
individual_percsim = []
|
||||
individual_ssim = []
|
||||
individual_psnr = []
|
||||
for i, tgt_img in enumerate(tqdm(tgt_imgs_list)):
|
||||
pred_imgs = pred_imgs_list[i]
|
||||
tgt_imgs = [tgt_img]
|
||||
assert len(tgt_imgs) == 1
|
||||
|
||||
if type(pred_imgs) != list:
|
||||
assert False
|
||||
pred_imgs = [pred_imgs]
|
||||
|
||||
perc_sim = 10000
|
||||
ssim_sim = -10
|
||||
psnr_sim = -10
|
||||
sample_percsim = list()
|
||||
sample_ssim = list()
|
||||
sample_psnr = list()
|
||||
for p_img in pred_imgs:
|
||||
if resize:
|
||||
t_img = load_img(tgt_imgs[0], size=(256,256))
|
||||
else:
|
||||
t_img = load_img(tgt_imgs[0])
|
||||
p_img = load_img(p_img, size=t_img.shape[2:])
|
||||
|
||||
t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()
|
||||
sample_percsim.append(t_perc_sim)
|
||||
perc_sim = min(perc_sim, t_perc_sim)
|
||||
|
||||
t_ssim = ssim_metric(p_img, t_img).item()
|
||||
sample_ssim.append(t_ssim)
|
||||
ssim_sim = max(ssim_sim, t_ssim)
|
||||
|
||||
t_psnr = psnr(p_img, t_img).item()
|
||||
sample_psnr.append(t_psnr)
|
||||
psnr_sim = max(psnr_sim, t_psnr)
|
||||
|
||||
values_percsim += [perc_sim]
|
||||
values_ssim += [ssim_sim]
|
||||
values_psnr += [psnr_sim]
|
||||
individual_percsim.append(sample_percsim)
|
||||
individual_ssim.append(sample_ssim)
|
||||
individual_psnr.append(sample_psnr)
|
||||
|
||||
if take_every_other:
|
||||
assert False, "Do this later, after specifying topk to get proper results"
|
||||
n_valuespercsim = []
|
||||
n_valuesssim = []
|
||||
n_valuespsnr = []
|
||||
for i in range(0, len(values_percsim) // 2):
|
||||
n_valuespercsim += [
|
||||
min(values_percsim[2 * i], values_percsim[2 * i + 1])
|
||||
]
|
||||
n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]
|
||||
n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]
|
||||
|
||||
values_percsim = n_valuespercsim
|
||||
values_ssim = n_valuesssim
|
||||
values_psnr = n_valuespsnr
|
||||
|
||||
avg_percsim = np.mean(np.array(values_percsim))
|
||||
std_percsim = np.std(np.array(values_percsim))
|
||||
|
||||
avg_psnr = np.mean(np.array(values_psnr))
|
||||
std_psnr = np.std(np.array(values_psnr))
|
||||
|
||||
avg_ssim = np.mean(np.array(values_ssim))
|
||||
std_ssim = np.std(np.array(values_ssim))
|
||||
|
||||
individual_percsim = np.array(individual_percsim)
|
||||
individual_psnr = np.array(individual_psnr)
|
||||
individual_ssim = np.array(individual_ssim)
|
||||
|
||||
return {
|
||||
"avg_of_best": {
|
||||
"Perceptual similarity": [float(avg_percsim), float(std_percsim)],
|
||||
"PSNR": [float(avg_psnr), float(std_psnr)],
|
||||
"SSIM": [float(avg_ssim), float(std_ssim)],
|
||||
},
|
||||
"individual": {
|
||||
"PSIM": individual_percsim,
|
||||
"PSNR": individual_psnr,
|
||||
"SSIM": individual_ssim,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = argparse.ArgumentParser()
|
||||
args.add_argument("--folder", type=str, default="")
|
||||
args.add_argument("--pred_image", type=str, default="")
|
||||
args.add_argument("--target_image", type=str, default="")
|
||||
args.add_argument("--take_every_other", action="store_true", default=False)
|
||||
args.add_argument("--output_file", type=str, default="")
|
||||
|
||||
opts = args.parse_args()
|
||||
|
||||
folder = opts.folder
|
||||
pred_img = opts.pred_image
|
||||
tgt_img = opts.target_image
|
||||
|
||||
results = compute_perceptual_similarity(
|
||||
folder, pred_img, tgt_img, opts.take_every_other
|
||||
)
|
||||
|
||||
f = open(opts.output_file, 'w')
|
||||
for key in results:
|
||||
print("%s for %s: \n" % (key, opts.folder))
|
||||
print(
|
||||
"\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1])
|
||||
)
|
||||
|
||||
f.write("%s for %s: \n" % (key, opts.folder))
|
||||
f.write(
|
||||
"\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1])
|
||||
)
|
||||
|
||||
f.close()
|
||||
147
ldm/modules/evaluate/frechet_video_distance.py
Executable file
147
ldm/modules/evaluate/frechet_video_distance.py
Executable file
@@ -0,0 +1,147 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python2, python3
|
||||
"""Minimal Reference implementation for the Frechet Video Distance (FVD).
|
||||
|
||||
FVD is a metric for the quality of video generation models. It is inspired by
|
||||
the FID (Frechet Inception Distance) used for images, but uses a different
|
||||
embedding to be better suitable for videos.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
import six
|
||||
import tensorflow.compat.v1 as tf
|
||||
import tensorflow_gan as tfgan
|
||||
import tensorflow_hub as hub
|
||||
|
||||
|
||||
def preprocess(videos, target_resolution):
|
||||
"""Runs some preprocessing on the videos for I3D model.
|
||||
|
||||
Args:
|
||||
videos: <T>[batch_size, num_frames, height, width, depth] The videos to be
|
||||
preprocessed. We don't care about the specific dtype of the videos, it can
|
||||
be anything that tf.image.resize_bilinear accepts. Values are expected to
|
||||
be in the range 0-255.
|
||||
target_resolution: (width, height): target video resolution
|
||||
|
||||
Returns:
|
||||
videos: <float32>[batch_size, num_frames, height, width, depth]
|
||||
"""
|
||||
videos_shape = list(videos.shape)
|
||||
all_frames = tf.reshape(videos, [-1] + videos_shape[-3:])
|
||||
resized_videos = tf.image.resize_bilinear(all_frames, size=target_resolution)
|
||||
target_shape = [videos_shape[0], -1] + list(target_resolution) + [3]
|
||||
output_videos = tf.reshape(resized_videos, target_shape)
|
||||
scaled_videos = 2. * tf.cast(output_videos, tf.float32) / 255. - 1
|
||||
return scaled_videos
|
||||
|
||||
|
||||
def _is_in_graph(tensor_name):
|
||||
"""Checks whether a given tensor does exists in the graph."""
|
||||
try:
|
||||
tf.get_default_graph().get_tensor_by_name(tensor_name)
|
||||
except KeyError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def create_id3_embedding(videos,warmup=False,batch_size=16):
|
||||
"""Embeds the given videos using the Inflated 3D Convolution ne twork.
|
||||
|
||||
Downloads the graph of the I3D from tf.hub and adds it to the graph on the
|
||||
first call.
|
||||
|
||||
Args:
|
||||
videos: <float32>[batch_size, num_frames, height=224, width=224, depth=3].
|
||||
Expected range is [-1, 1].
|
||||
|
||||
Returns:
|
||||
embedding: <float32>[batch_size, embedding_size]. embedding_size depends
|
||||
on the model used.
|
||||
|
||||
Raises:
|
||||
ValueError: when a provided embedding_layer is not supported.
|
||||
"""
|
||||
|
||||
# batch_size = 16
|
||||
module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1"
|
||||
|
||||
|
||||
# Making sure that we import the graph separately for
|
||||
# each different input video tensor.
|
||||
module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str(
|
||||
videos.name).replace(":", "_")
|
||||
|
||||
|
||||
|
||||
assert_ops = [
|
||||
tf.Assert(
|
||||
tf.reduce_max(videos) <= 1.001,
|
||||
["max value in frame is > 1", videos]),
|
||||
tf.Assert(
|
||||
tf.reduce_min(videos) >= -1.001,
|
||||
["min value in frame is < -1", videos]),
|
||||
tf.assert_equal(
|
||||
tf.shape(videos)[0],
|
||||
batch_size, ["invalid frame batch size: ",
|
||||
tf.shape(videos)],
|
||||
summarize=6),
|
||||
]
|
||||
with tf.control_dependencies(assert_ops):
|
||||
videos = tf.identity(videos)
|
||||
|
||||
module_scope = "%s_apply_default/" % module_name
|
||||
|
||||
# To check whether the module has already been loaded into the graph, we look
|
||||
# for a given tensor name. If this tensor name exists, we assume the function
|
||||
# has been called before and the graph was imported. Otherwise we import it.
|
||||
# Note: in theory, the tensor could exist, but have wrong shapes.
|
||||
# This will happen if create_id3_embedding is called with a frames_placehoder
|
||||
# of wrong size/batch size, because even though that will throw a tf.Assert
|
||||
# on graph-execution time, it will insert the tensor (with wrong shape) into
|
||||
# the graph. This is why we need the following assert.
|
||||
if warmup:
|
||||
video_batch_size = int(videos.shape[0])
|
||||
assert video_batch_size in [batch_size, -1, None], f"Invalid batch size {video_batch_size}"
|
||||
tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
|
||||
if not _is_in_graph(tensor_name):
|
||||
i3d_model = hub.Module(module_spec, name=module_name)
|
||||
i3d_model(videos)
|
||||
|
||||
# gets the kinetics-i3d-400-logits layer
|
||||
tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
|
||||
tensor = tf.get_default_graph().get_tensor_by_name(tensor_name)
|
||||
return tensor
|
||||
|
||||
|
||||
def calculate_fvd(real_activations,
|
||||
generated_activations):
|
||||
"""Returns a list of ops that compute metrics as funcs of activations.
|
||||
|
||||
Args:
|
||||
real_activations: <float32>[num_samples, embedding_size]
|
||||
generated_activations: <float32>[num_samples, embedding_size]
|
||||
|
||||
Returns:
|
||||
A scalar that contains the requested FVD.
|
||||
"""
|
||||
return tfgan.eval.frechet_classifier_distance_from_activations(
|
||||
real_activations, generated_activations)
|
||||
124
ldm/modules/evaluate/ssim.py
Executable file
124
ldm/modules/evaluate/ssim.py
Executable file
@@ -0,0 +1,124 @@
|
||||
# MIT Licence
|
||||
|
||||
# Methods to predict the SSIM, taken from
|
||||
# https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
|
||||
|
||||
from math import exp
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
|
||||
def gaussian(window_size, sigma):
|
||||
gauss = torch.Tensor(
|
||||
[
|
||||
exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2))
|
||||
for x in range(window_size)
|
||||
]
|
||||
)
|
||||
return gauss / gauss.sum()
|
||||
|
||||
|
||||
def create_window(window_size, channel):
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
||||
window = Variable(
|
||||
_2D_window.expand(channel, 1, window_size, window_size).contiguous()
|
||||
)
|
||||
return window
|
||||
|
||||
|
||||
def _ssim(
|
||||
img1, img2, window, window_size, channel, mask=None, size_average=True
|
||||
):
|
||||
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
||||
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
||||
|
||||
mu1_sq = mu1.pow(2)
|
||||
mu2_sq = mu2.pow(2)
|
||||
mu1_mu2 = mu1 * mu2
|
||||
|
||||
sigma1_sq = (
|
||||
F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel)
|
||||
- mu1_sq
|
||||
)
|
||||
sigma2_sq = (
|
||||
F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel)
|
||||
- mu2_sq
|
||||
)
|
||||
sigma12 = (
|
||||
F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
|
||||
- mu1_mu2
|
||||
)
|
||||
|
||||
C1 = (0.01) ** 2
|
||||
C2 = (0.03) ** 2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
|
||||
(mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
|
||||
)
|
||||
|
||||
if not (mask is None):
|
||||
b = mask.size(0)
|
||||
ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask
|
||||
ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum(
|
||||
dim=1
|
||||
).clamp(min=1)
|
||||
return ssim_map
|
||||
|
||||
import pdb
|
||||
|
||||
pdb.set_trace
|
||||
|
||||
if size_average:
|
||||
return ssim_map.mean()
|
||||
else:
|
||||
return ssim_map.mean(1).mean(1).mean(1)
|
||||
|
||||
|
||||
class SSIM(torch.nn.Module):
|
||||
def __init__(self, window_size=11, size_average=True):
|
||||
super(SSIM, self).__init__()
|
||||
self.window_size = window_size
|
||||
self.size_average = size_average
|
||||
self.channel = 1
|
||||
self.window = create_window(window_size, self.channel)
|
||||
|
||||
def forward(self, img1, img2, mask=None):
|
||||
(_, channel, _, _) = img1.size()
|
||||
|
||||
if (
|
||||
channel == self.channel
|
||||
and self.window.data.type() == img1.data.type()
|
||||
):
|
||||
window = self.window
|
||||
else:
|
||||
window = create_window(self.window_size, channel)
|
||||
|
||||
if img1.is_cuda:
|
||||
window = window.cuda(img1.get_device())
|
||||
window = window.type_as(img1)
|
||||
|
||||
self.window = window
|
||||
self.channel = channel
|
||||
|
||||
return _ssim(
|
||||
img1,
|
||||
img2,
|
||||
window,
|
||||
self.window_size,
|
||||
channel,
|
||||
mask,
|
||||
self.size_average,
|
||||
)
|
||||
|
||||
|
||||
def ssim(img1, img2, window_size=11, mask=None, size_average=True):
|
||||
(_, channel, _, _) = img1.size()
|
||||
window = create_window(window_size, channel)
|
||||
|
||||
if img1.is_cuda:
|
||||
window = window.cuda(img1.get_device())
|
||||
window = window.type_as(img1)
|
||||
|
||||
return _ssim(img1, img2, window, window_size, channel, mask, size_average)
|
||||
294
ldm/modules/evaluate/torch_frechet_video_distance.py
Executable file
294
ldm/modules/evaluate/torch_frechet_video_distance.py
Executable file
@@ -0,0 +1,294 @@
|
||||
# based on https://github.com/universome/fvd-comparison/blob/master/compare_models.py; huge thanks!
|
||||
import os
|
||||
import numpy as np
|
||||
import io
|
||||
import re
|
||||
import requests
|
||||
import html
|
||||
import hashlib
|
||||
import urllib
|
||||
import urllib.request
|
||||
import scipy.linalg
|
||||
import multiprocessing as mp
|
||||
import glob
|
||||
|
||||
|
||||
from tqdm import tqdm
|
||||
from typing import Any, List, Tuple, Union, Dict, Callable
|
||||
|
||||
from torchvision.io import read_video
|
||||
import torch; torch.set_grad_enabled(False)
|
||||
from einops import rearrange
|
||||
|
||||
from nitro.util import isvideo
|
||||
|
||||
def compute_frechet_distance(mu_sample,sigma_sample,mu_ref,sigma_ref) -> float:
|
||||
print('Calculate frechet distance...')
|
||||
m = np.square(mu_sample - mu_ref).sum()
|
||||
s, _ = scipy.linalg.sqrtm(np.dot(sigma_sample, sigma_ref), disp=False) # pylint: disable=no-member
|
||||
fid = np.real(m + np.trace(sigma_sample + sigma_ref - s * 2))
|
||||
|
||||
return float(fid)
|
||||
|
||||
|
||||
def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
mu = feats.mean(axis=0) # [d]
|
||||
sigma = np.cov(feats, rowvar=False) # [d, d]
|
||||
|
||||
return mu, sigma
|
||||
|
||||
|
||||
def open_url(url: str, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False) -> Any:
|
||||
"""Download the given URL and return a binary-mode file object to access the data."""
|
||||
assert num_attempts >= 1
|
||||
|
||||
# Doesn't look like an URL scheme so interpret it as a local filename.
|
||||
if not re.match('^[a-z]+://', url):
|
||||
return url if return_filename else open(url, "rb")
|
||||
|
||||
# Handle file URLs. This code handles unusual file:// patterns that
|
||||
# arise on Windows:
|
||||
#
|
||||
# file:///c:/foo.txt
|
||||
#
|
||||
# which would translate to a local '/c:/foo.txt' filename that's
|
||||
# invalid. Drop the forward slash for such pathnames.
|
||||
#
|
||||
# If you touch this code path, you should test it on both Linux and
|
||||
# Windows.
|
||||
#
|
||||
# Some internet resources suggest using urllib.request.url2pathname() but
|
||||
# but that converts forward slashes to backslashes and this causes
|
||||
# its own set of problems.
|
||||
if url.startswith('file://'):
|
||||
filename = urllib.parse.urlparse(url).path
|
||||
if re.match(r'^/[a-zA-Z]:', filename):
|
||||
filename = filename[1:]
|
||||
return filename if return_filename else open(filename, "rb")
|
||||
|
||||
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
||||
|
||||
# Download.
|
||||
url_name = None
|
||||
url_data = None
|
||||
with requests.Session() as session:
|
||||
if verbose:
|
||||
print("Downloading %s ..." % url, end="", flush=True)
|
||||
for attempts_left in reversed(range(num_attempts)):
|
||||
try:
|
||||
with session.get(url) as res:
|
||||
res.raise_for_status()
|
||||
if len(res.content) == 0:
|
||||
raise IOError("No data received")
|
||||
|
||||
if len(res.content) < 8192:
|
||||
content_str = res.content.decode("utf-8")
|
||||
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
||||
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
||||
if len(links) == 1:
|
||||
url = requests.compat.urljoin(url, links[0])
|
||||
raise IOError("Google Drive virus checker nag")
|
||||
if "Google Drive - Quota exceeded" in content_str:
|
||||
raise IOError("Google Drive download quota exceeded -- please try again later")
|
||||
|
||||
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
||||
url_name = match[1] if match else url
|
||||
url_data = res.content
|
||||
if verbose:
|
||||
print(" done")
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except:
|
||||
if not attempts_left:
|
||||
if verbose:
|
||||
print(" failed")
|
||||
raise
|
||||
if verbose:
|
||||
print(".", end="", flush=True)
|
||||
|
||||
# Return data as file object.
|
||||
assert not return_filename
|
||||
return io.BytesIO(url_data)
|
||||
|
||||
def load_video(ip):
|
||||
vid, *_ = read_video(ip)
|
||||
vid = rearrange(vid, 't h w c -> t c h w').to(torch.uint8)
|
||||
return vid
|
||||
|
||||
def get_data_from_str(input_str,nprc = None):
|
||||
assert os.path.isdir(input_str), f'Specified input folder "{input_str}" is not a directory'
|
||||
vid_filelist = glob.glob(os.path.join(input_str,'*.mp4'))
|
||||
print(f'Found {len(vid_filelist)} videos in dir {input_str}')
|
||||
|
||||
if nprc is None:
|
||||
try:
|
||||
nprc = mp.cpu_count()
|
||||
except NotImplementedError:
|
||||
print('WARNING: cpu_count() not avlailable, using only 1 cpu for video loading')
|
||||
nprc = 1
|
||||
|
||||
pool = mp.Pool(processes=nprc)
|
||||
|
||||
vids = []
|
||||
for v in tqdm(pool.imap_unordered(load_video,vid_filelist),total=len(vid_filelist),desc='Loading videos...'):
|
||||
vids.append(v)
|
||||
|
||||
|
||||
vids = torch.stack(vids,dim=0).float()
|
||||
|
||||
return vids
|
||||
|
||||
def get_stats(stats):
|
||||
assert os.path.isfile(stats) and stats.endswith('.npz'), f'no stats found under {stats}'
|
||||
|
||||
print(f'Using precomputed statistics under {stats}')
|
||||
stats = np.load(stats)
|
||||
stats = {key: stats[key] for key in stats.files}
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_fvd(ref_input, sample_input, bs=32,
|
||||
ref_stats=None,
|
||||
sample_stats=None,
|
||||
nprc_load=None):
|
||||
|
||||
|
||||
|
||||
calc_stats = ref_stats is None or sample_stats is None
|
||||
|
||||
if calc_stats:
|
||||
|
||||
only_ref = sample_stats is not None
|
||||
only_sample = ref_stats is not None
|
||||
|
||||
|
||||
if isinstance(ref_input,str) and not only_sample:
|
||||
ref_input = get_data_from_str(ref_input,nprc_load)
|
||||
|
||||
if isinstance(sample_input, str) and not only_ref:
|
||||
sample_input = get_data_from_str(sample_input, nprc_load)
|
||||
|
||||
stats = compute_statistics(sample_input,ref_input,
|
||||
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||
bs=bs,
|
||||
only_ref=only_ref,
|
||||
only_sample=only_sample)
|
||||
|
||||
if only_ref:
|
||||
stats.update(get_stats(sample_stats))
|
||||
elif only_sample:
|
||||
stats.update(get_stats(ref_stats))
|
||||
|
||||
|
||||
|
||||
else:
|
||||
stats = get_stats(sample_stats)
|
||||
stats.update(get_stats(ref_stats))
|
||||
|
||||
fvd = compute_frechet_distance(**stats)
|
||||
|
||||
return {'FVD' : fvd,}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_statistics(videos_fake, videos_real, device: str='cuda', bs=32, only_ref=False,only_sample=False) -> Dict:
|
||||
detector_url = 'https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1'
|
||||
detector_kwargs = dict(rescale=True, resize=True, return_features=True) # Return raw features before the softmax layer.
|
||||
|
||||
with open_url(detector_url, verbose=False) as f:
|
||||
detector = torch.jit.load(f).eval().to(device)
|
||||
|
||||
|
||||
|
||||
assert not (only_sample and only_ref), 'only_ref and only_sample arguments are mutually exclusive'
|
||||
|
||||
ref_embed, sample_embed = [], []
|
||||
|
||||
info = f'Computing I3D activations for FVD score with batch size {bs}'
|
||||
|
||||
if only_ref:
|
||||
|
||||
if not isvideo(videos_real):
|
||||
# if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
|
||||
videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float()
|
||||
print(videos_real.shape)
|
||||
|
||||
if videos_real.shape[0] % bs == 0:
|
||||
n_secs = videos_real.shape[0] // bs
|
||||
else:
|
||||
n_secs = videos_real.shape[0] // bs + 1
|
||||
|
||||
videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
|
||||
|
||||
for ref_v in tqdm(videos_real, total=len(videos_real),desc=info):
|
||||
|
||||
feats_ref = detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
|
||||
ref_embed.append(feats_ref)
|
||||
|
||||
elif only_sample:
|
||||
|
||||
if not isvideo(videos_fake):
|
||||
# if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
|
||||
videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float()
|
||||
print(videos_fake.shape)
|
||||
|
||||
if videos_fake.shape[0] % bs == 0:
|
||||
n_secs = videos_fake.shape[0] // bs
|
||||
else:
|
||||
n_secs = videos_fake.shape[0] // bs + 1
|
||||
|
||||
videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
|
||||
|
||||
for sample_v in tqdm(videos_fake, total=len(videos_real),desc=info):
|
||||
feats_sample = detector(sample_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
|
||||
sample_embed.append(feats_sample)
|
||||
|
||||
|
||||
else:
|
||||
|
||||
if not isvideo(videos_real):
|
||||
# if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
|
||||
videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float()
|
||||
|
||||
if not isvideo(videos_fake):
|
||||
videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float()
|
||||
|
||||
if videos_fake.shape[0] % bs == 0:
|
||||
n_secs = videos_fake.shape[0] // bs
|
||||
else:
|
||||
n_secs = videos_fake.shape[0] // bs + 1
|
||||
|
||||
videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
|
||||
videos_fake = torch.tensor_split(videos_fake, n_secs, dim=0)
|
||||
|
||||
for ref_v, sample_v in tqdm(zip(videos_real,videos_fake),total=len(videos_fake),desc=info):
|
||||
# print(ref_v.shape)
|
||||
# ref_v = torch.nn.functional.interpolate(ref_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False)
|
||||
# sample_v = torch.nn.functional.interpolate(sample_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False)
|
||||
|
||||
|
||||
feats_sample = detector(sample_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
|
||||
feats_ref = detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
|
||||
sample_embed.append(feats_sample)
|
||||
ref_embed.append(feats_ref)
|
||||
|
||||
out = dict()
|
||||
if len(sample_embed) > 0:
|
||||
sample_embed = np.concatenate(sample_embed,axis=0)
|
||||
mu_sample, sigma_sample = compute_stats(sample_embed)
|
||||
out.update({'mu_sample': mu_sample,
|
||||
'sigma_sample': sigma_sample})
|
||||
|
||||
if len(ref_embed) > 0:
|
||||
ref_embed = np.concatenate(ref_embed,axis=0)
|
||||
mu_ref, sigma_ref = compute_stats(ref_embed)
|
||||
out.update({'mu_ref': mu_ref,
|
||||
'sigma_ref': sigma_ref})
|
||||
|
||||
|
||||
return out
|
||||
2
ldm/modules/image_degradation/__init__.py
Executable file
2
ldm/modules/image_degradation/__init__.py
Executable file
@@ -0,0 +1,2 @@
|
||||
from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
|
||||
from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
|
||||
730
ldm/modules/image_degradation/bsrgan.py
Executable file
730
ldm/modules/image_degradation/bsrgan.py
Executable file
@@ -0,0 +1,730 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# Super-Resolution
|
||||
# --------------------------------------------
|
||||
#
|
||||
# Kai Zhang (cskaizhang@gmail.com)
|
||||
# https://github.com/cszn
|
||||
# From 2019/03--2021/08
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from functools import partial
|
||||
import random
|
||||
from scipy import ndimage
|
||||
import scipy
|
||||
import scipy.stats as ss
|
||||
from scipy.interpolate import interp2d
|
||||
from scipy.linalg import orth
|
||||
import albumentations
|
||||
|
||||
import ldm.modules.image_degradation.utils_image as util
|
||||
|
||||
|
||||
def modcrop_np(img, sf):
|
||||
'''
|
||||
Args:
|
||||
img: numpy image, WxH or WxHxC
|
||||
sf: scale factor
|
||||
Return:
|
||||
cropped image
|
||||
'''
|
||||
w, h = img.shape[:2]
|
||||
im = np.copy(img)
|
||||
return im[:w - w % sf, :h - h % sf, ...]
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# anisotropic Gaussian kernels
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def analytic_kernel(k):
|
||||
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
|
||||
k_size = k.shape[0]
|
||||
# Calculate the big kernels size
|
||||
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
|
||||
# Loop over the small kernel to fill the big one
|
||||
for r in range(k_size):
|
||||
for c in range(k_size):
|
||||
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
|
||||
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
|
||||
crop = k_size // 2
|
||||
cropped_big_k = big_k[crop:-crop, crop:-crop]
|
||||
# Normalize to 1
|
||||
return cropped_big_k / cropped_big_k.sum()
|
||||
|
||||
|
||||
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
|
||||
""" generate an anisotropic Gaussian kernel
|
||||
Args:
|
||||
ksize : e.g., 15, kernel size
|
||||
theta : [0, pi], rotation angle range
|
||||
l1 : [0.1,50], scaling of eigenvalues
|
||||
l2 : [0.1,l1], scaling of eigenvalues
|
||||
If l1 = l2, will get an isotropic Gaussian kernel.
|
||||
Returns:
|
||||
k : kernel
|
||||
"""
|
||||
|
||||
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
|
||||
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
|
||||
D = np.array([[l1, 0], [0, l2]])
|
||||
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
|
||||
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
|
||||
|
||||
return k
|
||||
|
||||
|
||||
def gm_blur_kernel(mean, cov, size=15):
|
||||
center = size / 2.0 + 0.5
|
||||
k = np.zeros([size, size])
|
||||
for y in range(size):
|
||||
for x in range(size):
|
||||
cy = y - center + 1
|
||||
cx = x - center + 1
|
||||
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
|
||||
|
||||
k = k / np.sum(k)
|
||||
return k
|
||||
|
||||
|
||||
def shift_pixel(x, sf, upper_left=True):
|
||||
"""shift pixel for super-resolution with different scale factors
|
||||
Args:
|
||||
x: WxHxC or WxH
|
||||
sf: scale factor
|
||||
upper_left: shift direction
|
||||
"""
|
||||
h, w = x.shape[:2]
|
||||
shift = (sf - 1) * 0.5
|
||||
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
|
||||
if upper_left:
|
||||
x1 = xv + shift
|
||||
y1 = yv + shift
|
||||
else:
|
||||
x1 = xv - shift
|
||||
y1 = yv - shift
|
||||
|
||||
x1 = np.clip(x1, 0, w - 1)
|
||||
y1 = np.clip(y1, 0, h - 1)
|
||||
|
||||
if x.ndim == 2:
|
||||
x = interp2d(xv, yv, x)(x1, y1)
|
||||
if x.ndim == 3:
|
||||
for i in range(x.shape[-1]):
|
||||
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def blur(x, k):
|
||||
'''
|
||||
x: image, NxcxHxW
|
||||
k: kernel, Nx1xhxw
|
||||
'''
|
||||
n, c = x.shape[:2]
|
||||
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
|
||||
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
|
||||
k = k.repeat(1, c, 1, 1)
|
||||
k = k.view(-1, 1, k.shape[2], k.shape[3])
|
||||
x = x.view(1, -1, x.shape[2], x.shape[3])
|
||||
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
|
||||
x = x.view(n, c, x.shape[2], x.shape[3])
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
|
||||
""""
|
||||
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
||||
# Kai Zhang
|
||||
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
||||
# max_var = 2.5 * sf
|
||||
"""
|
||||
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
||||
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
||||
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
||||
theta = np.random.rand() * np.pi # random theta
|
||||
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
|
||||
|
||||
# Set COV matrix using Lambdas and Theta
|
||||
LAMBDA = np.diag([lambda_1, lambda_2])
|
||||
Q = np.array([[np.cos(theta), -np.sin(theta)],
|
||||
[np.sin(theta), np.cos(theta)]])
|
||||
SIGMA = Q @ LAMBDA @ Q.T
|
||||
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
||||
|
||||
# Set expectation position (shifting kernel for aligned image)
|
||||
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
||||
MU = MU[None, None, :, None]
|
||||
|
||||
# Create meshgrid for Gaussian
|
||||
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
||||
Z = np.stack([X, Y], 2)[:, :, :, None]
|
||||
|
||||
# Calcualte Gaussian for every pixel of the kernel
|
||||
ZZ = Z - MU
|
||||
ZZ_t = ZZ.transpose(0, 1, 3, 2)
|
||||
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
||||
|
||||
# shift the kernel so it will be centered
|
||||
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
||||
|
||||
# Normalize the kernel and return
|
||||
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
||||
kernel = raw_kernel / np.sum(raw_kernel)
|
||||
return kernel
|
||||
|
||||
|
||||
def fspecial_gaussian(hsize, sigma):
|
||||
hsize = [hsize, hsize]
|
||||
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
|
||||
std = sigma
|
||||
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
|
||||
arg = -(x * x + y * y) / (2 * std * std)
|
||||
h = np.exp(arg)
|
||||
h[h < scipy.finfo(float).eps * h.max()] = 0
|
||||
sumh = h.sum()
|
||||
if sumh != 0:
|
||||
h = h / sumh
|
||||
return h
|
||||
|
||||
|
||||
def fspecial_laplacian(alpha):
|
||||
alpha = max([0, min([alpha, 1])])
|
||||
h1 = alpha / (alpha + 1)
|
||||
h2 = (1 - alpha) / (alpha + 1)
|
||||
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
|
||||
h = np.array(h)
|
||||
return h
|
||||
|
||||
|
||||
def fspecial(filter_type, *args, **kwargs):
|
||||
'''
|
||||
python code from:
|
||||
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
|
||||
'''
|
||||
if filter_type == 'gaussian':
|
||||
return fspecial_gaussian(*args, **kwargs)
|
||||
if filter_type == 'laplacian':
|
||||
return fspecial_laplacian(*args, **kwargs)
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# degradation models
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def bicubic_degradation(x, sf=3):
|
||||
'''
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
bicubicly downsampled LR image
|
||||
'''
|
||||
x = util.imresize_np(x, scale=1 / sf)
|
||||
return x
|
||||
|
||||
|
||||
def srmd_degradation(x, k, sf=3):
|
||||
''' blur + bicubic downsampling
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
Reference:
|
||||
@inproceedings{zhang2018learning,
|
||||
title={Learning a single convolutional super-resolution network for multiple degradations},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={3262--3271},
|
||||
year={2018}
|
||||
}
|
||||
'''
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
return x
|
||||
|
||||
|
||||
def dpsr_degradation(x, k, sf=3):
|
||||
''' bicubic downsampling + blur
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
Reference:
|
||||
@inproceedings{zhang2019deep,
|
||||
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={1671--1681},
|
||||
year={2019}
|
||||
}
|
||||
'''
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
||||
return x
|
||||
|
||||
|
||||
def classical_degradation(x, k, sf=3):
|
||||
''' blur + downsampling
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]/[0, 255]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
'''
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
||||
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
|
||||
st = 0
|
||||
return x[st::sf, st::sf, ...]
|
||||
|
||||
|
||||
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
|
||||
"""USM sharpening. borrowed from real-ESRGAN
|
||||
Input image: I; Blurry image: B.
|
||||
1. K = I + weight * (I - B)
|
||||
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
||||
3. Blur mask:
|
||||
4. Out = Mask * K + (1 - Mask) * I
|
||||
Args:
|
||||
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
||||
weight (float): Sharp weight. Default: 1.
|
||||
radius (float): Kernel size of Gaussian blur. Default: 50.
|
||||
threshold (int):
|
||||
"""
|
||||
if radius % 2 == 0:
|
||||
radius += 1
|
||||
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
||||
residual = img - blur
|
||||
mask = np.abs(residual) * 255 > threshold
|
||||
mask = mask.astype('float32')
|
||||
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
||||
|
||||
K = img + weight * residual
|
||||
K = np.clip(K, 0, 1)
|
||||
return soft_mask * K + (1 - soft_mask) * img
|
||||
|
||||
|
||||
def add_blur(img, sf=4):
|
||||
wd2 = 4.0 + sf
|
||||
wd = 2.0 + 0.2 * sf
|
||||
if random.random() < 0.5:
|
||||
l1 = wd2 * random.random()
|
||||
l2 = wd2 * random.random()
|
||||
k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
|
||||
else:
|
||||
k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
|
||||
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def add_resize(img, sf=4):
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.8: # up
|
||||
sf1 = random.uniform(1, 2)
|
||||
elif rnum < 0.7: # down
|
||||
sf1 = random.uniform(0.5 / sf, 1)
|
||||
else:
|
||||
sf1 = 1.0
|
||||
img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
||||
# noise_level = random.randint(noise_level1, noise_level2)
|
||||
# rnum = np.random.rand()
|
||||
# if rnum > 0.6: # add color Gaussian noise
|
||||
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
# elif rnum < 0.4: # add grayscale Gaussian noise
|
||||
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
# else: # add noise
|
||||
# L = noise_level2 / 255.
|
||||
# D = np.diag(np.random.rand(3))
|
||||
# U = orth(np.random.rand(3, 3))
|
||||
# conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
# img = np.clip(img, 0.0, 1.0)
|
||||
# return img
|
||||
|
||||
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.6: # add color Gaussian noise
|
||||
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
elif rnum < 0.4: # add grayscale Gaussian noise
|
||||
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
else: # add noise
|
||||
L = noise_level2 / 255.
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
rnum = random.random()
|
||||
if rnum > 0.6:
|
||||
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
elif rnum < 0.4:
|
||||
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
else:
|
||||
L = noise_level2 / 255.
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_Poisson_noise(img):
|
||||
img = np.clip((img * 255.0).round(), 0, 255) / 255.
|
||||
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
|
||||
if random.random() < 0.5:
|
||||
img = np.random.poisson(img * vals).astype(np.float32) / vals
|
||||
else:
|
||||
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
|
||||
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
|
||||
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
||||
img += noise_gray[:, :, np.newaxis]
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_JPEG_noise(img):
|
||||
quality_factor = random.randint(30, 95)
|
||||
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
||||
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
|
||||
img = cv2.imdecode(encimg, 1)
|
||||
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
def random_crop(lq, hq, sf=4, lq_patchsize=64):
|
||||
h, w = lq.shape[:2]
|
||||
rnd_h = random.randint(0, h - lq_patchsize)
|
||||
rnd_w = random.randint(0, w - lq_patchsize)
|
||||
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
|
||||
|
||||
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
|
||||
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
|
||||
return lq, hq
|
||||
|
||||
|
||||
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
||||
"""
|
||||
This is the degradation model of BSRGAN from the paper
|
||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
||||
----------
|
||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
||||
sf: scale factor
|
||||
isp_model: camera ISP model
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
sf_ori = sf
|
||||
|
||||
h1, w1 = img.shape[:2]
|
||||
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
|
||||
h, w = img.shape[:2]
|
||||
|
||||
if h < lq_patchsize * sf or w < lq_patchsize * sf:
|
||||
raise ValueError(f'img size ({h1}X{w1}) is too small!')
|
||||
|
||||
hq = img.copy()
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
img = util.imresize_np(img, 1 / 2, True)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
sf = 2
|
||||
|
||||
shuffle_order = random.sample(range(7), 7)
|
||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
||||
if idx1 > idx2: # keep downsample3 last
|
||||
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
|
||||
|
||||
for i in shuffle_order:
|
||||
|
||||
if i == 0:
|
||||
img = add_blur(img, sf=sf)
|
||||
|
||||
elif i == 1:
|
||||
img = add_blur(img, sf=sf)
|
||||
|
||||
elif i == 2:
|
||||
a, b = img.shape[1], img.shape[0]
|
||||
# downsample2
|
||||
if random.random() < 0.75:
|
||||
sf1 = random.uniform(1, 2 * sf)
|
||||
img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
||||
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
elif i == 3:
|
||||
# downsample3
|
||||
img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
elif i == 4:
|
||||
# add Gaussian noise
|
||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
||||
|
||||
elif i == 5:
|
||||
# add JPEG noise
|
||||
if random.random() < jpeg_prob:
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
elif i == 6:
|
||||
# add processed camera sensor noise
|
||||
if random.random() < isp_prob and isp_model is not None:
|
||||
with torch.no_grad():
|
||||
img, hq = isp_model.forward(img.copy(), hq)
|
||||
|
||||
# add final JPEG compression noise
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
# random crop
|
||||
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
|
||||
|
||||
return img, hq
|
||||
|
||||
|
||||
# todo no isp_model?
|
||||
def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
||||
"""
|
||||
This is the degradation model of BSRGAN from the paper
|
||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
||||
----------
|
||||
sf: scale factor
|
||||
isp_model: camera ISP model
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
image = util.uint2single(image)
|
||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
sf_ori = sf
|
||||
|
||||
h1, w1 = image.shape[:2]
|
||||
image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
|
||||
h, w = image.shape[:2]
|
||||
|
||||
hq = image.copy()
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
image = util.imresize_np(image, 1 / 2, True)
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
sf = 2
|
||||
|
||||
shuffle_order = random.sample(range(7), 7)
|
||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
||||
if idx1 > idx2: # keep downsample3 last
|
||||
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
|
||||
|
||||
for i in shuffle_order:
|
||||
|
||||
if i == 0:
|
||||
image = add_blur(image, sf=sf)
|
||||
|
||||
elif i == 1:
|
||||
image = add_blur(image, sf=sf)
|
||||
|
||||
elif i == 2:
|
||||
a, b = image.shape[1], image.shape[0]
|
||||
# downsample2
|
||||
if random.random() < 0.75:
|
||||
sf1 = random.uniform(1, 2 * sf)
|
||||
image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
||||
image = image[0::sf, 0::sf, ...] # nearest downsampling
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
|
||||
elif i == 3:
|
||||
# downsample3
|
||||
image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
|
||||
elif i == 4:
|
||||
# add Gaussian noise
|
||||
image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
|
||||
|
||||
elif i == 5:
|
||||
# add JPEG noise
|
||||
if random.random() < jpeg_prob:
|
||||
image = add_JPEG_noise(image)
|
||||
|
||||
# elif i == 6:
|
||||
# # add processed camera sensor noise
|
||||
# if random.random() < isp_prob and isp_model is not None:
|
||||
# with torch.no_grad():
|
||||
# img, hq = isp_model.forward(img.copy(), hq)
|
||||
|
||||
# add final JPEG compression noise
|
||||
image = add_JPEG_noise(image)
|
||||
image = util.single2uint(image)
|
||||
example = {"image":image}
|
||||
return example
|
||||
|
||||
|
||||
# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
|
||||
def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
|
||||
"""
|
||||
This is an extended degradation model by combining
|
||||
the degradation models of BSRGAN and Real-ESRGAN
|
||||
----------
|
||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
||||
sf: scale factor
|
||||
use_shuffle: the degradation shuffle
|
||||
use_sharp: sharpening the img
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
|
||||
h1, w1 = img.shape[:2]
|
||||
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
|
||||
h, w = img.shape[:2]
|
||||
|
||||
if h < lq_patchsize * sf or w < lq_patchsize * sf:
|
||||
raise ValueError(f'img size ({h1}X{w1}) is too small!')
|
||||
|
||||
if use_sharp:
|
||||
img = add_sharpening(img)
|
||||
hq = img.copy()
|
||||
|
||||
if random.random() < shuffle_prob:
|
||||
shuffle_order = random.sample(range(13), 13)
|
||||
else:
|
||||
shuffle_order = list(range(13))
|
||||
# local shuffle for noise, JPEG is always the last one
|
||||
shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
|
||||
shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
|
||||
|
||||
poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
|
||||
|
||||
for i in shuffle_order:
|
||||
if i == 0:
|
||||
img = add_blur(img, sf=sf)
|
||||
elif i == 1:
|
||||
img = add_resize(img, sf=sf)
|
||||
elif i == 2:
|
||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
||||
elif i == 3:
|
||||
if random.random() < poisson_prob:
|
||||
img = add_Poisson_noise(img)
|
||||
elif i == 4:
|
||||
if random.random() < speckle_prob:
|
||||
img = add_speckle_noise(img)
|
||||
elif i == 5:
|
||||
if random.random() < isp_prob and isp_model is not None:
|
||||
with torch.no_grad():
|
||||
img, hq = isp_model.forward(img.copy(), hq)
|
||||
elif i == 6:
|
||||
img = add_JPEG_noise(img)
|
||||
elif i == 7:
|
||||
img = add_blur(img, sf=sf)
|
||||
elif i == 8:
|
||||
img = add_resize(img, sf=sf)
|
||||
elif i == 9:
|
||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
||||
elif i == 10:
|
||||
if random.random() < poisson_prob:
|
||||
img = add_Poisson_noise(img)
|
||||
elif i == 11:
|
||||
if random.random() < speckle_prob:
|
||||
img = add_speckle_noise(img)
|
||||
elif i == 12:
|
||||
if random.random() < isp_prob and isp_model is not None:
|
||||
with torch.no_grad():
|
||||
img, hq = isp_model.forward(img.copy(), hq)
|
||||
else:
|
||||
print('check the shuffle!')
|
||||
|
||||
# resize to desired size
|
||||
img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
|
||||
# add final JPEG compression noise
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
# random crop
|
||||
img, hq = random_crop(img, hq, sf, lq_patchsize)
|
||||
|
||||
return img, hq
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("hey")
|
||||
img = util.imread_uint('utils/test.png', 3)
|
||||
print(img)
|
||||
img = util.uint2single(img)
|
||||
print(img)
|
||||
img = img[:448, :448]
|
||||
h = img.shape[0] // 4
|
||||
print("resizing to", h)
|
||||
sf = 4
|
||||
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
|
||||
for i in range(20):
|
||||
print(i)
|
||||
img_lq = deg_fn(img)
|
||||
print(img_lq)
|
||||
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
|
||||
print(img_lq.shape)
|
||||
print("bicubic", img_lq_bicubic.shape)
|
||||
print(img_hq.shape)
|
||||
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
interpolation=0)
|
||||
lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
interpolation=0)
|
||||
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
|
||||
util.imsave(img_concat, str(i) + '.png')
|
||||
|
||||
|
||||
650
ldm/modules/image_degradation/bsrgan_light.py
Executable file
650
ldm/modules/image_degradation/bsrgan_light.py
Executable file
@@ -0,0 +1,650 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from functools import partial
|
||||
import random
|
||||
from scipy import ndimage
|
||||
import scipy
|
||||
import scipy.stats as ss
|
||||
from scipy.interpolate import interp2d
|
||||
from scipy.linalg import orth
|
||||
import albumentations
|
||||
|
||||
import ldm.modules.image_degradation.utils_image as util
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# Super-Resolution
|
||||
# --------------------------------------------
|
||||
#
|
||||
# Kai Zhang (cskaizhang@gmail.com)
|
||||
# https://github.com/cszn
|
||||
# From 2019/03--2021/08
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def modcrop_np(img, sf):
|
||||
'''
|
||||
Args:
|
||||
img: numpy image, WxH or WxHxC
|
||||
sf: scale factor
|
||||
Return:
|
||||
cropped image
|
||||
'''
|
||||
w, h = img.shape[:2]
|
||||
im = np.copy(img)
|
||||
return im[:w - w % sf, :h - h % sf, ...]
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# anisotropic Gaussian kernels
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def analytic_kernel(k):
|
||||
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
|
||||
k_size = k.shape[0]
|
||||
# Calculate the big kernels size
|
||||
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
|
||||
# Loop over the small kernel to fill the big one
|
||||
for r in range(k_size):
|
||||
for c in range(k_size):
|
||||
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
|
||||
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
|
||||
crop = k_size // 2
|
||||
cropped_big_k = big_k[crop:-crop, crop:-crop]
|
||||
# Normalize to 1
|
||||
return cropped_big_k / cropped_big_k.sum()
|
||||
|
||||
|
||||
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
|
||||
""" generate an anisotropic Gaussian kernel
|
||||
Args:
|
||||
ksize : e.g., 15, kernel size
|
||||
theta : [0, pi], rotation angle range
|
||||
l1 : [0.1,50], scaling of eigenvalues
|
||||
l2 : [0.1,l1], scaling of eigenvalues
|
||||
If l1 = l2, will get an isotropic Gaussian kernel.
|
||||
Returns:
|
||||
k : kernel
|
||||
"""
|
||||
|
||||
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
|
||||
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
|
||||
D = np.array([[l1, 0], [0, l2]])
|
||||
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
|
||||
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
|
||||
|
||||
return k
|
||||
|
||||
|
||||
def gm_blur_kernel(mean, cov, size=15):
|
||||
center = size / 2.0 + 0.5
|
||||
k = np.zeros([size, size])
|
||||
for y in range(size):
|
||||
for x in range(size):
|
||||
cy = y - center + 1
|
||||
cx = x - center + 1
|
||||
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
|
||||
|
||||
k = k / np.sum(k)
|
||||
return k
|
||||
|
||||
|
||||
def shift_pixel(x, sf, upper_left=True):
|
||||
"""shift pixel for super-resolution with different scale factors
|
||||
Args:
|
||||
x: WxHxC or WxH
|
||||
sf: scale factor
|
||||
upper_left: shift direction
|
||||
"""
|
||||
h, w = x.shape[:2]
|
||||
shift = (sf - 1) * 0.5
|
||||
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
|
||||
if upper_left:
|
||||
x1 = xv + shift
|
||||
y1 = yv + shift
|
||||
else:
|
||||
x1 = xv - shift
|
||||
y1 = yv - shift
|
||||
|
||||
x1 = np.clip(x1, 0, w - 1)
|
||||
y1 = np.clip(y1, 0, h - 1)
|
||||
|
||||
if x.ndim == 2:
|
||||
x = interp2d(xv, yv, x)(x1, y1)
|
||||
if x.ndim == 3:
|
||||
for i in range(x.shape[-1]):
|
||||
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def blur(x, k):
|
||||
'''
|
||||
x: image, NxcxHxW
|
||||
k: kernel, Nx1xhxw
|
||||
'''
|
||||
n, c = x.shape[:2]
|
||||
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
|
||||
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
|
||||
k = k.repeat(1, c, 1, 1)
|
||||
k = k.view(-1, 1, k.shape[2], k.shape[3])
|
||||
x = x.view(1, -1, x.shape[2], x.shape[3])
|
||||
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
|
||||
x = x.view(n, c, x.shape[2], x.shape[3])
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
|
||||
""""
|
||||
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
||||
# Kai Zhang
|
||||
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
||||
# max_var = 2.5 * sf
|
||||
"""
|
||||
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
||||
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
||||
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
||||
theta = np.random.rand() * np.pi # random theta
|
||||
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
|
||||
|
||||
# Set COV matrix using Lambdas and Theta
|
||||
LAMBDA = np.diag([lambda_1, lambda_2])
|
||||
Q = np.array([[np.cos(theta), -np.sin(theta)],
|
||||
[np.sin(theta), np.cos(theta)]])
|
||||
SIGMA = Q @ LAMBDA @ Q.T
|
||||
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
||||
|
||||
# Set expectation position (shifting kernel for aligned image)
|
||||
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
||||
MU = MU[None, None, :, None]
|
||||
|
||||
# Create meshgrid for Gaussian
|
||||
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
||||
Z = np.stack([X, Y], 2)[:, :, :, None]
|
||||
|
||||
# Calcualte Gaussian for every pixel of the kernel
|
||||
ZZ = Z - MU
|
||||
ZZ_t = ZZ.transpose(0, 1, 3, 2)
|
||||
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
||||
|
||||
# shift the kernel so it will be centered
|
||||
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
||||
|
||||
# Normalize the kernel and return
|
||||
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
||||
kernel = raw_kernel / np.sum(raw_kernel)
|
||||
return kernel
|
||||
|
||||
|
||||
def fspecial_gaussian(hsize, sigma):
|
||||
hsize = [hsize, hsize]
|
||||
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
|
||||
std = sigma
|
||||
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
|
||||
arg = -(x * x + y * y) / (2 * std * std)
|
||||
h = np.exp(arg)
|
||||
h[h < scipy.finfo(float).eps * h.max()] = 0
|
||||
sumh = h.sum()
|
||||
if sumh != 0:
|
||||
h = h / sumh
|
||||
return h
|
||||
|
||||
|
||||
def fspecial_laplacian(alpha):
|
||||
alpha = max([0, min([alpha, 1])])
|
||||
h1 = alpha / (alpha + 1)
|
||||
h2 = (1 - alpha) / (alpha + 1)
|
||||
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
|
||||
h = np.array(h)
|
||||
return h
|
||||
|
||||
|
||||
def fspecial(filter_type, *args, **kwargs):
|
||||
'''
|
||||
python code from:
|
||||
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
|
||||
'''
|
||||
if filter_type == 'gaussian':
|
||||
return fspecial_gaussian(*args, **kwargs)
|
||||
if filter_type == 'laplacian':
|
||||
return fspecial_laplacian(*args, **kwargs)
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# degradation models
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def bicubic_degradation(x, sf=3):
|
||||
'''
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
bicubicly downsampled LR image
|
||||
'''
|
||||
x = util.imresize_np(x, scale=1 / sf)
|
||||
return x
|
||||
|
||||
|
||||
def srmd_degradation(x, k, sf=3):
|
||||
''' blur + bicubic downsampling
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
Reference:
|
||||
@inproceedings{zhang2018learning,
|
||||
title={Learning a single convolutional super-resolution network for multiple degradations},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={3262--3271},
|
||||
year={2018}
|
||||
}
|
||||
'''
|
||||
x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
return x
|
||||
|
||||
|
||||
def dpsr_degradation(x, k, sf=3):
|
||||
''' bicubic downsampling + blur
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
Reference:
|
||||
@inproceedings{zhang2019deep,
|
||||
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={1671--1681},
|
||||
year={2019}
|
||||
}
|
||||
'''
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
||||
return x
|
||||
|
||||
|
||||
def classical_degradation(x, k, sf=3):
|
||||
''' blur + downsampling
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]/[0, 255]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
'''
|
||||
x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
||||
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
|
||||
st = 0
|
||||
return x[st::sf, st::sf, ...]
|
||||
|
||||
|
||||
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
|
||||
"""USM sharpening. borrowed from real-ESRGAN
|
||||
Input image: I; Blurry image: B.
|
||||
1. K = I + weight * (I - B)
|
||||
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
||||
3. Blur mask:
|
||||
4. Out = Mask * K + (1 - Mask) * I
|
||||
Args:
|
||||
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
||||
weight (float): Sharp weight. Default: 1.
|
||||
radius (float): Kernel size of Gaussian blur. Default: 50.
|
||||
threshold (int):
|
||||
"""
|
||||
if radius % 2 == 0:
|
||||
radius += 1
|
||||
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
||||
residual = img - blur
|
||||
mask = np.abs(residual) * 255 > threshold
|
||||
mask = mask.astype('float32')
|
||||
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
||||
|
||||
K = img + weight * residual
|
||||
K = np.clip(K, 0, 1)
|
||||
return soft_mask * K + (1 - soft_mask) * img
|
||||
|
||||
|
||||
def add_blur(img, sf=4):
|
||||
wd2 = 4.0 + sf
|
||||
wd = 2.0 + 0.2 * sf
|
||||
|
||||
wd2 = wd2/4
|
||||
wd = wd/4
|
||||
|
||||
if random.random() < 0.5:
|
||||
l1 = wd2 * random.random()
|
||||
l2 = wd2 * random.random()
|
||||
k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
|
||||
else:
|
||||
k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
|
||||
img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def add_resize(img, sf=4):
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.8: # up
|
||||
sf1 = random.uniform(1, 2)
|
||||
elif rnum < 0.7: # down
|
||||
sf1 = random.uniform(0.5 / sf, 1)
|
||||
else:
|
||||
sf1 = 1.0
|
||||
img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
||||
# noise_level = random.randint(noise_level1, noise_level2)
|
||||
# rnum = np.random.rand()
|
||||
# if rnum > 0.6: # add color Gaussian noise
|
||||
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
# elif rnum < 0.4: # add grayscale Gaussian noise
|
||||
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
# else: # add noise
|
||||
# L = noise_level2 / 255.
|
||||
# D = np.diag(np.random.rand(3))
|
||||
# U = orth(np.random.rand(3, 3))
|
||||
# conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
# img = np.clip(img, 0.0, 1.0)
|
||||
# return img
|
||||
|
||||
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.6: # add color Gaussian noise
|
||||
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
elif rnum < 0.4: # add grayscale Gaussian noise
|
||||
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
else: # add noise
|
||||
L = noise_level2 / 255.
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
rnum = random.random()
|
||||
if rnum > 0.6:
|
||||
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
elif rnum < 0.4:
|
||||
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
else:
|
||||
L = noise_level2 / 255.
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_Poisson_noise(img):
|
||||
img = np.clip((img * 255.0).round(), 0, 255) / 255.
|
||||
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
|
||||
if random.random() < 0.5:
|
||||
img = np.random.poisson(img * vals).astype(np.float32) / vals
|
||||
else:
|
||||
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
|
||||
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
|
||||
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
||||
img += noise_gray[:, :, np.newaxis]
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_JPEG_noise(img):
|
||||
quality_factor = random.randint(80, 95)
|
||||
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
||||
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
|
||||
img = cv2.imdecode(encimg, 1)
|
||||
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
def random_crop(lq, hq, sf=4, lq_patchsize=64):
|
||||
h, w = lq.shape[:2]
|
||||
rnd_h = random.randint(0, h - lq_patchsize)
|
||||
rnd_w = random.randint(0, w - lq_patchsize)
|
||||
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
|
||||
|
||||
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
|
||||
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
|
||||
return lq, hq
|
||||
|
||||
|
||||
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
||||
"""
|
||||
This is the degradation model of BSRGAN from the paper
|
||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
||||
----------
|
||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
||||
sf: scale factor
|
||||
isp_model: camera ISP model
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
sf_ori = sf
|
||||
|
||||
h1, w1 = img.shape[:2]
|
||||
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
|
||||
h, w = img.shape[:2]
|
||||
|
||||
if h < lq_patchsize * sf or w < lq_patchsize * sf:
|
||||
raise ValueError(f'img size ({h1}X{w1}) is too small!')
|
||||
|
||||
hq = img.copy()
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
img = util.imresize_np(img, 1 / 2, True)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
sf = 2
|
||||
|
||||
shuffle_order = random.sample(range(7), 7)
|
||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
||||
if idx1 > idx2: # keep downsample3 last
|
||||
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
|
||||
|
||||
for i in shuffle_order:
|
||||
|
||||
if i == 0:
|
||||
img = add_blur(img, sf=sf)
|
||||
|
||||
elif i == 1:
|
||||
img = add_blur(img, sf=sf)
|
||||
|
||||
elif i == 2:
|
||||
a, b = img.shape[1], img.shape[0]
|
||||
# downsample2
|
||||
if random.random() < 0.75:
|
||||
sf1 = random.uniform(1, 2 * sf)
|
||||
img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||
img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
||||
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
elif i == 3:
|
||||
# downsample3
|
||||
img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
elif i == 4:
|
||||
# add Gaussian noise
|
||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
|
||||
|
||||
elif i == 5:
|
||||
# add JPEG noise
|
||||
if random.random() < jpeg_prob:
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
elif i == 6:
|
||||
# add processed camera sensor noise
|
||||
if random.random() < isp_prob and isp_model is not None:
|
||||
with torch.no_grad():
|
||||
img, hq = isp_model.forward(img.copy(), hq)
|
||||
|
||||
# add final JPEG compression noise
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
# random crop
|
||||
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
|
||||
|
||||
return img, hq
|
||||
|
||||
|
||||
# todo no isp_model?
|
||||
def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
||||
"""
|
||||
This is the degradation model of BSRGAN from the paper
|
||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
||||
----------
|
||||
sf: scale factor
|
||||
isp_model: camera ISP model
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
image = util.uint2single(image)
|
||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
sf_ori = sf
|
||||
|
||||
h1, w1 = image.shape[:2]
|
||||
image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
|
||||
h, w = image.shape[:2]
|
||||
|
||||
hq = image.copy()
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
image = util.imresize_np(image, 1 / 2, True)
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
sf = 2
|
||||
|
||||
shuffle_order = random.sample(range(7), 7)
|
||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
||||
if idx1 > idx2: # keep downsample3 last
|
||||
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
|
||||
|
||||
for i in shuffle_order:
|
||||
|
||||
if i == 0:
|
||||
image = add_blur(image, sf=sf)
|
||||
|
||||
# elif i == 1:
|
||||
# image = add_blur(image, sf=sf)
|
||||
|
||||
if i == 0:
|
||||
pass
|
||||
|
||||
elif i == 2:
|
||||
a, b = image.shape[1], image.shape[0]
|
||||
# downsample2
|
||||
if random.random() < 0.8:
|
||||
sf1 = random.uniform(1, 2 * sf)
|
||||
image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||
image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
||||
image = image[0::sf, 0::sf, ...] # nearest downsampling
|
||||
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
|
||||
elif i == 3:
|
||||
# downsample3
|
||||
image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
|
||||
elif i == 4:
|
||||
# add Gaussian noise
|
||||
image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
|
||||
|
||||
elif i == 5:
|
||||
# add JPEG noise
|
||||
if random.random() < jpeg_prob:
|
||||
image = add_JPEG_noise(image)
|
||||
#
|
||||
# elif i == 6:
|
||||
# # add processed camera sensor noise
|
||||
# if random.random() < isp_prob and isp_model is not None:
|
||||
# with torch.no_grad():
|
||||
# img, hq = isp_model.forward(img.copy(), hq)
|
||||
|
||||
# add final JPEG compression noise
|
||||
image = add_JPEG_noise(image)
|
||||
image = util.single2uint(image)
|
||||
example = {"image": image}
|
||||
return example
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("hey")
|
||||
img = util.imread_uint('utils/test.png', 3)
|
||||
img = img[:448, :448]
|
||||
h = img.shape[0] // 4
|
||||
print("resizing to", h)
|
||||
sf = 4
|
||||
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
|
||||
for i in range(20):
|
||||
print(i)
|
||||
img_hq = img
|
||||
img_lq = deg_fn(img)["image"]
|
||||
img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
|
||||
print(img_lq)
|
||||
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
|
||||
print(img_lq.shape)
|
||||
print("bicubic", img_lq_bicubic.shape)
|
||||
print(img_hq.shape)
|
||||
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
interpolation=0)
|
||||
lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
|
||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
interpolation=0)
|
||||
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
|
||||
util.imsave(img_concat, str(i) + '.png')
|
||||
BIN
ldm/modules/image_degradation/utils/test.png
Executable file
BIN
ldm/modules/image_degradation/utils/test.png
Executable file
Binary file not shown.
|
After Width: | Height: | Size: 431 KiB |
916
ldm/modules/image_degradation/utils_image.py
Executable file
916
ldm/modules/image_degradation/utils_image.py
Executable file
@@ -0,0 +1,916 @@
|
||||
import os
|
||||
import math
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import cv2
|
||||
from torchvision.utils import make_grid
|
||||
from datetime import datetime
|
||||
#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
|
||||
|
||||
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# Kai Zhang (github: https://github.com/cszn)
|
||||
# 03/Mar/2019
|
||||
# --------------------------------------------
|
||||
# https://github.com/twhui/SRGAN-pyTorch
|
||||
# https://github.com/xinntao/BasicSR
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
|
||||
|
||||
|
||||
def is_image_file(filename):
|
||||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
|
||||
def get_timestamp():
|
||||
return datetime.now().strftime('%y%m%d-%H%M%S')
|
||||
|
||||
|
||||
def imshow(x, title=None, cbar=False, figsize=None):
|
||||
plt.figure(figsize=figsize)
|
||||
plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
|
||||
if title:
|
||||
plt.title(title)
|
||||
if cbar:
|
||||
plt.colorbar()
|
||||
plt.show()
|
||||
|
||||
|
||||
def surf(Z, cmap='rainbow', figsize=None):
|
||||
plt.figure(figsize=figsize)
|
||||
ax3 = plt.axes(projection='3d')
|
||||
|
||||
w, h = Z.shape[:2]
|
||||
xx = np.arange(0,w,1)
|
||||
yy = np.arange(0,h,1)
|
||||
X, Y = np.meshgrid(xx, yy)
|
||||
ax3.plot_surface(X,Y,Z,cmap=cmap)
|
||||
#ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
|
||||
plt.show()
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# get image pathes
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
def get_image_paths(dataroot):
|
||||
paths = None # return None if dataroot is None
|
||||
if dataroot is not None:
|
||||
paths = sorted(_get_paths_from_images(dataroot))
|
||||
return paths
|
||||
|
||||
|
||||
def _get_paths_from_images(path):
|
||||
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
|
||||
images = []
|
||||
for dirpath, _, fnames in sorted(os.walk(path)):
|
||||
for fname in sorted(fnames):
|
||||
if is_image_file(fname):
|
||||
img_path = os.path.join(dirpath, fname)
|
||||
images.append(img_path)
|
||||
assert images, '{:s} has no valid image file'.format(path)
|
||||
return images
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# split large images into small images
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
|
||||
w, h = img.shape[:2]
|
||||
patches = []
|
||||
if w > p_max and h > p_max:
|
||||
w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
|
||||
h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
|
||||
w1.append(w-p_size)
|
||||
h1.append(h-p_size)
|
||||
# print(w1)
|
||||
# print(h1)
|
||||
for i in w1:
|
||||
for j in h1:
|
||||
patches.append(img[i:i+p_size, j:j+p_size,:])
|
||||
else:
|
||||
patches.append(img)
|
||||
|
||||
return patches
|
||||
|
||||
|
||||
def imssave(imgs, img_path):
|
||||
"""
|
||||
imgs: list, N images of size WxHxC
|
||||
"""
|
||||
img_name, ext = os.path.splitext(os.path.basename(img_path))
|
||||
|
||||
for i, img in enumerate(imgs):
|
||||
if img.ndim == 3:
|
||||
img = img[:, :, [2, 1, 0]]
|
||||
new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
|
||||
cv2.imwrite(new_path, img)
|
||||
|
||||
|
||||
def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
|
||||
"""
|
||||
split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
|
||||
and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
|
||||
will be splitted.
|
||||
Args:
|
||||
original_dataroot:
|
||||
taget_dataroot:
|
||||
p_size: size of small images
|
||||
p_overlap: patch size in training is a good choice
|
||||
p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
|
||||
"""
|
||||
paths = get_image_paths(original_dataroot)
|
||||
for img_path in paths:
|
||||
# img_name, ext = os.path.splitext(os.path.basename(img_path))
|
||||
img = imread_uint(img_path, n_channels=n_channels)
|
||||
patches = patches_from_image(img, p_size, p_overlap, p_max)
|
||||
imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
|
||||
#if original_dataroot == taget_dataroot:
|
||||
#del img_path
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# makedir
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
def mkdir(path):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
def mkdirs(paths):
|
||||
if isinstance(paths, str):
|
||||
mkdir(paths)
|
||||
else:
|
||||
for path in paths:
|
||||
mkdir(path)
|
||||
|
||||
|
||||
def mkdir_and_rename(path):
|
||||
if os.path.exists(path):
|
||||
new_name = path + '_archived_' + get_timestamp()
|
||||
print('Path already exists. Rename it to [{:s}]'.format(new_name))
|
||||
os.rename(path, new_name)
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# read image from path
|
||||
# opencv is fast, but read BGR numpy image
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# get uint8 image of size HxWxn_channles (RGB)
|
||||
# --------------------------------------------
|
||||
def imread_uint(path, n_channels=3):
|
||||
# input: path
|
||||
# output: HxWx3(RGB or GGG), or HxWx1 (G)
|
||||
if n_channels == 1:
|
||||
img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
|
||||
img = np.expand_dims(img, axis=2) # HxWx1
|
||||
elif n_channels == 3:
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
|
||||
if img.ndim == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
|
||||
else:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
|
||||
return img
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# matlab's imwrite
|
||||
# --------------------------------------------
|
||||
def imsave(img, img_path):
|
||||
img = np.squeeze(img)
|
||||
if img.ndim == 3:
|
||||
img = img[:, :, [2, 1, 0]]
|
||||
cv2.imwrite(img_path, img)
|
||||
|
||||
def imwrite(img, img_path):
|
||||
img = np.squeeze(img)
|
||||
if img.ndim == 3:
|
||||
img = img[:, :, [2, 1, 0]]
|
||||
cv2.imwrite(img_path, img)
|
||||
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# get single image of size HxWxn_channles (BGR)
|
||||
# --------------------------------------------
|
||||
def read_img(path):
|
||||
# read image by cv2
|
||||
# return: Numpy float32, HWC, BGR, [0,1]
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
|
||||
img = img.astype(np.float32) / 255.
|
||||
if img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
# some images have 4 channels
|
||||
if img.shape[2] > 3:
|
||||
img = img[:, :, :3]
|
||||
return img
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# image format conversion
|
||||
# --------------------------------------------
|
||||
# numpy(single) <---> numpy(unit)
|
||||
# numpy(single) <---> tensor
|
||||
# numpy(unit) <---> tensor
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# numpy(single) [0, 1] <---> numpy(unit)
|
||||
# --------------------------------------------
|
||||
|
||||
|
||||
def uint2single(img):
|
||||
|
||||
return np.float32(img/255.)
|
||||
|
||||
|
||||
def single2uint(img):
|
||||
|
||||
return np.uint8((img.clip(0, 1)*255.).round())
|
||||
|
||||
|
||||
def uint162single(img):
|
||||
|
||||
return np.float32(img/65535.)
|
||||
|
||||
|
||||
def single2uint16(img):
|
||||
|
||||
return np.uint16((img.clip(0, 1)*65535.).round())
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# numpy(unit) (HxWxC or HxW) <---> tensor
|
||||
# --------------------------------------------
|
||||
|
||||
|
||||
# convert uint to 4-dimensional torch tensor
|
||||
def uint2tensor4(img):
|
||||
if img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
|
||||
|
||||
|
||||
# convert uint to 3-dimensional torch tensor
|
||||
def uint2tensor3(img):
|
||||
if img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
|
||||
|
||||
|
||||
# convert 2/3/4-dimensional torch tensor to uint
|
||||
def tensor2uint(img):
|
||||
img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
|
||||
if img.ndim == 3:
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
return np.uint8((img*255.0).round())
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# numpy(single) (HxWxC) <---> tensor
|
||||
# --------------------------------------------
|
||||
|
||||
|
||||
# convert single (HxWxC) to 3-dimensional torch tensor
|
||||
def single2tensor3(img):
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
|
||||
|
||||
|
||||
# convert single (HxWxC) to 4-dimensional torch tensor
|
||||
def single2tensor4(img):
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
|
||||
|
||||
|
||||
# convert torch tensor to single
|
||||
def tensor2single(img):
|
||||
img = img.data.squeeze().float().cpu().numpy()
|
||||
if img.ndim == 3:
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
|
||||
return img
|
||||
|
||||
# convert torch tensor to single
|
||||
def tensor2single3(img):
|
||||
img = img.data.squeeze().float().cpu().numpy()
|
||||
if img.ndim == 3:
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
elif img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
return img
|
||||
|
||||
|
||||
def single2tensor5(img):
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
|
||||
|
||||
|
||||
def single32tensor5(img):
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
|
||||
|
||||
|
||||
def single42tensor4(img):
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
|
||||
|
||||
|
||||
# from skimage.io import imread, imsave
|
||||
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
|
||||
'''
|
||||
Converts a torch Tensor into an image Numpy array of BGR channel order
|
||||
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
|
||||
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
|
||||
'''
|
||||
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
|
||||
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
|
||||
n_dim = tensor.dim()
|
||||
if n_dim == 4:
|
||||
n_img = len(tensor)
|
||||
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
|
||||
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
||||
elif n_dim == 3:
|
||||
img_np = tensor.numpy()
|
||||
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
||||
elif n_dim == 2:
|
||||
img_np = tensor.numpy()
|
||||
else:
|
||||
raise TypeError(
|
||||
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
|
||||
if out_type == np.uint8:
|
||||
img_np = (img_np * 255.0).round()
|
||||
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
|
||||
return img_np.astype(out_type)
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# Augmentation, flipe and/or rotate
|
||||
# --------------------------------------------
|
||||
# The following two are enough.
|
||||
# (1) augmet_img: numpy image of WxHxC or WxH
|
||||
# (2) augment_img_tensor4: tensor image 1xCxWxH
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
def augment_img(img, mode=0):
|
||||
'''Kai Zhang (github: https://github.com/cszn)
|
||||
'''
|
||||
if mode == 0:
|
||||
return img
|
||||
elif mode == 1:
|
||||
return np.flipud(np.rot90(img))
|
||||
elif mode == 2:
|
||||
return np.flipud(img)
|
||||
elif mode == 3:
|
||||
return np.rot90(img, k=3)
|
||||
elif mode == 4:
|
||||
return np.flipud(np.rot90(img, k=2))
|
||||
elif mode == 5:
|
||||
return np.rot90(img)
|
||||
elif mode == 6:
|
||||
return np.rot90(img, k=2)
|
||||
elif mode == 7:
|
||||
return np.flipud(np.rot90(img, k=3))
|
||||
|
||||
|
||||
def augment_img_tensor4(img, mode=0):
|
||||
'''Kai Zhang (github: https://github.com/cszn)
|
||||
'''
|
||||
if mode == 0:
|
||||
return img
|
||||
elif mode == 1:
|
||||
return img.rot90(1, [2, 3]).flip([2])
|
||||
elif mode == 2:
|
||||
return img.flip([2])
|
||||
elif mode == 3:
|
||||
return img.rot90(3, [2, 3])
|
||||
elif mode == 4:
|
||||
return img.rot90(2, [2, 3]).flip([2])
|
||||
elif mode == 5:
|
||||
return img.rot90(1, [2, 3])
|
||||
elif mode == 6:
|
||||
return img.rot90(2, [2, 3])
|
||||
elif mode == 7:
|
||||
return img.rot90(3, [2, 3]).flip([2])
|
||||
|
||||
|
||||
def augment_img_tensor(img, mode=0):
|
||||
'''Kai Zhang (github: https://github.com/cszn)
|
||||
'''
|
||||
img_size = img.size()
|
||||
img_np = img.data.cpu().numpy()
|
||||
if len(img_size) == 3:
|
||||
img_np = np.transpose(img_np, (1, 2, 0))
|
||||
elif len(img_size) == 4:
|
||||
img_np = np.transpose(img_np, (2, 3, 1, 0))
|
||||
img_np = augment_img(img_np, mode=mode)
|
||||
img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
|
||||
if len(img_size) == 3:
|
||||
img_tensor = img_tensor.permute(2, 0, 1)
|
||||
elif len(img_size) == 4:
|
||||
img_tensor = img_tensor.permute(3, 2, 0, 1)
|
||||
|
||||
return img_tensor.type_as(img)
|
||||
|
||||
|
||||
def augment_img_np3(img, mode=0):
|
||||
if mode == 0:
|
||||
return img
|
||||
elif mode == 1:
|
||||
return img.transpose(1, 0, 2)
|
||||
elif mode == 2:
|
||||
return img[::-1, :, :]
|
||||
elif mode == 3:
|
||||
img = img[::-1, :, :]
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
elif mode == 4:
|
||||
return img[:, ::-1, :]
|
||||
elif mode == 5:
|
||||
img = img[:, ::-1, :]
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
elif mode == 6:
|
||||
img = img[:, ::-1, :]
|
||||
img = img[::-1, :, :]
|
||||
return img
|
||||
elif mode == 7:
|
||||
img = img[:, ::-1, :]
|
||||
img = img[::-1, :, :]
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
|
||||
|
||||
def augment_imgs(img_list, hflip=True, rot=True):
|
||||
# horizontal flip OR rotate
|
||||
hflip = hflip and random.random() < 0.5
|
||||
vflip = rot and random.random() < 0.5
|
||||
rot90 = rot and random.random() < 0.5
|
||||
|
||||
def _augment(img):
|
||||
if hflip:
|
||||
img = img[:, ::-1, :]
|
||||
if vflip:
|
||||
img = img[::-1, :, :]
|
||||
if rot90:
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
|
||||
return [_augment(img) for img in img_list]
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# modcrop and shave
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
def modcrop(img_in, scale):
|
||||
# img_in: Numpy, HWC or HW
|
||||
img = np.copy(img_in)
|
||||
if img.ndim == 2:
|
||||
H, W = img.shape
|
||||
H_r, W_r = H % scale, W % scale
|
||||
img = img[:H - H_r, :W - W_r]
|
||||
elif img.ndim == 3:
|
||||
H, W, C = img.shape
|
||||
H_r, W_r = H % scale, W % scale
|
||||
img = img[:H - H_r, :W - W_r, :]
|
||||
else:
|
||||
raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
|
||||
return img
|
||||
|
||||
|
||||
def shave(img_in, border=0):
|
||||
# img_in: Numpy, HWC or HW
|
||||
img = np.copy(img_in)
|
||||
h, w = img.shape[:2]
|
||||
img = img[border:h-border, border:w-border]
|
||||
return img
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# image processing process on numpy image
|
||||
# channel_convert(in_c, tar_type, img_list):
|
||||
# rgb2ycbcr(img, only_y=True):
|
||||
# bgr2ycbcr(img, only_y=True):
|
||||
# ycbcr2rgb(img):
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
def rgb2ycbcr(img, only_y=True):
|
||||
'''same as matlab rgb2ycbcr
|
||||
only_y: only return Y channel
|
||||
Input:
|
||||
uint8, [0, 255]
|
||||
float, [0, 1]
|
||||
'''
|
||||
in_img_type = img.dtype
|
||||
img.astype(np.float32)
|
||||
if in_img_type != np.uint8:
|
||||
img *= 255.
|
||||
# convert
|
||||
if only_y:
|
||||
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
|
||||
else:
|
||||
rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
|
||||
[24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
|
||||
if in_img_type == np.uint8:
|
||||
rlt = rlt.round()
|
||||
else:
|
||||
rlt /= 255.
|
||||
return rlt.astype(in_img_type)
|
||||
|
||||
|
||||
def ycbcr2rgb(img):
|
||||
'''same as matlab ycbcr2rgb
|
||||
Input:
|
||||
uint8, [0, 255]
|
||||
float, [0, 1]
|
||||
'''
|
||||
in_img_type = img.dtype
|
||||
img.astype(np.float32)
|
||||
if in_img_type != np.uint8:
|
||||
img *= 255.
|
||||
# convert
|
||||
rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
|
||||
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
|
||||
if in_img_type == np.uint8:
|
||||
rlt = rlt.round()
|
||||
else:
|
||||
rlt /= 255.
|
||||
return rlt.astype(in_img_type)
|
||||
|
||||
|
||||
def bgr2ycbcr(img, only_y=True):
|
||||
'''bgr version of rgb2ycbcr
|
||||
only_y: only return Y channel
|
||||
Input:
|
||||
uint8, [0, 255]
|
||||
float, [0, 1]
|
||||
'''
|
||||
in_img_type = img.dtype
|
||||
img.astype(np.float32)
|
||||
if in_img_type != np.uint8:
|
||||
img *= 255.
|
||||
# convert
|
||||
if only_y:
|
||||
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
|
||||
else:
|
||||
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
|
||||
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
|
||||
if in_img_type == np.uint8:
|
||||
rlt = rlt.round()
|
||||
else:
|
||||
rlt /= 255.
|
||||
return rlt.astype(in_img_type)
|
||||
|
||||
|
||||
def channel_convert(in_c, tar_type, img_list):
|
||||
# conversion among BGR, gray and y
|
||||
if in_c == 3 and tar_type == 'gray': # BGR to gray
|
||||
gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
|
||||
return [np.expand_dims(img, axis=2) for img in gray_list]
|
||||
elif in_c == 3 and tar_type == 'y': # BGR to y
|
||||
y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
|
||||
return [np.expand_dims(img, axis=2) for img in y_list]
|
||||
elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
|
||||
return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
|
||||
else:
|
||||
return img_list
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# metric, PSNR and SSIM
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# PSNR
|
||||
# --------------------------------------------
|
||||
def calculate_psnr(img1, img2, border=0):
|
||||
# img1 and img2 have range [0, 255]
|
||||
#img1 = img1.squeeze()
|
||||
#img2 = img2.squeeze()
|
||||
if not img1.shape == img2.shape:
|
||||
raise ValueError('Input images must have the same dimensions.')
|
||||
h, w = img1.shape[:2]
|
||||
img1 = img1[border:h-border, border:w-border]
|
||||
img2 = img2[border:h-border, border:w-border]
|
||||
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
mse = np.mean((img1 - img2)**2)
|
||||
if mse == 0:
|
||||
return float('inf')
|
||||
return 20 * math.log10(255.0 / math.sqrt(mse))
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# SSIM
|
||||
# --------------------------------------------
|
||||
def calculate_ssim(img1, img2, border=0):
|
||||
'''calculate SSIM
|
||||
the same outputs as MATLAB's
|
||||
img1, img2: [0, 255]
|
||||
'''
|
||||
#img1 = img1.squeeze()
|
||||
#img2 = img2.squeeze()
|
||||
if not img1.shape == img2.shape:
|
||||
raise ValueError('Input images must have the same dimensions.')
|
||||
h, w = img1.shape[:2]
|
||||
img1 = img1[border:h-border, border:w-border]
|
||||
img2 = img2[border:h-border, border:w-border]
|
||||
|
||||
if img1.ndim == 2:
|
||||
return ssim(img1, img2)
|
||||
elif img1.ndim == 3:
|
||||
if img1.shape[2] == 3:
|
||||
ssims = []
|
||||
for i in range(3):
|
||||
ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
|
||||
return np.array(ssims).mean()
|
||||
elif img1.shape[2] == 1:
|
||||
return ssim(np.squeeze(img1), np.squeeze(img2))
|
||||
else:
|
||||
raise ValueError('Wrong input image dimensions.')
|
||||
|
||||
|
||||
def ssim(img1, img2):
|
||||
C1 = (0.01 * 255)**2
|
||||
C2 = (0.03 * 255)**2
|
||||
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
kernel = cv2.getGaussianKernel(11, 1.5)
|
||||
window = np.outer(kernel, kernel.transpose())
|
||||
|
||||
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
||||
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
||||
mu1_sq = mu1**2
|
||||
mu2_sq = mu2**2
|
||||
mu1_mu2 = mu1 * mu2
|
||||
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
||||
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
||||
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
||||
(sigma1_sq + sigma2_sq + C2))
|
||||
return ssim_map.mean()
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# matlab's bicubic imresize (numpy and torch) [0, 1]
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
# matlab 'imresize' function, now only support 'bicubic'
|
||||
def cubic(x):
|
||||
absx = torch.abs(x)
|
||||
absx2 = absx**2
|
||||
absx3 = absx**3
|
||||
return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
|
||||
(-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
|
||||
|
||||
|
||||
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
||||
if (scale < 1) and (antialiasing):
|
||||
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
|
||||
kernel_width = kernel_width / scale
|
||||
|
||||
# Output-space coordinates
|
||||
x = torch.linspace(1, out_length, out_length)
|
||||
|
||||
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
||||
# in output space maps to 0.5 in input space, and 0.5+scale in output
|
||||
# space maps to 1.5 in input space.
|
||||
u = x / scale + 0.5 * (1 - 1 / scale)
|
||||
|
||||
# What is the left-most pixel that can be involved in the computation?
|
||||
left = torch.floor(u - kernel_width / 2)
|
||||
|
||||
# What is the maximum number of pixels that can be involved in the
|
||||
# computation? Note: it's OK to use an extra pixel here; if the
|
||||
# corresponding weights are all zero, it will be eliminated at the end
|
||||
# of this function.
|
||||
P = math.ceil(kernel_width) + 2
|
||||
|
||||
# The indices of the input pixels involved in computing the k-th output
|
||||
# pixel are in row k of the indices matrix.
|
||||
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
|
||||
1, P).expand(out_length, P)
|
||||
|
||||
# The weights used to compute the k-th output pixel are in row k of the
|
||||
# weights matrix.
|
||||
distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
|
||||
# apply cubic kernel
|
||||
if (scale < 1) and (antialiasing):
|
||||
weights = scale * cubic(distance_to_center * scale)
|
||||
else:
|
||||
weights = cubic(distance_to_center)
|
||||
# Normalize the weights matrix so that each row sums to 1.
|
||||
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
||||
weights = weights / weights_sum.expand(out_length, P)
|
||||
|
||||
# If a column in weights is all zero, get rid of it. only consider the first and last column.
|
||||
weights_zero_tmp = torch.sum((weights == 0), 0)
|
||||
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
||||
indices = indices.narrow(1, 1, P - 2)
|
||||
weights = weights.narrow(1, 1, P - 2)
|
||||
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
||||
indices = indices.narrow(1, 0, P - 2)
|
||||
weights = weights.narrow(1, 0, P - 2)
|
||||
weights = weights.contiguous()
|
||||
indices = indices.contiguous()
|
||||
sym_len_s = -indices.min() + 1
|
||||
sym_len_e = indices.max() - in_length
|
||||
indices = indices + sym_len_s - 1
|
||||
return weights, indices, int(sym_len_s), int(sym_len_e)
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# imresize for tensor image [0, 1]
|
||||
# --------------------------------------------
|
||||
def imresize(img, scale, antialiasing=True):
|
||||
# Now the scale should be the same for H and W
|
||||
# input: img: pytorch tensor, CHW or HW [0,1]
|
||||
# output: CHW or HW [0,1] w/o round
|
||||
need_squeeze = True if img.dim() == 2 else False
|
||||
if need_squeeze:
|
||||
img.unsqueeze_(0)
|
||||
in_C, in_H, in_W = img.size()
|
||||
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
|
||||
kernel_width = 4
|
||||
kernel = 'cubic'
|
||||
|
||||
# Return the desired dimension order for performing the resize. The
|
||||
# strategy is to perform the resize first along the dimension with the
|
||||
# smallest scale factor.
|
||||
# Now we do not support this.
|
||||
|
||||
# get weights and indices
|
||||
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
||||
in_H, out_H, scale, kernel, kernel_width, antialiasing)
|
||||
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
||||
in_W, out_W, scale, kernel, kernel_width, antialiasing)
|
||||
# process H dimension
|
||||
# symmetric copying
|
||||
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
|
||||
img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
|
||||
|
||||
sym_patch = img[:, :sym_len_Hs, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = img[:, -sym_len_He:, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
||||
|
||||
out_1 = torch.FloatTensor(in_C, out_H, in_W)
|
||||
kernel_width = weights_H.size(1)
|
||||
for i in range(out_H):
|
||||
idx = int(indices_H[i][0])
|
||||
for j in range(out_C):
|
||||
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
||||
|
||||
# process W dimension
|
||||
# symmetric copying
|
||||
out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
|
||||
out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
|
||||
|
||||
sym_patch = out_1[:, :, :sym_len_Ws]
|
||||
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
||||
out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = out_1[:, :, -sym_len_We:]
|
||||
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
||||
out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
||||
|
||||
out_2 = torch.FloatTensor(in_C, out_H, out_W)
|
||||
kernel_width = weights_W.size(1)
|
||||
for i in range(out_W):
|
||||
idx = int(indices_W[i][0])
|
||||
for j in range(out_C):
|
||||
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
|
||||
if need_squeeze:
|
||||
out_2.squeeze_()
|
||||
return out_2
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# imresize for numpy image [0, 1]
|
||||
# --------------------------------------------
|
||||
def imresize_np(img, scale, antialiasing=True):
|
||||
# Now the scale should be the same for H and W
|
||||
# input: img: Numpy, HWC or HW [0,1]
|
||||
# output: HWC or HW [0,1] w/o round
|
||||
img = torch.from_numpy(img)
|
||||
need_squeeze = True if img.dim() == 2 else False
|
||||
if need_squeeze:
|
||||
img.unsqueeze_(2)
|
||||
|
||||
in_H, in_W, in_C = img.size()
|
||||
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
|
||||
kernel_width = 4
|
||||
kernel = 'cubic'
|
||||
|
||||
# Return the desired dimension order for performing the resize. The
|
||||
# strategy is to perform the resize first along the dimension with the
|
||||
# smallest scale factor.
|
||||
# Now we do not support this.
|
||||
|
||||
# get weights and indices
|
||||
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
||||
in_H, out_H, scale, kernel, kernel_width, antialiasing)
|
||||
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
||||
in_W, out_W, scale, kernel, kernel_width, antialiasing)
|
||||
# process H dimension
|
||||
# symmetric copying
|
||||
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
|
||||
img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
|
||||
|
||||
sym_patch = img[:sym_len_Hs, :, :]
|
||||
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
||||
img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = img[-sym_len_He:, :, :]
|
||||
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
||||
img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
||||
|
||||
out_1 = torch.FloatTensor(out_H, in_W, in_C)
|
||||
kernel_width = weights_H.size(1)
|
||||
for i in range(out_H):
|
||||
idx = int(indices_H[i][0])
|
||||
for j in range(out_C):
|
||||
out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
|
||||
|
||||
# process W dimension
|
||||
# symmetric copying
|
||||
out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
|
||||
out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
|
||||
|
||||
sym_patch = out_1[:, :sym_len_Ws, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = out_1[:, -sym_len_We:, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
||||
|
||||
out_2 = torch.FloatTensor(out_H, out_W, in_C)
|
||||
kernel_width = weights_W.size(1)
|
||||
for i in range(out_W):
|
||||
idx = int(indices_W[i][0])
|
||||
for j in range(out_C):
|
||||
out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
|
||||
if need_squeeze:
|
||||
out_2.squeeze_()
|
||||
|
||||
return out_2.numpy()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('---')
|
||||
# img = imread_uint('test.bmp', 3)
|
||||
# img = uint2single(img)
|
||||
# img_bicubic = imresize_np(img, 1/4)
|
||||
1
ldm/modules/losses/__init__.py
Executable file
1
ldm/modules/losses/__init__.py
Executable file
@@ -0,0 +1 @@
|
||||
from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
|
||||
111
ldm/modules/losses/contperceptual.py
Executable file
111
ldm/modules/losses/contperceptual.py
Executable file
@@ -0,0 +1,111 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
|
||||
|
||||
|
||||
class LPIPSWithDiscriminator(nn.Module):
|
||||
def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
|
||||
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
||||
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
|
||||
disc_loss="hinge"):
|
||||
|
||||
super().__init__()
|
||||
assert disc_loss in ["hinge", "vanilla"]
|
||||
self.kl_weight = kl_weight
|
||||
self.pixel_weight = pixelloss_weight
|
||||
self.perceptual_loss = LPIPS().eval()
|
||||
self.perceptual_weight = perceptual_weight
|
||||
# output log variance
|
||||
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
|
||||
|
||||
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
|
||||
n_layers=disc_num_layers,
|
||||
use_actnorm=use_actnorm
|
||||
).apply(weights_init)
|
||||
self.discriminator_iter_start = disc_start
|
||||
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
|
||||
self.disc_factor = disc_factor
|
||||
self.discriminator_weight = disc_weight
|
||||
self.disc_conditional = disc_conditional
|
||||
|
||||
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
||||
if last_layer is not None:
|
||||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
||||
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
||||
else:
|
||||
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
|
||||
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
|
||||
|
||||
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
||||
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
||||
d_weight = d_weight * self.discriminator_weight
|
||||
return d_weight
|
||||
|
||||
def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
|
||||
global_step, last_layer=None, cond=None, split="train",
|
||||
weights=None):
|
||||
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
||||
if self.perceptual_weight > 0:
|
||||
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
|
||||
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
||||
|
||||
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
||||
weighted_nll_loss = nll_loss
|
||||
if weights is not None:
|
||||
weighted_nll_loss = weights*nll_loss
|
||||
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
||||
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
||||
kl_loss = posteriors.kl()
|
||||
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
||||
|
||||
# now the GAN part
|
||||
if optimizer_idx == 0:
|
||||
# generator update
|
||||
if cond is None:
|
||||
assert not self.disc_conditional
|
||||
logits_fake = self.discriminator(reconstructions.contiguous())
|
||||
else:
|
||||
assert self.disc_conditional
|
||||
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
|
||||
g_loss = -torch.mean(logits_fake)
|
||||
|
||||
if self.disc_factor > 0.0:
|
||||
try:
|
||||
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
||||
except RuntimeError:
|
||||
assert not self.training
|
||||
d_weight = torch.tensor(0.0)
|
||||
else:
|
||||
d_weight = torch.tensor(0.0)
|
||||
|
||||
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
||||
loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
|
||||
|
||||
log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
|
||||
"{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
|
||||
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
||||
"{}/d_weight".format(split): d_weight.detach(),
|
||||
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
||||
"{}/g_loss".format(split): g_loss.detach().mean(),
|
||||
}
|
||||
return loss, log
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# second pass for discriminator update
|
||||
if cond is None:
|
||||
logits_real = self.discriminator(inputs.contiguous().detach())
|
||||
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
||||
else:
|
||||
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
|
||||
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
|
||||
|
||||
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
||||
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
||||
|
||||
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
||||
"{}/logits_real".format(split): logits_real.detach().mean(),
|
||||
"{}/logits_fake".format(split): logits_fake.detach().mean()
|
||||
}
|
||||
return d_loss, log
|
||||
|
||||
167
ldm/modules/losses/vqperceptual.py
Executable file
167
ldm/modules/losses/vqperceptual.py
Executable file
@@ -0,0 +1,167 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from einops import repeat
|
||||
|
||||
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
|
||||
from taming.modules.losses.lpips import LPIPS
|
||||
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
|
||||
|
||||
|
||||
def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
|
||||
assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
|
||||
loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
|
||||
loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
|
||||
loss_real = (weights * loss_real).sum() / weights.sum()
|
||||
loss_fake = (weights * loss_fake).sum() / weights.sum()
|
||||
d_loss = 0.5 * (loss_real + loss_fake)
|
||||
return d_loss
|
||||
|
||||
def adopt_weight(weight, global_step, threshold=0, value=0.):
|
||||
if global_step < threshold:
|
||||
weight = value
|
||||
return weight
|
||||
|
||||
|
||||
def measure_perplexity(predicted_indices, n_embed):
|
||||
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
|
||||
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
|
||||
encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
|
||||
avg_probs = encodings.mean(0)
|
||||
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
|
||||
cluster_use = torch.sum(avg_probs > 0)
|
||||
return perplexity, cluster_use
|
||||
|
||||
def l1(x, y):
|
||||
return torch.abs(x-y)
|
||||
|
||||
|
||||
def l2(x, y):
|
||||
return torch.pow((x-y), 2)
|
||||
|
||||
|
||||
class VQLPIPSWithDiscriminator(nn.Module):
|
||||
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
|
||||
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
||||
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
|
||||
disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
|
||||
pixel_loss="l1"):
|
||||
super().__init__()
|
||||
assert disc_loss in ["hinge", "vanilla"]
|
||||
assert perceptual_loss in ["lpips", "clips", "dists"]
|
||||
assert pixel_loss in ["l1", "l2"]
|
||||
self.codebook_weight = codebook_weight
|
||||
self.pixel_weight = pixelloss_weight
|
||||
if perceptual_loss == "lpips":
|
||||
print(f"{self.__class__.__name__}: Running with LPIPS.")
|
||||
self.perceptual_loss = LPIPS().eval()
|
||||
else:
|
||||
raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
|
||||
self.perceptual_weight = perceptual_weight
|
||||
|
||||
if pixel_loss == "l1":
|
||||
self.pixel_loss = l1
|
||||
else:
|
||||
self.pixel_loss = l2
|
||||
|
||||
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
|
||||
n_layers=disc_num_layers,
|
||||
use_actnorm=use_actnorm,
|
||||
ndf=disc_ndf
|
||||
).apply(weights_init)
|
||||
self.discriminator_iter_start = disc_start
|
||||
if disc_loss == "hinge":
|
||||
self.disc_loss = hinge_d_loss
|
||||
elif disc_loss == "vanilla":
|
||||
self.disc_loss = vanilla_d_loss
|
||||
else:
|
||||
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
|
||||
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
|
||||
self.disc_factor = disc_factor
|
||||
self.discriminator_weight = disc_weight
|
||||
self.disc_conditional = disc_conditional
|
||||
self.n_classes = n_classes
|
||||
|
||||
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
||||
if last_layer is not None:
|
||||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
||||
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
||||
else:
|
||||
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
|
||||
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
|
||||
|
||||
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
||||
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
||||
d_weight = d_weight * self.discriminator_weight
|
||||
return d_weight
|
||||
|
||||
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
|
||||
global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
|
||||
if not exists(codebook_loss):
|
||||
codebook_loss = torch.tensor([0.]).to(inputs.device)
|
||||
#rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
||||
rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
|
||||
if self.perceptual_weight > 0:
|
||||
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
|
||||
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
||||
else:
|
||||
p_loss = torch.tensor([0.0])
|
||||
|
||||
nll_loss = rec_loss
|
||||
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
||||
nll_loss = torch.mean(nll_loss)
|
||||
|
||||
# now the GAN part
|
||||
if optimizer_idx == 0:
|
||||
# generator update
|
||||
if cond is None:
|
||||
assert not self.disc_conditional
|
||||
logits_fake = self.discriminator(reconstructions.contiguous())
|
||||
else:
|
||||
assert self.disc_conditional
|
||||
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
|
||||
g_loss = -torch.mean(logits_fake)
|
||||
|
||||
try:
|
||||
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
||||
except RuntimeError:
|
||||
assert not self.training
|
||||
d_weight = torch.tensor(0.0)
|
||||
|
||||
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
||||
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
|
||||
|
||||
log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
|
||||
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
|
||||
"{}/nll_loss".format(split): nll_loss.detach().mean(),
|
||||
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
||||
"{}/p_loss".format(split): p_loss.detach().mean(),
|
||||
"{}/d_weight".format(split): d_weight.detach(),
|
||||
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
||||
"{}/g_loss".format(split): g_loss.detach().mean(),
|
||||
}
|
||||
if predicted_indices is not None:
|
||||
assert self.n_classes is not None
|
||||
with torch.no_grad():
|
||||
perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
|
||||
log[f"{split}/perplexity"] = perplexity
|
||||
log[f"{split}/cluster_usage"] = cluster_usage
|
||||
return loss, log
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# second pass for discriminator update
|
||||
if cond is None:
|
||||
logits_real = self.discriminator(inputs.contiguous().detach())
|
||||
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
||||
else:
|
||||
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
|
||||
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
|
||||
|
||||
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
||||
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
||||
|
||||
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
||||
"{}/logits_real".format(split): logits_real.detach().mean(),
|
||||
"{}/logits_fake".format(split): logits_fake.detach().mean()
|
||||
}
|
||||
return d_loss, log
|
||||
641
ldm/modules/x_transformer.py
Executable file
641
ldm/modules/x_transformer.py
Executable file
@@ -0,0 +1,641 @@
|
||||
"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
from inspect import isfunction
|
||||
from collections import namedtuple
|
||||
from einops import rearrange, repeat, reduce
|
||||
|
||||
# constants
|
||||
|
||||
DEFAULT_DIM_HEAD = 64
|
||||
|
||||
Intermediates = namedtuple('Intermediates', [
|
||||
'pre_softmax_attn',
|
||||
'post_softmax_attn'
|
||||
])
|
||||
|
||||
LayerIntermediates = namedtuple('Intermediates', [
|
||||
'hiddens',
|
||||
'attn_intermediates'
|
||||
])
|
||||
|
||||
|
||||
class AbsolutePositionalEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_seq_len):
|
||||
super().__init__()
|
||||
self.emb = nn.Embedding(max_seq_len, dim)
|
||||
self.init_()
|
||||
|
||||
def init_(self):
|
||||
nn.init.normal_(self.emb.weight, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
n = torch.arange(x.shape[1], device=x.device)
|
||||
return self.emb(n)[None, :, :]
|
||||
|
||||
|
||||
class FixedPositionalEmbedding(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer('inv_freq', inv_freq)
|
||||
|
||||
def forward(self, x, seq_dim=1, offset=0):
|
||||
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
|
||||
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
|
||||
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
|
||||
return emb[None, :, :]
|
||||
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def always(val):
|
||||
def inner(*args, **kwargs):
|
||||
return val
|
||||
return inner
|
||||
|
||||
|
||||
def not_equals(val):
|
||||
def inner(x):
|
||||
return x != val
|
||||
return inner
|
||||
|
||||
|
||||
def equals(val):
|
||||
def inner(x):
|
||||
return x == val
|
||||
return inner
|
||||
|
||||
|
||||
def max_neg_value(tensor):
|
||||
return -torch.finfo(tensor.dtype).max
|
||||
|
||||
|
||||
# keyword argument helpers
|
||||
|
||||
def pick_and_pop(keys, d):
|
||||
values = list(map(lambda key: d.pop(key), keys))
|
||||
return dict(zip(keys, values))
|
||||
|
||||
|
||||
def group_dict_by_key(cond, d):
|
||||
return_val = [dict(), dict()]
|
||||
for key in d.keys():
|
||||
match = bool(cond(key))
|
||||
ind = int(not match)
|
||||
return_val[ind][key] = d[key]
|
||||
return (*return_val,)
|
||||
|
||||
|
||||
def string_begins_with(prefix, str):
|
||||
return str.startswith(prefix)
|
||||
|
||||
|
||||
def group_by_key_prefix(prefix, d):
|
||||
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
|
||||
|
||||
def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
|
||||
# classes
|
||||
class Scale(nn.Module):
|
||||
def __init__(self, value, fn):
|
||||
super().__init__()
|
||||
self.value = value
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
x, *rest = self.fn(x, **kwargs)
|
||||
return (x * self.value, *rest)
|
||||
|
||||
|
||||
class Rezero(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.g = nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
x, *rest = self.fn(x, **kwargs)
|
||||
return (x * self.g, *rest)
|
||||
|
||||
|
||||
class ScaleNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.scale = dim ** -0.5
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1))
|
||||
|
||||
def forward(self, x):
|
||||
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
||||
return x / norm.clamp(min=self.eps) * self.g
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-8):
|
||||
super().__init__()
|
||||
self.scale = dim ** -0.5
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
||||
return x / norm.clamp(min=self.eps) * self.g
|
||||
|
||||
|
||||
class Residual(nn.Module):
|
||||
def forward(self, x, residual):
|
||||
return x + residual
|
||||
|
||||
|
||||
class GRUGating(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gru = nn.GRUCell(dim, dim)
|
||||
|
||||
def forward(self, x, residual):
|
||||
gated_output = self.gru(
|
||||
rearrange(x, 'b n d -> (b n) d'),
|
||||
rearrange(residual, 'b n d -> (b n) d')
|
||||
)
|
||||
|
||||
return gated_output.reshape_as(x)
|
||||
|
||||
|
||||
# feedforward
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
# attention.
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_head=DEFAULT_DIM_HEAD,
|
||||
heads=8,
|
||||
causal=False,
|
||||
mask=None,
|
||||
talking_heads=False,
|
||||
sparse_topk=None,
|
||||
use_entmax15=False,
|
||||
num_mem_kv=0,
|
||||
dropout=0.,
|
||||
on_attn=False
|
||||
):
|
||||
super().__init__()
|
||||
if use_entmax15:
|
||||
raise NotImplementedError("Check out entmax activation instead of softmax activation!")
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
self.causal = causal
|
||||
self.mask = mask
|
||||
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# talking heads
|
||||
self.talking_heads = talking_heads
|
||||
if talking_heads:
|
||||
self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
|
||||
self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
|
||||
|
||||
# explicit topk sparse attention
|
||||
self.sparse_topk = sparse_topk
|
||||
|
||||
# entmax
|
||||
#self.attn_fn = entmax15 if use_entmax15 else F.softmax
|
||||
self.attn_fn = F.softmax
|
||||
|
||||
# add memory key / values
|
||||
self.num_mem_kv = num_mem_kv
|
||||
if num_mem_kv > 0:
|
||||
self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
|
||||
self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
|
||||
|
||||
# attention on attention
|
||||
self.attn_on_attn = on_attn
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context=None,
|
||||
mask=None,
|
||||
context_mask=None,
|
||||
rel_pos=None,
|
||||
sinusoidal_emb=None,
|
||||
prev_attn=None,
|
||||
mem=None
|
||||
):
|
||||
b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
|
||||
kv_input = default(context, x)
|
||||
|
||||
q_input = x
|
||||
k_input = kv_input
|
||||
v_input = kv_input
|
||||
|
||||
if exists(mem):
|
||||
k_input = torch.cat((mem, k_input), dim=-2)
|
||||
v_input = torch.cat((mem, v_input), dim=-2)
|
||||
|
||||
if exists(sinusoidal_emb):
|
||||
# in shortformer, the query would start at a position offset depending on the past cached memory
|
||||
offset = k_input.shape[-2] - q_input.shape[-2]
|
||||
q_input = q_input + sinusoidal_emb(q_input, offset=offset)
|
||||
k_input = k_input + sinusoidal_emb(k_input)
|
||||
|
||||
q = self.to_q(q_input)
|
||||
k = self.to_k(k_input)
|
||||
v = self.to_v(v_input)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
|
||||
|
||||
input_mask = None
|
||||
if any(map(exists, (mask, context_mask))):
|
||||
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
|
||||
k_mask = q_mask if not exists(context) else context_mask
|
||||
k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
|
||||
q_mask = rearrange(q_mask, 'b i -> b () i ()')
|
||||
k_mask = rearrange(k_mask, 'b j -> b () () j')
|
||||
input_mask = q_mask * k_mask
|
||||
|
||||
if self.num_mem_kv > 0:
|
||||
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
|
||||
k = torch.cat((mem_k, k), dim=-2)
|
||||
v = torch.cat((mem_v, v), dim=-2)
|
||||
if exists(input_mask):
|
||||
input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
mask_value = max_neg_value(dots)
|
||||
|
||||
if exists(prev_attn):
|
||||
dots = dots + prev_attn
|
||||
|
||||
pre_softmax_attn = dots
|
||||
|
||||
if talking_heads:
|
||||
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
|
||||
|
||||
if exists(rel_pos):
|
||||
dots = rel_pos(dots)
|
||||
|
||||
if exists(input_mask):
|
||||
dots.masked_fill_(~input_mask, mask_value)
|
||||
del input_mask
|
||||
|
||||
if self.causal:
|
||||
i, j = dots.shape[-2:]
|
||||
r = torch.arange(i, device=device)
|
||||
mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
|
||||
mask = F.pad(mask, (j - i, 0), value=False)
|
||||
dots.masked_fill_(mask, mask_value)
|
||||
del mask
|
||||
|
||||
if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
|
||||
top, _ = dots.topk(self.sparse_topk, dim=-1)
|
||||
vk = top[..., -1].unsqueeze(-1).expand_as(dots)
|
||||
mask = dots < vk
|
||||
dots.masked_fill_(mask, mask_value)
|
||||
del mask
|
||||
|
||||
attn = self.attn_fn(dots, dim=-1)
|
||||
post_softmax_attn = attn
|
||||
|
||||
attn = self.dropout(attn)
|
||||
|
||||
if talking_heads:
|
||||
attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
intermediates = Intermediates(
|
||||
pre_softmax_attn=pre_softmax_attn,
|
||||
post_softmax_attn=post_softmax_attn
|
||||
)
|
||||
|
||||
return self.to_out(out), intermediates
|
||||
|
||||
|
||||
class AttentionLayers(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
heads=8,
|
||||
causal=False,
|
||||
cross_attend=False,
|
||||
only_cross=False,
|
||||
use_scalenorm=False,
|
||||
use_rmsnorm=False,
|
||||
use_rezero=False,
|
||||
rel_pos_num_buckets=32,
|
||||
rel_pos_max_distance=128,
|
||||
position_infused_attn=False,
|
||||
custom_layers=None,
|
||||
sandwich_coef=None,
|
||||
par_ratio=None,
|
||||
residual_attn=False,
|
||||
cross_residual_attn=False,
|
||||
macaron=False,
|
||||
pre_norm=True,
|
||||
gate_residual=False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
|
||||
attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
|
||||
|
||||
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
|
||||
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.has_pos_emb = position_infused_attn
|
||||
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
|
||||
self.rotary_pos_emb = always(None)
|
||||
|
||||
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
||||
self.rel_pos = None
|
||||
|
||||
self.pre_norm = pre_norm
|
||||
|
||||
self.residual_attn = residual_attn
|
||||
self.cross_residual_attn = cross_residual_attn
|
||||
|
||||
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
|
||||
norm_class = RMSNorm if use_rmsnorm else norm_class
|
||||
norm_fn = partial(norm_class, dim)
|
||||
|
||||
norm_fn = nn.Identity if use_rezero else norm_fn
|
||||
branch_fn = Rezero if use_rezero else None
|
||||
|
||||
if cross_attend and not only_cross:
|
||||
default_block = ('a', 'c', 'f')
|
||||
elif cross_attend and only_cross:
|
||||
default_block = ('c', 'f')
|
||||
else:
|
||||
default_block = ('a', 'f')
|
||||
|
||||
if macaron:
|
||||
default_block = ('f',) + default_block
|
||||
|
||||
if exists(custom_layers):
|
||||
layer_types = custom_layers
|
||||
elif exists(par_ratio):
|
||||
par_depth = depth * len(default_block)
|
||||
assert 1 < par_ratio <= par_depth, 'par ratio out of range'
|
||||
default_block = tuple(filter(not_equals('f'), default_block))
|
||||
par_attn = par_depth // par_ratio
|
||||
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
|
||||
par_width = (depth_cut + depth_cut // par_attn) // par_attn
|
||||
assert len(default_block) <= par_width, 'default block is too large for par_ratio'
|
||||
par_block = default_block + ('f',) * (par_width - len(default_block))
|
||||
par_head = par_block * par_attn
|
||||
layer_types = par_head + ('f',) * (par_depth - len(par_head))
|
||||
elif exists(sandwich_coef):
|
||||
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
|
||||
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
|
||||
else:
|
||||
layer_types = default_block * depth
|
||||
|
||||
self.layer_types = layer_types
|
||||
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
|
||||
|
||||
for layer_type in self.layer_types:
|
||||
if layer_type == 'a':
|
||||
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
|
||||
elif layer_type == 'c':
|
||||
layer = Attention(dim, heads=heads, **attn_kwargs)
|
||||
elif layer_type == 'f':
|
||||
layer = FeedForward(dim, **ff_kwargs)
|
||||
layer = layer if not macaron else Scale(0.5, layer)
|
||||
else:
|
||||
raise Exception(f'invalid layer type {layer_type}')
|
||||
|
||||
if isinstance(layer, Attention) and exists(branch_fn):
|
||||
layer = branch_fn(layer)
|
||||
|
||||
if gate_residual:
|
||||
residual_fn = GRUGating(dim)
|
||||
else:
|
||||
residual_fn = Residual()
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
norm_fn(),
|
||||
layer,
|
||||
residual_fn
|
||||
]))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context=None,
|
||||
mask=None,
|
||||
context_mask=None,
|
||||
mems=None,
|
||||
return_hiddens=False
|
||||
):
|
||||
hiddens = []
|
||||
intermediates = []
|
||||
prev_attn = None
|
||||
prev_cross_attn = None
|
||||
|
||||
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
|
||||
|
||||
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
|
||||
is_last = ind == (len(self.layers) - 1)
|
||||
|
||||
if layer_type == 'a':
|
||||
hiddens.append(x)
|
||||
layer_mem = mems.pop(0)
|
||||
|
||||
residual = x
|
||||
|
||||
if self.pre_norm:
|
||||
x = norm(x)
|
||||
|
||||
if layer_type == 'a':
|
||||
out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
|
||||
prev_attn=prev_attn, mem=layer_mem)
|
||||
elif layer_type == 'c':
|
||||
out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
|
||||
elif layer_type == 'f':
|
||||
out = block(x)
|
||||
|
||||
x = residual_fn(out, residual)
|
||||
|
||||
if layer_type in ('a', 'c'):
|
||||
intermediates.append(inter)
|
||||
|
||||
if layer_type == 'a' and self.residual_attn:
|
||||
prev_attn = inter.pre_softmax_attn
|
||||
elif layer_type == 'c' and self.cross_residual_attn:
|
||||
prev_cross_attn = inter.pre_softmax_attn
|
||||
|
||||
if not self.pre_norm and not is_last:
|
||||
x = norm(x)
|
||||
|
||||
if return_hiddens:
|
||||
intermediates = LayerIntermediates(
|
||||
hiddens=hiddens,
|
||||
attn_intermediates=intermediates
|
||||
)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(AttentionLayers):
|
||||
def __init__(self, **kwargs):
|
||||
assert 'causal' not in kwargs, 'cannot set causality on encoder'
|
||||
super().__init__(causal=False, **kwargs)
|
||||
|
||||
|
||||
|
||||
class TransformerWrapper(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_tokens,
|
||||
max_seq_len,
|
||||
attn_layers,
|
||||
emb_dim=None,
|
||||
max_mem_len=0.,
|
||||
emb_dropout=0.,
|
||||
num_memory_tokens=None,
|
||||
tie_embedding=False,
|
||||
use_pos_emb=True
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
|
||||
|
||||
dim = attn_layers.dim
|
||||
emb_dim = default(emb_dim, dim)
|
||||
|
||||
self.max_seq_len = max_seq_len
|
||||
self.max_mem_len = max_mem_len
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
self.token_emb = nn.Embedding(num_tokens, emb_dim)
|
||||
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
|
||||
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
||||
self.emb_dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
||||
self.attn_layers = attn_layers
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.init_()
|
||||
|
||||
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
|
||||
|
||||
# memory tokens (like [cls]) from Memory Transformers paper
|
||||
num_memory_tokens = default(num_memory_tokens, 0)
|
||||
self.num_memory_tokens = num_memory_tokens
|
||||
if num_memory_tokens > 0:
|
||||
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
||||
|
||||
# let funnel encoder know number of memory tokens, if specified
|
||||
if hasattr(attn_layers, 'num_memory_tokens'):
|
||||
attn_layers.num_memory_tokens = num_memory_tokens
|
||||
|
||||
def init_(self):
|
||||
nn.init.normal_(self.token_emb.weight, std=0.02)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
return_embeddings=False,
|
||||
mask=None,
|
||||
return_mems=False,
|
||||
return_attn=False,
|
||||
mems=None,
|
||||
**kwargs
|
||||
):
|
||||
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
|
||||
x = self.token_emb(x)
|
||||
x += self.pos_emb(x)
|
||||
x = self.emb_dropout(x)
|
||||
|
||||
x = self.project_emb(x)
|
||||
|
||||
if num_mem > 0:
|
||||
mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
|
||||
x = torch.cat((mem, x), dim=1)
|
||||
|
||||
# auto-handle masking after appending memory tokens
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (num_mem, 0), value=True)
|
||||
|
||||
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
|
||||
x = self.norm(x)
|
||||
|
||||
mem, x = x[:, :num_mem], x[:, num_mem:]
|
||||
|
||||
out = self.to_logits(x) if not return_embeddings else x
|
||||
|
||||
if return_mems:
|
||||
hiddens = intermediates.hiddens
|
||||
new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
|
||||
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
|
||||
return out, new_mems
|
||||
|
||||
if return_attn:
|
||||
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
||||
return out, attn_maps
|
||||
|
||||
return out
|
||||
|
||||
121
ldm/thirdp/psp/helpers.py
Executable file
121
ldm/thirdp/psp/helpers.py
Executable file
@@ -0,0 +1,121 @@
|
||||
# https://github.com/eladrich/pixel2style2pixel
|
||||
|
||||
from collections import namedtuple
|
||||
import torch
|
||||
from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
|
||||
|
||||
"""
|
||||
ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
||||
"""
|
||||
|
||||
|
||||
class Flatten(Module):
|
||||
def forward(self, input):
|
||||
return input.view(input.size(0), -1)
|
||||
|
||||
|
||||
def l2_norm(input, axis=1):
|
||||
norm = torch.norm(input, 2, axis, True)
|
||||
output = torch.div(input, norm)
|
||||
return output
|
||||
|
||||
|
||||
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
|
||||
""" A named tuple describing a ResNet block. """
|
||||
|
||||
|
||||
def get_block(in_channel, depth, num_units, stride=2):
|
||||
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
|
||||
|
||||
|
||||
def get_blocks(num_layers):
|
||||
if num_layers == 50:
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, num_units=3),
|
||||
get_block(in_channel=64, depth=128, num_units=4),
|
||||
get_block(in_channel=128, depth=256, num_units=14),
|
||||
get_block(in_channel=256, depth=512, num_units=3)
|
||||
]
|
||||
elif num_layers == 100:
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, num_units=3),
|
||||
get_block(in_channel=64, depth=128, num_units=13),
|
||||
get_block(in_channel=128, depth=256, num_units=30),
|
||||
get_block(in_channel=256, depth=512, num_units=3)
|
||||
]
|
||||
elif num_layers == 152:
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, num_units=3),
|
||||
get_block(in_channel=64, depth=128, num_units=8),
|
||||
get_block(in_channel=128, depth=256, num_units=36),
|
||||
get_block(in_channel=256, depth=512, num_units=3)
|
||||
]
|
||||
else:
|
||||
raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
|
||||
return blocks
|
||||
|
||||
|
||||
class SEModule(Module):
|
||||
def __init__(self, channels, reduction):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = AdaptiveAvgPool2d(1)
|
||||
self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
|
||||
self.relu = ReLU(inplace=True)
|
||||
self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
|
||||
self.sigmoid = Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
module_input = x
|
||||
x = self.avg_pool(x)
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.sigmoid(x)
|
||||
return module_input * x
|
||||
|
||||
|
||||
class bottleneck_IR(Module):
|
||||
def __init__(self, in_channel, depth, stride):
|
||||
super(bottleneck_IR, self).__init__()
|
||||
if in_channel == depth:
|
||||
self.shortcut_layer = MaxPool2d(1, stride)
|
||||
else:
|
||||
self.shortcut_layer = Sequential(
|
||||
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
||||
BatchNorm2d(depth)
|
||||
)
|
||||
self.res_layer = Sequential(
|
||||
BatchNorm2d(in_channel),
|
||||
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
|
||||
Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = self.shortcut_layer(x)
|
||||
res = self.res_layer(x)
|
||||
return res + shortcut
|
||||
|
||||
|
||||
class bottleneck_IR_SE(Module):
|
||||
def __init__(self, in_channel, depth, stride):
|
||||
super(bottleneck_IR_SE, self).__init__()
|
||||
if in_channel == depth:
|
||||
self.shortcut_layer = MaxPool2d(1, stride)
|
||||
else:
|
||||
self.shortcut_layer = Sequential(
|
||||
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
||||
BatchNorm2d(depth)
|
||||
)
|
||||
self.res_layer = Sequential(
|
||||
BatchNorm2d(in_channel),
|
||||
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
||||
PReLU(depth),
|
||||
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
||||
BatchNorm2d(depth),
|
||||
SEModule(depth, 16)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = self.shortcut_layer(x)
|
||||
res = self.res_layer(x)
|
||||
return res + shortcut
|
||||
23
ldm/thirdp/psp/id_loss.py
Executable file
23
ldm/thirdp/psp/id_loss.py
Executable file
@@ -0,0 +1,23 @@
|
||||
# https://github.com/eladrich/pixel2style2pixel
|
||||
import torch
|
||||
from torch import nn
|
||||
from ldm.thirdp.psp.model_irse import Backbone
|
||||
|
||||
|
||||
class IDFeatures(nn.Module):
|
||||
def __init__(self, model_path):
|
||||
super(IDFeatures, self).__init__()
|
||||
print('Loading ResNet ArcFace')
|
||||
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
|
||||
self.facenet.load_state_dict(torch.load(model_path, map_location="cpu"))
|
||||
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
||||
self.facenet.eval()
|
||||
|
||||
def forward(self, x, crop=False):
|
||||
# Not sure of the image range here
|
||||
if crop:
|
||||
x = torch.nn.functional.interpolate(x, (256, 256), mode="area")
|
||||
x = x[:, :, 35:223, 32:220]
|
||||
x = self.face_pool(x)
|
||||
x_feats = self.facenet(x)
|
||||
return x_feats
|
||||
86
ldm/thirdp/psp/model_irse.py
Executable file
86
ldm/thirdp/psp/model_irse.py
Executable file
@@ -0,0 +1,86 @@
|
||||
# https://github.com/eladrich/pixel2style2pixel
|
||||
|
||||
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
|
||||
from ldm.thirdp.psp.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
|
||||
|
||||
"""
|
||||
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
||||
"""
|
||||
|
||||
|
||||
class Backbone(Module):
|
||||
def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
|
||||
super(Backbone, self).__init__()
|
||||
assert input_size in [112, 224], "input_size should be 112 or 224"
|
||||
assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
|
||||
assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
|
||||
blocks = get_blocks(num_layers)
|
||||
if mode == 'ir':
|
||||
unit_module = bottleneck_IR
|
||||
elif mode == 'ir_se':
|
||||
unit_module = bottleneck_IR_SE
|
||||
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
||||
BatchNorm2d(64),
|
||||
PReLU(64))
|
||||
if input_size == 112:
|
||||
self.output_layer = Sequential(BatchNorm2d(512),
|
||||
Dropout(drop_ratio),
|
||||
Flatten(),
|
||||
Linear(512 * 7 * 7, 512),
|
||||
BatchNorm1d(512, affine=affine))
|
||||
else:
|
||||
self.output_layer = Sequential(BatchNorm2d(512),
|
||||
Dropout(drop_ratio),
|
||||
Flatten(),
|
||||
Linear(512 * 14 * 14, 512),
|
||||
BatchNorm1d(512, affine=affine))
|
||||
|
||||
modules = []
|
||||
for block in blocks:
|
||||
for bottleneck in block:
|
||||
modules.append(unit_module(bottleneck.in_channel,
|
||||
bottleneck.depth,
|
||||
bottleneck.stride))
|
||||
self.body = Sequential(*modules)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.input_layer(x)
|
||||
x = self.body(x)
|
||||
x = self.output_layer(x)
|
||||
return l2_norm(x)
|
||||
|
||||
|
||||
def IR_50(input_size):
|
||||
"""Constructs a ir-50 model."""
|
||||
model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_101(input_size):
|
||||
"""Constructs a ir-101 model."""
|
||||
model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_152(input_size):
|
||||
"""Constructs a ir-152 model."""
|
||||
model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_SE_50(input_size):
|
||||
"""Constructs a ir_se-50 model."""
|
||||
model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_SE_101(input_size):
|
||||
"""Constructs a ir_se-101 model."""
|
||||
model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_SE_152(input_size):
|
||||
"""Constructs a ir_se-152 model."""
|
||||
model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
227
ldm/util.py
Executable file
227
ldm/util.py
Executable file
@@ -0,0 +1,227 @@
|
||||
import importlib
|
||||
|
||||
import torchvision
|
||||
import torch
|
||||
from torch import optim
|
||||
import numpy as np
|
||||
|
||||
from inspect import isfunction
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image
|
||||
import torch
|
||||
import time
|
||||
import cv2
|
||||
|
||||
import PIL
|
||||
|
||||
def pil_rectangle_crop(im):
|
||||
width, height = im.size # Get dimensions
|
||||
|
||||
if width <= height:
|
||||
left = 0
|
||||
right = width
|
||||
top = (height - width)/2
|
||||
bottom = (height + width)/2
|
||||
else:
|
||||
|
||||
top = 0
|
||||
bottom = height
|
||||
left = (width - height) / 2
|
||||
bottom = (width + height) / 2
|
||||
|
||||
# Crop the center of the image
|
||||
im = im.crop((left, top, right, bottom))
|
||||
return im
|
||||
|
||||
|
||||
def log_txt_as_img(wh, xc, size=10):
|
||||
# wh a tuple of (width, height)
|
||||
# xc a list of captions to plot
|
||||
b = len(xc)
|
||||
txts = list()
|
||||
for bi in range(b):
|
||||
txt = Image.new("RGB", wh, color="white")
|
||||
draw = ImageDraw.Draw(txt)
|
||||
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
|
||||
nc = int(40 * (wh[0] / 256))
|
||||
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
|
||||
|
||||
try:
|
||||
draw.text((0, 0), lines, fill="black", font=font)
|
||||
except UnicodeEncodeError:
|
||||
print("Cant encode string for logging. Skipping.")
|
||||
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
txts = np.stack(txts)
|
||||
txts = torch.tensor(txts)
|
||||
return txts
|
||||
|
||||
|
||||
def ismap(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
||||
|
||||
|
||||
def isimage(x):
|
||||
if not isinstance(x,torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
||||
|
||||
|
||||
def exists(x):
|
||||
return x is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
||||
return total_params
|
||||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if not "target" in config:
|
||||
if config == '__is_first_stage__':
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
return None
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
class AdamWwithEMAandWings(optim.Optimizer):
|
||||
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
|
||||
def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
|
||||
weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
|
||||
ema_power=1., param_names=()):
|
||||
"""AdamW that saves EMA versions of the parameters."""
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
if not 0.0 <= ema_decay <= 1.0:
|
||||
raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
|
||||
ema_power=ema_power, param_names=param_names)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('amsgrad', False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
Args:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
exp_avgs = []
|
||||
exp_avg_sqs = []
|
||||
ema_params_with_grad = []
|
||||
state_sums = []
|
||||
max_exp_avg_sqs = []
|
||||
state_steps = []
|
||||
amsgrad = group['amsgrad']
|
||||
beta1, beta2 = group['betas']
|
||||
ema_decay = group['ema_decay']
|
||||
ema_power = group['ema_power']
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
params_with_grad.append(p)
|
||||
if p.grad.is_sparse:
|
||||
raise RuntimeError('AdamW does not support sparse gradients')
|
||||
grads.append(p.grad)
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
if amsgrad:
|
||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
# Exponential moving average of parameter values
|
||||
state['param_exp_avg'] = p.detach().float().clone()
|
||||
|
||||
exp_avgs.append(state['exp_avg'])
|
||||
exp_avg_sqs.append(state['exp_avg_sq'])
|
||||
ema_params_with_grad.append(state['param_exp_avg'])
|
||||
|
||||
if amsgrad:
|
||||
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
|
||||
|
||||
# update the steps for each param group update
|
||||
state['step'] += 1
|
||||
# record the step after step update
|
||||
state_steps.append(state['step'])
|
||||
|
||||
optim._functional.adamw(params_with_grad,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
max_exp_avg_sqs,
|
||||
state_steps,
|
||||
amsgrad=amsgrad,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=group['lr'],
|
||||
weight_decay=group['weight_decay'],
|
||||
eps=group['eps'],
|
||||
maximize=False)
|
||||
|
||||
cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
|
||||
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
|
||||
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
|
||||
|
||||
return loss
|
||||
630
main.py
Normal file
630
main.py
Normal file
@@ -0,0 +1,630 @@
|
||||
import torch
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
import pandas as pd
|
||||
|
||||
from nerf.provider import NeRFDataset, generate_grid_points
|
||||
from nerf.utils import *
|
||||
|
||||
import yaml
|
||||
from easydict import EasyDict as edict
|
||||
import dnnultis
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# The first arg parser parses out only the --config argument, this argument is used to
|
||||
# load a yaml file containing key-values that override the defaults for the main parser below
|
||||
config_parser = parser = argparse.ArgumentParser(
|
||||
description='Training Config', add_help=False)
|
||||
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
|
||||
help='YAML config file specifying default arguments')
|
||||
parser = argparse.ArgumentParser(description='3D AIGC Training')
|
||||
parser.add_argument('--workspace', type=str, default='', help='path to log')
|
||||
parser.add_argument('--text', default=None, help="text prompt")
|
||||
parser.add_argument('--negative', default='', type=str,
|
||||
help="negative text prompt")
|
||||
parser.add_argument('--dir_texts_neg', action='store_true',
|
||||
help="enable negative directional text")
|
||||
parser.add_argument('--check_prompt', action='store_true', help="check prompt")
|
||||
parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray")
|
||||
parser.add_argument('-O2', action='store_true',
|
||||
help="equals --backbone vanilla")
|
||||
parser.add_argument('--test', action='store_true', help="test mode")
|
||||
parser.add_argument('--six_views', action='store_true',
|
||||
help="six_views mode: save the images of the six views")
|
||||
parser.add_argument('--eval_interval', type=int, default=1,
|
||||
help="evaluate on the valid set every interval epochs")
|
||||
parser.add_argument('--test_interval', type=int, default=50,
|
||||
help="test on the test set every interval epochs")
|
||||
parser.add_argument('--seed', type=int, default=101)
|
||||
parser.add_argument('--log_every', type=int, default=20,
|
||||
help="log losses every X iterations")
|
||||
parser.add_argument('--use_wandb', action='store_true',
|
||||
help="log online into wandb")
|
||||
|
||||
# guidance
|
||||
parser.add_argument('--guidance', type=str, nargs='*',
|
||||
default=['SD'], help='guidance model')
|
||||
parser.add_argument('--guidance_scale', type=float, nargs='*', default=[100],
|
||||
help="diffusion model classifier-free guidance scale")
|
||||
parser.add_argument('--gudiance_spatial_weighting',
|
||||
action='store_true', help="add spatial weighting to guidance")
|
||||
parser.add_argument('--save_train_every', type=int,
|
||||
default=-1, help="save sds guidance")
|
||||
|
||||
# clip guidance
|
||||
# lambda_clip, set to 1 if use clip loss outside sds
|
||||
parser.add_argument('--lambda_clip', type=float, default=0,
|
||||
help="loss scale for clip loss outside sds")
|
||||
# set to 100 if use clip guidance in sds
|
||||
parser.add_argument('--clip_version', type=str,
|
||||
default='large', help="clip version, large is ued in stable diffusion")
|
||||
parser.add_argument('--clip_guidance', type=float, default=0,
|
||||
help="diffusion model classifier-free guidance scale")
|
||||
parser.add_argument('--clip_t', type=float, default=0.4,
|
||||
help="time step thresh started to use clip")
|
||||
parser.add_argument('--clip_iterative', action='store_true',
|
||||
help="use clipd iteratively with sds")
|
||||
parser.add_argument('--clip_image_loss', action='store_true',
|
||||
help="use image as reference in clip")
|
||||
parser.add_argument('--save_guidance_every', type=int,
|
||||
default=-1, help="save sds guidance")
|
||||
|
||||
# 3D prior: Shap-E. Does not work.
|
||||
parser.add_argument('--use_shape', action='store_true',
|
||||
help="enable shap-e initization")
|
||||
parser.add_argument('--shape_guidance', type=float, default=3,
|
||||
help="guidance scaling for shap-e prior")
|
||||
parser.add_argument('--shape_radius', type=float, default=4,
|
||||
help="camera raidus for shap-e prior")
|
||||
parser.add_argument('--shape_fovy', type=float, default=40,
|
||||
help="fov for shap-e prior")
|
||||
parser.add_argument('--shape_no_color', action='store_false',
|
||||
dest='shape_init_color', help="do not use shap-E color for initization")
|
||||
parser.add_argument('--shape_rpst', type=str, default='sdf',
|
||||
help="use sdf to init NeRF/mesh by default")
|
||||
|
||||
# image options.
|
||||
parser.add_argument('--image', default=None, help="image prompt")
|
||||
parser.add_argument('--image_config', default=None, help="image config csv")
|
||||
parser.add_argument('--learned_embeds_path', type=str,
|
||||
default=None, help="path to learned embeds of the given image")
|
||||
parser.add_argument('--known_iters', type=int, default=100,
|
||||
help="loss scale for alpha entropy")
|
||||
parser.add_argument('--known_view_interval', type=int, default=4,
|
||||
help="do reconstruction every X iterations to save on compute")
|
||||
parser.add_argument('--bg_color_known', type=str,
|
||||
default=None, help='pixelnoise, noise, None') # pixelnoise
|
||||
parser.add_argument('--known_shading', type=str, default='lambertian')
|
||||
|
||||
# DMTet and Mesh options
|
||||
parser.add_argument('--save_mesh', action='store_true',
|
||||
help="export an obj mesh with texture")
|
||||
parser.add_argument('--mcubes_resolution', type=int, default=256,
|
||||
help="mcubes resolution for extracting mesh")
|
||||
parser.add_argument('--decimate_target', type=int, default=5e4,
|
||||
help="target face number for mesh decimation")
|
||||
parser.add_argument('--dmtet', action='store_true',
|
||||
help="use dmtet finetuning")
|
||||
parser.add_argument('--tet_mlp', action='store_true',
|
||||
help="use tet_mlp finetuning")
|
||||
parser.add_argument('--base_mesh', default=None,
|
||||
help="base mesh for dmtet init")
|
||||
parser.add_argument('--tet_grid_size', type=int,
|
||||
default=256, help="tet grid size")
|
||||
parser.add_argument('--init_ckpt', type=str, default='',
|
||||
help="ckpt to init dmtet")
|
||||
parser.add_argument('--lock_geo', action='store_true',
|
||||
help="disable dmtet to learn geometry")
|
||||
|
||||
# training options
|
||||
parser.add_argument('--iters', type=int, default=5000, help="training iters")
|
||||
parser.add_argument('--lr', type=float, default=1e-3, help="max learning rate")
|
||||
parser.add_argument('--lr_scale_nerf', type=float,
|
||||
default=1, help="max learning rate")
|
||||
parser.add_argument('--lr_scale_texture', type=float,
|
||||
default=1, help="max learning rate")
|
||||
parser.add_argument('--ckpt', type=str, default='latest')
|
||||
parser.add_argument('--cuda_ray', action='store_true',
|
||||
help="use CUDA raymarching instead of pytorch")
|
||||
parser.add_argument('--taichi_ray', action='store_true',
|
||||
help="use taichi raymarching")
|
||||
parser.add_argument('--max_steps', type=int, default=1024,
|
||||
help="max num steps sampled per ray (only valid when using --cuda_ray)")
|
||||
parser.add_argument('--num_steps', type=int, default=64,
|
||||
help="num steps sampled per ray (only valid when not using --cuda_ray)")
|
||||
parser.add_argument('--upsample_steps', type=int, default=32,
|
||||
help="num steps up-sampled per ray (only valid when not using --cuda_ray)")
|
||||
parser.add_argument('--update_extra_interval', type=int, default=16,
|
||||
help="iter interval to update extra status (only valid when using --cuda_ray)")
|
||||
parser.add_argument('--max_ray_batch', type=int, default=4096,
|
||||
help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)")
|
||||
parser.add_argument('--latent_iter_ratio', type=float, default=0.0,
|
||||
help="training iters that only use latent normal shading")
|
||||
parser.add_argument('--normal_iter_ratio', type=float, default=0.0,
|
||||
help="training iters that only use normal shading")
|
||||
parser.add_argument('--textureless_iter_ratio', type=float, default=0.0,
|
||||
help="training iters that only use textureless shading")
|
||||
parser.add_argument('--albedo_iter_ratio', type=float, default=0,
|
||||
help="training iters that only use albedo shading")
|
||||
parser.add_argument('--warmup_bg_color', type=str, default=None,
|
||||
help="bg color [None | pixelnoise | noise | white]")
|
||||
parser.add_argument('--bg_color', type=str, default=None)
|
||||
parser.add_argument('--bg_color_test', default='white')
|
||||
parser.add_argument('--ema_decay', type=float, default=0.95,
|
||||
help="exponential moving average of model weights")
|
||||
parser.add_argument('--jitter_pose', action='store_true',
|
||||
help="add jitters to the randomly sampled camera poses")
|
||||
parser.add_argument('--jitter_center', type=float, default=0.2,
|
||||
help="amount of jitter to add to sampled camera pose's center (camera location)")
|
||||
parser.add_argument('--jitter_target', type=float, default=0.2,
|
||||
help="amount of jitter to add to sampled camera pose's target (i.e. 'look-at')")
|
||||
parser.add_argument('--jitter_up', type=float, default=0.02,
|
||||
help="amount of jitter to add to sampled camera pose's up-axis (i.e. 'camera roll')")
|
||||
parser.add_argument('--uniform_sphere_rate', type=float, default=0.5,
|
||||
help="likelihood of sampling camera location uniformly on the sphere surface area")
|
||||
parser.add_argument('--grad_clip', type=float, default=-1,
|
||||
help="clip grad of all grad to this limit, negative value disables it")
|
||||
parser.add_argument('--grad_clip_rgb', type=float, default=-1,
|
||||
help="clip grad of rgb space grad to this limit, negative value disables it")
|
||||
parser.add_argument('--grid_levels_mask', type=int, default=8,
|
||||
help="the number of levels in the feature grid to mask (to disable use 0)")
|
||||
parser.add_argument('--grid_levels_mask_iters', type=int, default=3000,
|
||||
help="the number of iterations for feature grid masking (to disable use 0)")
|
||||
|
||||
# model options
|
||||
parser.add_argument('--bg_radius', type=float, default=1.4,
|
||||
help="if positive, use a background model at sphere(bg_radius)")
|
||||
parser.add_argument('--density_activation', type=str, default='exp',
|
||||
choices=['softplus', 'exp', 'relu'], help="density activation function")
|
||||
parser.add_argument('--density_thresh', type=float, default=10,
|
||||
help="threshold for density grid to be occupied")
|
||||
# add more strength to the center, believe the center is more likely to have objects.
|
||||
parser.add_argument('--blob_density', type=float, default=10,
|
||||
help="max (center) density for the density blob")
|
||||
parser.add_argument('--blob_radius', type=float, default=0.2,
|
||||
help="control the radius for the density blob")
|
||||
# network backbone
|
||||
parser.add_argument('--backbone', type=str, default='grid',
|
||||
choices=['grid', 'vanilla', 'grid_taichi'], help="nerf backbone")
|
||||
parser.add_argument('--grid_type', type=str,
|
||||
default='hashgrid', help="grid type")
|
||||
parser.add_argument('--hidden_dim_bg', type=int, default=32,
|
||||
help="channels for background network")
|
||||
parser.add_argument('--optim', type=str, default='adam',
|
||||
choices=['adan', 'adam'], help="optimizer")
|
||||
parser.add_argument('--sd_version', type=str, default='1.5',
|
||||
choices=['1.5', '2.0', '2.1'], help="stable diffusion version")
|
||||
parser.add_argument('--hf_key', type=str, default=None,
|
||||
help="hugging face Stable diffusion model key")
|
||||
# try this if CUDA OOM
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="use float16 for training")
|
||||
parser.add_argument('--vram_O', action='store_true',
|
||||
help="optimization for low VRAM usage")
|
||||
# rendering resolution in training, increase these for better quality / decrease these if CUDA OOM even if --vram_O enabled.
|
||||
parser.add_argument('--w', type=int, default=128,
|
||||
help="render width for NeRF in training")
|
||||
parser.add_argument('--h', type=int, default=128,
|
||||
help="render height for NeRF in training")
|
||||
parser.add_argument('--known_view_scale', type=float, default=1.5,
|
||||
help="multiply --h/w by this for known view rendering")
|
||||
parser.add_argument('--known_view_noise_scale', type=float, default=1e-3,
|
||||
help="random camera noise added to rays_o and rays_d")
|
||||
parser.add_argument('--noise_known_camera_annealing', action='store_true',
|
||||
help="anneal the noise to zero over the coarse of training")
|
||||
parser.add_argument('--dmtet_reso_scale', type=float, default=8,
|
||||
help="multiply --h/w by this for dmtet finetuning")
|
||||
parser.add_argument('--rm_edge', action='store_true',
|
||||
help="remove edge (ideally only enale for high resolution cases)")
|
||||
parser.add_argument('--edge_threshold', type=float, default=0.1,
|
||||
help="remove edges with value > threshold")
|
||||
parser.add_argument('--edge_width', type=float, default=5,
|
||||
help="edge width")
|
||||
parser.add_argument('--batch_size', type=int, default=1,
|
||||
help="images to render per batch using NeRF")
|
||||
|
||||
# dataset options
|
||||
parser.add_argument('--bound', type=float, default=1.0,
|
||||
help="assume the scene is bounded in box(-bound, bound)")
|
||||
parser.add_argument('--dt_gamma', type=float, default=0,
|
||||
help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
|
||||
parser.add_argument('--min_near', type=float, default=0.1,
|
||||
help="minimum near distance for camera")
|
||||
|
||||
parser.add_argument('--radius_range', type=float, nargs='*',
|
||||
default=[1.8, 1.8], help="training camera radius range")
|
||||
parser.add_argument('--theta_range', type=float, nargs='*',
|
||||
default=[45, 135], help="training camera elevation/polar range, 90 is front")
|
||||
parser.add_argument('--phi_range', type=float, nargs='*',
|
||||
default=[-180, 180], help="training camera azimuth range")
|
||||
parser.add_argument('--fovy_range', type=float, nargs='*',
|
||||
default=[40, 40], help="training camera fovy range")
|
||||
|
||||
parser.add_argument('--default_radius', type=float, default=1.8,
|
||||
help="radius for the default view")
|
||||
parser.add_argument('--default_polar', type=float,
|
||||
default=90, help="polar for the default view")
|
||||
parser.add_argument('--default_azimuth', type=float,
|
||||
default=0, help="azimuth for the default view")
|
||||
parser.add_argument('--default_fovy', type=float, default=40,
|
||||
help="fovy for the default view")
|
||||
|
||||
parser.add_argument('--progressive_view', action='store_true',
|
||||
help="progressively expand view sampling range from default to full")
|
||||
parser.add_argument('--progressive_level', action='store_true',
|
||||
help="progressively increase gridencoder's max_level")
|
||||
|
||||
parser.add_argument('--angle_overhead', type=float, default=30,
|
||||
help="[0, angle_overhead] is the overhead region")
|
||||
parser.add_argument('--angle_front', type=float, default=60,
|
||||
help="[0, angle_front] is the front region, [180, 180+angle_front] the back region, otherwise the side region.")
|
||||
parser.add_argument('--t_range', type=float, nargs='*',
|
||||
default=[0.2, 0.6], help="stable diffusion time steps range")
|
||||
|
||||
# regularizations
|
||||
parser.add_argument('--lambda_entropy', type=float, default=1e-3,
|
||||
help="loss scale for alpha entropy, favors 0 or 1")
|
||||
# Try increasing/decreasing lambda_opacity if your scene is stuffed with floaters/becoming empty.
|
||||
parser.add_argument('--lambda_opacity', type=float, default=0.,
|
||||
help="loss scale for alpha value, avoid uncessary filling")
|
||||
# Try increasing/decreasing lambda_orient if you object is foggy/over-smoothed.
|
||||
parser.add_argument('--lambda_orient', type=float,
|
||||
default=1e-2, help="loss scale for orientation")
|
||||
parser.add_argument('--lambda_tv', type=float, default=0,
|
||||
help="loss scale for total variation of grad")
|
||||
parser.add_argument('--lambda_wd', type=float, default=0,
|
||||
help="loss scale for weight decay of grad")
|
||||
parser.add_argument('--lambda_normal_smooth', type=float, default=0.5,
|
||||
help="loss scale for first-order 2D normal image smoothness")
|
||||
parser.add_argument('--lambda_normal_smooth2d', type=float, default=0.5,
|
||||
help="loss scale for second-order 2D normal image smoothness")
|
||||
parser.add_argument('--lambda_3d_normal_smooth', type=float, default=0.0,
|
||||
help="loss scale for second-order 2D normal image smoothness")
|
||||
parser.add_argument('--lambda_guidance', type=float, nargs='*',
|
||||
default=[1], help="loss scale for SDS")
|
||||
parser.add_argument('--lambda_rgb', type=float,
|
||||
default=5, help="loss scale for RGB")
|
||||
parser.add_argument('--lambda_mask', type=float, default=0.5,
|
||||
help="loss scale for mask (alpha)")
|
||||
parser.add_argument('--lambda_depth', type=float, default=0.01,
|
||||
help="loss scale for relative depth of the known view")
|
||||
parser.add_argument('--lambda_normal', type=float,
|
||||
default=0.0, help="loss scale for normals of the known view")
|
||||
parser.add_argument('--lambda_depth_mse', type=float, default=0.0,
|
||||
help="loss scale for depth of the known view")
|
||||
parser.add_argument('--no_normalize_depth', action='store_false', dest='normalize_depth', help="normalize depth")
|
||||
|
||||
# for DMTet
|
||||
parser.add_argument('--lambda_mesh_normal', type=float,
|
||||
default=0.1, help="loss scale for mesh normal smoothness")
|
||||
parser.add_argument('--lambda_mesh_lap', type=float,
|
||||
default=0.1, help="loss scale for mesh laplacian")
|
||||
|
||||
# GUI options
|
||||
parser.add_argument('--gui', action='store_true', help="start a GUI")
|
||||
parser.add_argument('--W', type=int, default=800, help="GUI width")
|
||||
parser.add_argument('--H', type=int, default=800, help="GUI height")
|
||||
parser.add_argument('--radius', type=float, default=1.8,
|
||||
help="default GUI camera radius from center")
|
||||
parser.add_argument('--fovy', type=float, default=40,
|
||||
help="default GUI camera fovy")
|
||||
parser.add_argument('--light_theta', type=float, default=60,
|
||||
help="default GUI light direction in [0, 180], corresponding to elevation [90, -90]")
|
||||
parser.add_argument('--light_phi', type=float, default=0,
|
||||
help="default GUI light direction in [0, 360), azimuth")
|
||||
parser.add_argument('--max_spp', type=int, default=1,
|
||||
help="GUI rendering max sample per pixel")
|
||||
parser.add_argument('--zero123_config', type=str,
|
||||
default='./pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml', help="config file for zero123")
|
||||
parser.add_argument('--zero123_ckpt', type=str,
|
||||
default='./pretrained/zero123/105000.ckpt', help="ckpt for zero123")
|
||||
parser.add_argument('--zero123_grad_scale', type=str, default='angle',
|
||||
help="whether to scale the gradients based on 'angle' or 'None'")
|
||||
|
||||
parser.add_argument('--dataset_size_train', type=int, default=100,
|
||||
help="Length of train dataset i.e. # of iterations per epoch")
|
||||
parser.add_argument('--dataset_size_valid', type=int, default=8,
|
||||
help="# of frames to render in the turntable video in validation")
|
||||
parser.add_argument('--dataset_size_test', type=int, default=100,
|
||||
help="# of frames to render in the turntable video at test time")
|
||||
|
||||
|
||||
def _parse_args():
|
||||
args_config, remaining = config_parser.parse_known_args()
|
||||
if args_config.config:
|
||||
with open(args_config.config, 'r') as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
parser.set_defaults(**cfg)
|
||||
|
||||
# The main arg parser parses the rest of the args, the usual
|
||||
# defaults will have been overridden if config file specified.
|
||||
args = parser.parse_args(remaining)
|
||||
|
||||
# Cache the args as a text string to save them in the output dir later
|
||||
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
|
||||
return args, args_text
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args, args_text = _parse_args()
|
||||
opt = edict(vars(args))
|
||||
|
||||
if opt.O:
|
||||
opt.fp16 = True
|
||||
opt.cuda_ray = True
|
||||
|
||||
elif opt.O2:
|
||||
opt.fp16 = True
|
||||
opt.backbone = 'vanilla'
|
||||
|
||||
opt.images, opt.ref_radii, opt.ref_polars, opt.ref_azimuths, opt.zero123_ws = [], [], [], [], []
|
||||
opt.default_zero123_w = 1
|
||||
|
||||
# parameters for image-conditioned generation
|
||||
if opt.image is not None or opt.image_config is not None:
|
||||
if 'zero123' in opt.guidance:
|
||||
# fix fov as zero123 doesn't support changing fov
|
||||
opt.fovy_range = [opt.default_fovy, opt.default_fovy]
|
||||
else:
|
||||
opt.known_view_interval = 2
|
||||
|
||||
if 'SD' in opt.guidance:
|
||||
opt.t_range = [0.2, 0.6]
|
||||
opt.bg_radius = -1
|
||||
|
||||
# latent warmup is not needed
|
||||
opt.latent_iter_ratio = 0
|
||||
opt.albedo_iter_ratio = 0
|
||||
|
||||
if opt.image is not None:
|
||||
opt.images += [opt.image]
|
||||
opt.ref_radii += [opt.default_radius]
|
||||
opt.ref_polars += [opt.default_polar]
|
||||
opt.ref_azimuths += [opt.default_azimuth]
|
||||
opt.zero123_ws += [opt.default_zero123_w]
|
||||
|
||||
if opt.image_config is not None:
|
||||
# for multiview (zero123)
|
||||
conf = pd.read_csv(opt.image_config, skipinitialspace=True)
|
||||
opt.images += list(conf.image)
|
||||
opt.ref_radii += list(conf.radius)
|
||||
opt.ref_polars += list(conf.polar)
|
||||
opt.ref_azimuths += list(conf.azimuth)
|
||||
opt.zero123_ws += list(conf.zero123_weight)
|
||||
if opt.image is None:
|
||||
opt.default_radius = opt.ref_radii[0]
|
||||
opt.default_polar = opt.ref_polars[0]
|
||||
opt.default_azimuth = opt.ref_azimuths[0]
|
||||
opt.default_zero123_w = opt.zero123_ws[0]
|
||||
|
||||
# reset to None
|
||||
if len(opt.images) == 0:
|
||||
opt.images = None
|
||||
|
||||
# default parameters for finetuning
|
||||
if opt.dmtet:
|
||||
opt.h = int(opt.h * opt.dmtet_reso_scale)
|
||||
opt.w = int(opt.w * opt.dmtet_reso_scale)
|
||||
opt.known_view_scale = 1
|
||||
opt.grid_levels_mask = -1 # disable corse nerf (fine to keep, not necesary)
|
||||
opt.t_range = [0.02, 0.50] # ref: magic3D
|
||||
|
||||
if opt.images is not None:
|
||||
opt.lambda_normal = 0
|
||||
opt.lambda_depth = 0
|
||||
|
||||
# assume finetuning
|
||||
opt.latent_iter_ratio = 0
|
||||
opt.textureless_iter_ratio = 0
|
||||
opt.albedo_iter_ratio = 0
|
||||
opt.normal_iter_ratio = 0
|
||||
opt.progressive_view = False
|
||||
opt.progressive_level = False
|
||||
|
||||
# record full range for progressive view expansion
|
||||
if opt.progressive_view:
|
||||
# disable as they disturb progressive view
|
||||
opt.jitter_pose = False
|
||||
opt.uniform_sphere_rate = 0
|
||||
# back up full range
|
||||
opt.full_radius_range = opt.radius_range
|
||||
opt.full_theta_range = opt.theta_range
|
||||
opt.full_phi_range = opt.phi_range
|
||||
opt.full_fovy_range = opt.fovy_range
|
||||
|
||||
opt.use_clip = opt.clip_guidance > 0 or opt.lambda_clip > 0
|
||||
# Do not support Shap-E for NeRF yet.
|
||||
opt.use_shape = False if not opt.dmtet else opt.use_shape
|
||||
|
||||
# workspace prepare
|
||||
setup_workspace(opt)
|
||||
dnnultis.setup_logging(opt.log_path)
|
||||
|
||||
if opt.seed < 0:
|
||||
opt.seed = random.randint(0, 10000)
|
||||
seed_everything(int(opt.seed))
|
||||
|
||||
if opt.backbone == 'vanilla':
|
||||
from nerf.network import NeRFNetwork
|
||||
elif opt.backbone == 'grid':
|
||||
from nerf.network_grid import NeRFNetwork
|
||||
elif opt.backbone == 'grid_tcnn':
|
||||
from nerf.network_grid_tcnn import NeRFNetwork
|
||||
elif opt.backbone == 'grid_taichi':
|
||||
opt.cuda_ray = False
|
||||
opt.taichi_ray = True
|
||||
import taichi as ti
|
||||
from nerf.network_grid_taichi import NeRFNetwork
|
||||
taichi_half2_opt = True
|
||||
taichi_init_args = {"arch": ti.cuda, "device_memory_GB": 4.0}
|
||||
if taichi_half2_opt:
|
||||
taichi_init_args["half2_vectorization"] = True
|
||||
ti.init(**taichi_init_args)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'--backbone {opt.backbone} is not implemented!')
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
opt.device = device
|
||||
model = NeRFNetwork(opt).to(device)
|
||||
|
||||
if opt.init_ckpt != '':
|
||||
if not os.path.exists(opt.init_ckpt):
|
||||
logger.warning(f'ckpt {opt.init_ckpt} is not found')
|
||||
else:
|
||||
# load pretrained weights to init dmtet
|
||||
state_dict = torch.load(opt.init_ckpt, map_location=device)
|
||||
model.load_state_dict(state_dict['model'], strict=False)
|
||||
if opt.cuda_ray:
|
||||
model.mean_density = state_dict['mean_density']
|
||||
logger.info(f'init from {opt.init_ckpt}...')
|
||||
# if init ckpt is provided, we assume the color network is well learned and do not need base_mesh init
|
||||
opt.shape_init_color = False
|
||||
opt.base_mesh = None
|
||||
|
||||
if opt.use_shape and opt.dmtet:
|
||||
# now only supports shape for dmtet init
|
||||
from guidance.shape_utils import get_shape_from_image
|
||||
|
||||
opt.points = generate_grid_points(
|
||||
128, device=device) if not opt.dmtet else model.dmtet.verts
|
||||
opt.rpsts, opt.colors = get_shape_from_image(
|
||||
opt.image.replace('rgba', 'rgb'),
|
||||
opt.points,
|
||||
rpst_type=opt.shape_rpst,
|
||||
get_color=opt.shape_init_color,
|
||||
shape_guidance=opt.shape_guidance, device=device)
|
||||
scale = opt.default_radius / opt.shape_radius * \
|
||||
np.tan(np.deg2rad(opt.default_fovy / 2)) / \
|
||||
np.tan(np.deg2rad(opt.shape_fovy / 2))
|
||||
if opt.dmtet:
|
||||
model.dmtet.reset_tet_scale(scale)
|
||||
else:
|
||||
opt.points *= scale
|
||||
logger.info(f'Got sdf from Shap-E init...')
|
||||
|
||||
logger.info(model)
|
||||
|
||||
if opt.six_views:
|
||||
guidance = None # no need to load guidance model at test
|
||||
|
||||
trainer = Trainer(' '.join(sys.argv), 'df', opt, model, guidance, device=device,
|
||||
workspace=opt.workspace, fp16=opt.fp16, use_checkpoint=opt.ckpt)
|
||||
|
||||
test_loader = NeRFDataset(
|
||||
opt, device=device, type='six_views', H=opt.H, W=opt.W, size=6).dataloader(batch_size=1)
|
||||
trainer.test(test_loader, write_video=False)
|
||||
|
||||
if opt.save_mesh:
|
||||
trainer.save_mesh()
|
||||
|
||||
elif opt.test:
|
||||
guidance = None # no need to load guidance model at test
|
||||
trainer = Trainer(' '.join(sys.argv), os.path.basename(opt.workspace), opt, model, guidance,
|
||||
device=device, workspace=opt.workspace, fp16=opt.fp16, use_checkpoint=opt.ckpt)
|
||||
if opt.gui:
|
||||
from nerf.gui import NeRFGUI
|
||||
gui = NeRFGUI(opt, trainer)
|
||||
gui.render()
|
||||
|
||||
else:
|
||||
test_loader = NeRFDataset(
|
||||
opt, device=device, type='test', H=opt.H, W=opt.W, size=opt.dataset_size_test).dataloader()
|
||||
trainer.test(test_loader)
|
||||
trainer.test(test_loader, shading='normal') # save normal
|
||||
if opt.save_mesh:
|
||||
try:
|
||||
trainer.save_mesh()
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
train_loader = NeRFDataset(
|
||||
opt, device=device, type='train', H=opt.h, W=opt.w, size=opt.dataset_size_train * opt.batch_size).dataloader()
|
||||
|
||||
if opt.optim == 'adan':
|
||||
from optimizer import Adan
|
||||
# Adan usually requires a larger LR
|
||||
|
||||
def optimizer(model): return Adan(model.get_params(
|
||||
5 * opt.lr), eps=1e-8, weight_decay=2e-5, max_grad_norm=5.0, foreach=False)
|
||||
else: # adam
|
||||
def optimizer(model): return torch.optim.Adam(
|
||||
model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)
|
||||
|
||||
if opt.backbone == 'vanilla':
|
||||
def scheduler(optimizer): return optim.lr_scheduler.LambdaLR(
|
||||
optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1))
|
||||
else:
|
||||
def scheduler(optimizer): return optim.lr_scheduler.LambdaLR(
|
||||
optimizer, lambda iter: 1) # fixed
|
||||
|
||||
guidance = nn.ModuleDict()
|
||||
lambda_guidance, guidance_scale = {}, {}
|
||||
for idx, guidance_type in enumerate(opt.guidance):
|
||||
lambda_guidance[guidance_type] = opt.lambda_guidance[idx] if idx < len(
|
||||
opt.lambda_guidance) else opt.lambda_guidance[-1]
|
||||
guidance_scale[guidance_type] = opt.guidance_scale[idx] if idx < len(
|
||||
opt.guidance_scale) else opt.guidance_scale[-1]
|
||||
if 'SD' == guidance_type:
|
||||
from guidance.sd_utils import StableDiffusion, token_replace
|
||||
guidance['SD'] = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key, opt.t_range,
|
||||
learned_embeds_path=opt.learned_embeds_path,
|
||||
use_clip=opt.use_clip, clip_t=opt.clip_t, clip_iterative=opt.clip_iterative, clip_version=opt.clip_version,
|
||||
)
|
||||
if opt.learned_embeds_path is not None and os.path.exists(opt.learned_embeds_path): # add textual inversion tokens to model
|
||||
opt.text, opt.negative = token_replace(
|
||||
opt.text, opt.negative, opt.learned_embeds_path)
|
||||
logger.info(
|
||||
f'prompt: {opt.text}, negative: {opt.negative}')
|
||||
if opt.check_prompt:
|
||||
guidance['SD'].check_prompt(opt)
|
||||
else:
|
||||
opt.text = opt.text.replace('<token>', os.path.basename(os.path.dirname(opt.image)))
|
||||
logger.warning('No learned_embeds_path provided, using the folowing pure text prompt with degraded performance: ' + opt.text)
|
||||
|
||||
if 'IF' == guidance_type:
|
||||
from guidance.if_utils import IF
|
||||
guidance['IF'] = IF(device, opt.vram_O, opt.t_range)
|
||||
|
||||
if 'zero123' == guidance_type:
|
||||
from guidance.zero123_utils import Zero123
|
||||
guidance['zero123'] = Zero123(device=device, fp16=opt.fp16, config=opt.zero123_config,
|
||||
ckpt=opt.zero123_ckpt, vram_O=opt.vram_O, t_range=opt.t_range, opt=opt)
|
||||
|
||||
if 'clip' == guidance_type:
|
||||
from guidance.clip_utils import CLIP
|
||||
guidance['clip'] = CLIP(device)
|
||||
opt.lambda_guidance = lambda_guidance
|
||||
opt.guidance_scale = guidance_scale
|
||||
|
||||
logger.info(opt)
|
||||
trainer = Trainer(' '.join(sys.argv), os.path.basename(opt.workspace), opt, model,
|
||||
guidance,
|
||||
device=device, workspace=opt.workspace, optimizer=optimizer,
|
||||
ema_decay=opt.ema_decay, fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint=opt.ckpt, scheduler_update_every_step=True)
|
||||
trainer.default_view_data = train_loader._data.get_default_view_data()
|
||||
|
||||
if opt.gui:
|
||||
from nerf.gui import NeRFGUI
|
||||
gui = NeRFGUI(opt, trainer, train_loader)
|
||||
gui.render()
|
||||
|
||||
else:
|
||||
valid_loader = NeRFDataset(
|
||||
opt, device=device, type='val', H=opt.H, W=opt.W, size=opt.dataset_size_valid).dataloader()
|
||||
test_loader = NeRFDataset(
|
||||
opt, device=device, type='test', H=opt.H, W=opt.W, size=opt.dataset_size_test).dataloader()
|
||||
max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32)
|
||||
|
||||
trainer.train(train_loader, valid_loader, test_loader, max_epoch)
|
||||
|
||||
trainer.test(test_loader)
|
||||
trainer.test(test_loader, shading='normal') # save normal
|
||||
if opt.save_mesh:
|
||||
try:
|
||||
trainer.save_mesh()
|
||||
except:
|
||||
pass
|
||||
117
meshutils.py
Normal file
117
meshutils.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import numpy as np
|
||||
import pymeshlab as pml
|
||||
|
||||
def poisson_mesh_reconstruction(points, normals=None):
|
||||
# points/normals: [N, 3] np.ndarray
|
||||
|
||||
import open3d as o3d
|
||||
|
||||
pcd = o3d.geometry.PointCloud()
|
||||
pcd.points = o3d.utility.Vector3dVector(points)
|
||||
|
||||
# outlier removal
|
||||
pcd, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=10)
|
||||
|
||||
# normals
|
||||
if normals is None:
|
||||
pcd.estimate_normals()
|
||||
else:
|
||||
pcd.normals = o3d.utility.Vector3dVector(normals[ind])
|
||||
|
||||
# visualize
|
||||
o3d.visualization.draw_geometries([pcd], point_show_normal=False)
|
||||
|
||||
mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=9)
|
||||
vertices_to_remove = densities < np.quantile(densities, 0.1)
|
||||
mesh.remove_vertices_by_mask(vertices_to_remove)
|
||||
|
||||
# visualize
|
||||
o3d.visualization.draw_geometries([mesh])
|
||||
|
||||
vertices = np.asarray(mesh.vertices)
|
||||
triangles = np.asarray(mesh.triangles)
|
||||
|
||||
print(f'[INFO] poisson mesh reconstruction: {points.shape} --> {vertices.shape} / {triangles.shape}')
|
||||
|
||||
return vertices, triangles
|
||||
|
||||
|
||||
def decimate_mesh(verts, faces, target, backend='pymeshlab', remesh=False, optimalplacement=True):
|
||||
# optimalplacement: default is True, but for flat mesh must turn False to prevent spike artifect.
|
||||
|
||||
_ori_vert_shape = verts.shape
|
||||
_ori_face_shape = faces.shape
|
||||
|
||||
if backend == 'pyfqmr':
|
||||
import pyfqmr
|
||||
solver = pyfqmr.Simplify()
|
||||
solver.setMesh(verts, faces)
|
||||
solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False)
|
||||
verts, faces, normals = solver.getMesh()
|
||||
else:
|
||||
|
||||
m = pml.Mesh(verts, faces)
|
||||
ms = pml.MeshSet()
|
||||
ms.add_mesh(m, 'mesh') # will copy!
|
||||
|
||||
# filters
|
||||
# ms.meshing_decimation_clustering(threshold=pml.Percentage(1))
|
||||
ms.meshing_decimation_quadric_edge_collapse(targetfacenum=int(target), optimalplacement=optimalplacement)
|
||||
|
||||
if remesh:
|
||||
# ms.apply_coord_taubin_smoothing()
|
||||
ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.Percentage(1))
|
||||
|
||||
# extract mesh
|
||||
m = ms.current_mesh()
|
||||
verts = m.vertex_matrix()
|
||||
faces = m.face_matrix()
|
||||
|
||||
print(f'[INFO] mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}')
|
||||
|
||||
return verts, faces
|
||||
|
||||
|
||||
def clean_mesh(verts, faces, v_pct=1, min_f=8, min_d=5, repair=True, remesh=True, remesh_size=0.01):
|
||||
# verts: [N, 3]
|
||||
# faces: [N, 3]
|
||||
|
||||
_ori_vert_shape = verts.shape
|
||||
_ori_face_shape = faces.shape
|
||||
|
||||
m = pml.Mesh(verts, faces)
|
||||
ms = pml.MeshSet()
|
||||
ms.add_mesh(m, 'mesh') # will copy!
|
||||
|
||||
# filters
|
||||
ms.meshing_remove_unreferenced_vertices() # verts not refed by any faces
|
||||
|
||||
if v_pct > 0:
|
||||
ms.meshing_merge_close_vertices(threshold=pml.Percentage(v_pct)) # 1/10000 of bounding box diagonal
|
||||
|
||||
ms.meshing_remove_duplicate_faces() # faces defined by the same verts
|
||||
ms.meshing_remove_null_faces() # faces with area == 0
|
||||
|
||||
if min_d > 0:
|
||||
ms.meshing_remove_connected_component_by_diameter(mincomponentdiag=pml.Percentage(min_d))
|
||||
|
||||
if min_f > 0:
|
||||
ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f)
|
||||
|
||||
if repair:
|
||||
# ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True)
|
||||
ms.meshing_repair_non_manifold_edges(method=0)
|
||||
ms.meshing_repair_non_manifold_vertices(vertdispratio=0)
|
||||
|
||||
if remesh:
|
||||
# ms.apply_coord_taubin_smoothing()
|
||||
ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.AbsoluteValue(remesh_size))
|
||||
|
||||
# extract mesh
|
||||
m = ms.current_mesh()
|
||||
verts = m.vertex_matrix()
|
||||
faces = m.face_matrix()
|
||||
|
||||
print(f'[INFO] mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}')
|
||||
|
||||
return verts, faces
|
||||
1
midas/__init__.py
Normal file
1
midas/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .model_loader import load_model, default_models
|
||||
196
midas/backbones/beit.py
Normal file
196
midas/backbones/beit.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import timm
|
||||
import torch
|
||||
import types
|
||||
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .utils import forward_adapted_unflatten, make_backbone_default
|
||||
from timm.models.beit import gen_relative_position_index
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def forward_beit(pretrained, x):
|
||||
return forward_adapted_unflatten(pretrained, x, "forward_features")
|
||||
|
||||
|
||||
def patch_embed_forward(self, x):
|
||||
"""
|
||||
Modification of timm.models.layers.patch_embed.py: PatchEmbed.forward to support arbitrary window sizes.
|
||||
"""
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
def _get_rel_pos_bias(self, window_size):
|
||||
"""
|
||||
Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes.
|
||||
"""
|
||||
old_height = 2 * self.window_size[0] - 1
|
||||
old_width = 2 * self.window_size[1] - 1
|
||||
|
||||
new_height = 2 * window_size[0] - 1
|
||||
new_width = 2 * window_size[1] - 1
|
||||
|
||||
old_relative_position_bias_table = self.relative_position_bias_table
|
||||
|
||||
old_num_relative_distance = self.num_relative_distance
|
||||
new_num_relative_distance = new_height * new_width + 3
|
||||
|
||||
old_sub_table = old_relative_position_bias_table[:old_num_relative_distance - 3]
|
||||
|
||||
old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
|
||||
new_sub_table = F.interpolate(old_sub_table, size=(new_height, new_width), mode="bilinear")
|
||||
new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)
|
||||
|
||||
new_relative_position_bias_table = torch.cat(
|
||||
[new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3:]])
|
||||
|
||||
key = str(window_size[1]) + "," + str(window_size[0])
|
||||
if key not in self.relative_position_indices.keys():
|
||||
self.relative_position_indices[key] = gen_relative_position_index(window_size)
|
||||
|
||||
relative_position_bias = new_relative_position_bias_table[
|
||||
self.relative_position_indices[key].view(-1)].view(
|
||||
window_size[0] * window_size[1] + 1,
|
||||
window_size[0] * window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
return relative_position_bias.unsqueeze(0)
|
||||
|
||||
|
||||
def attention_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Modification of timm.models.beit.py: Attention.forward to support arbitrary window sizes.
|
||||
"""
|
||||
B, N, C = x.shape
|
||||
|
||||
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
||||
if self.relative_position_bias_table is not None:
|
||||
window_size = tuple(np.array(resolution) // 16)
|
||||
attn = attn + self._get_rel_pos_bias(window_size)
|
||||
if shared_rel_pos_bias is not None:
|
||||
attn = attn + shared_rel_pos_bias
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
def block_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Modification of timm.models.beit.py: Block.forward to support arbitrary window sizes.
|
||||
"""
|
||||
if self.gamma_1 is None:
|
||||
x = x + self.drop_path(self.attn(self.norm1(x), resolution, shared_rel_pos_bias=shared_rel_pos_bias))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
else:
|
||||
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), resolution,
|
||||
shared_rel_pos_bias=shared_rel_pos_bias))
|
||||
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
def beit_forward_features(self, x):
|
||||
"""
|
||||
Modification of timm.models.beit.py: Beit.forward_features to support arbitrary window sizes.
|
||||
"""
|
||||
resolution = x.shape[2:]
|
||||
|
||||
x = self.patch_embed(x)
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
||||
for blk in self.blocks:
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias)
|
||||
else:
|
||||
x = blk(x, resolution, shared_rel_pos_bias=rel_pos_bias)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
def _make_beit_backbone(
|
||||
model,
|
||||
features=[96, 192, 384, 768],
|
||||
size=[384, 384],
|
||||
hooks=[0, 4, 8, 11],
|
||||
vit_features=768,
|
||||
use_readout="ignore",
|
||||
start_index=1,
|
||||
start_index_readout=1,
|
||||
):
|
||||
backbone = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index,
|
||||
start_index_readout)
|
||||
|
||||
backbone.model.patch_embed.forward = types.MethodType(patch_embed_forward, backbone.model.patch_embed)
|
||||
backbone.model.forward_features = types.MethodType(beit_forward_features, backbone.model)
|
||||
|
||||
for block in backbone.model.blocks:
|
||||
attn = block.attn
|
||||
attn._get_rel_pos_bias = types.MethodType(_get_rel_pos_bias, attn)
|
||||
attn.forward = types.MethodType(attention_forward, attn)
|
||||
attn.relative_position_indices = {}
|
||||
|
||||
block.forward = types.MethodType(block_forward, block)
|
||||
|
||||
return backbone
|
||||
|
||||
|
||||
def _make_pretrained_beitl16_512(pretrained, use_readout="ignore", hooks=None):
|
||||
model = timm.create_model("beit_large_patch16_512", pretrained=pretrained)
|
||||
|
||||
hooks = [5, 11, 17, 23] if hooks is None else hooks
|
||||
|
||||
features = [256, 512, 1024, 1024]
|
||||
|
||||
return _make_beit_backbone(
|
||||
model,
|
||||
features=features,
|
||||
size=[512, 512],
|
||||
hooks=hooks,
|
||||
vit_features=1024,
|
||||
use_readout=use_readout,
|
||||
)
|
||||
|
||||
|
||||
def _make_pretrained_beitl16_384(pretrained, use_readout="ignore", hooks=None):
|
||||
model = timm.create_model("beit_large_patch16_384", pretrained=pretrained)
|
||||
|
||||
hooks = [5, 11, 17, 23] if hooks is None else hooks
|
||||
return _make_beit_backbone(
|
||||
model,
|
||||
features=[256, 512, 1024, 1024],
|
||||
hooks=hooks,
|
||||
vit_features=1024,
|
||||
use_readout=use_readout,
|
||||
)
|
||||
|
||||
|
||||
def _make_pretrained_beitb16_384(pretrained, use_readout="ignore", hooks=None):
|
||||
model = timm.create_model("beit_base_patch16_384", pretrained=pretrained)
|
||||
|
||||
hooks = [2, 5, 8, 11] if hooks is None else hooks
|
||||
return _make_beit_backbone(
|
||||
model,
|
||||
features=[96, 192, 384, 768],
|
||||
hooks=hooks,
|
||||
use_readout=use_readout,
|
||||
)
|
||||
106
midas/backbones/levit.py
Normal file
106
midas/backbones/levit.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import timm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .utils import activations, get_activation, Transpose
|
||||
|
||||
|
||||
def forward_levit(pretrained, x):
|
||||
pretrained.model.forward_features(x)
|
||||
|
||||
layer_1 = pretrained.activations["1"]
|
||||
layer_2 = pretrained.activations["2"]
|
||||
layer_3 = pretrained.activations["3"]
|
||||
|
||||
layer_1 = pretrained.act_postprocess1(layer_1)
|
||||
layer_2 = pretrained.act_postprocess2(layer_2)
|
||||
layer_3 = pretrained.act_postprocess3(layer_3)
|
||||
|
||||
return layer_1, layer_2, layer_3
|
||||
|
||||
|
||||
def _make_levit_backbone(
|
||||
model,
|
||||
hooks=[3, 11, 21],
|
||||
patch_grid=[14, 14]
|
||||
):
|
||||
pretrained = nn.Module()
|
||||
|
||||
pretrained.model = model
|
||||
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
||||
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
||||
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
||||
|
||||
pretrained.activations = activations
|
||||
|
||||
patch_grid_size = np.array(patch_grid, dtype=int)
|
||||
|
||||
pretrained.act_postprocess1 = nn.Sequential(
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size(patch_grid_size.tolist()))
|
||||
)
|
||||
pretrained.act_postprocess2 = nn.Sequential(
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 2).astype(int)).tolist()))
|
||||
)
|
||||
pretrained.act_postprocess3 = nn.Sequential(
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 4).astype(int)).tolist()))
|
||||
)
|
||||
|
||||
return pretrained
|
||||
|
||||
|
||||
class ConvTransposeNorm(nn.Sequential):
|
||||
"""
|
||||
Modification of
|
||||
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: ConvNorm
|
||||
such that ConvTranspose2d is used instead of Conv2d.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1,
|
||||
groups=1, bn_weight_init=1):
|
||||
super().__init__()
|
||||
self.add_module('c',
|
||||
nn.ConvTranspose2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False))
|
||||
self.add_module('bn', nn.BatchNorm2d(out_chs))
|
||||
|
||||
nn.init.constant_(self.bn.weight, bn_weight_init)
|
||||
|
||||
@torch.no_grad()
|
||||
def fuse(self):
|
||||
c, bn = self._modules.values()
|
||||
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
|
||||
w = c.weight * w[:, None, None, None]
|
||||
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
|
||||
m = nn.ConvTranspose2d(
|
||||
w.size(1), w.size(0), w.shape[2:], stride=self.c.stride,
|
||||
padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
|
||||
m.weight.data.copy_(w)
|
||||
m.bias.data.copy_(b)
|
||||
return m
|
||||
|
||||
|
||||
def stem_b4_transpose(in_chs, out_chs, activation):
|
||||
"""
|
||||
Modification of
|
||||
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: stem_b16
|
||||
such that ConvTranspose2d is used instead of Conv2d and stem is also reduced to the half.
|
||||
"""
|
||||
return nn.Sequential(
|
||||
ConvTransposeNorm(in_chs, out_chs, 3, 2, 1),
|
||||
activation(),
|
||||
ConvTransposeNorm(out_chs, out_chs // 2, 3, 2, 1),
|
||||
activation())
|
||||
|
||||
|
||||
def _make_pretrained_levit_384(pretrained, hooks=None):
|
||||
model = timm.create_model("levit_384", pretrained=pretrained)
|
||||
|
||||
hooks = [3, 11, 21] if hooks == None else hooks
|
||||
return _make_levit_backbone(
|
||||
model,
|
||||
hooks=hooks
|
||||
)
|
||||
39
midas/backbones/next_vit.py
Normal file
39
midas/backbones/next_vit.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import timm
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from pathlib import Path
|
||||
from .utils import activations, forward_default, get_activation
|
||||
|
||||
from ..external.next_vit.classification.nextvit import *
|
||||
|
||||
|
||||
def forward_next_vit(pretrained, x):
|
||||
return forward_default(pretrained, x, "forward")
|
||||
|
||||
|
||||
def _make_next_vit_backbone(
|
||||
model,
|
||||
hooks=[2, 6, 36, 39],
|
||||
):
|
||||
pretrained = nn.Module()
|
||||
|
||||
pretrained.model = model
|
||||
pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1"))
|
||||
pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2"))
|
||||
pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3"))
|
||||
pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4"))
|
||||
|
||||
pretrained.activations = activations
|
||||
|
||||
return pretrained
|
||||
|
||||
|
||||
def _make_pretrained_next_vit_large_6m(hooks=None):
|
||||
model = timm.create_model("nextvit_large")
|
||||
|
||||
hooks = [2, 6, 36, 39] if hooks == None else hooks
|
||||
return _make_next_vit_backbone(
|
||||
model,
|
||||
hooks=hooks,
|
||||
)
|
||||
13
midas/backbones/swin.py
Normal file
13
midas/backbones/swin.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import timm
|
||||
|
||||
from .swin_common import _make_swin_backbone
|
||||
|
||||
|
||||
def _make_pretrained_swinl12_384(pretrained, hooks=None):
|
||||
model = timm.create_model("swin_large_patch4_window12_384", pretrained=pretrained)
|
||||
|
||||
hooks = [1, 1, 17, 1] if hooks == None else hooks
|
||||
return _make_swin_backbone(
|
||||
model,
|
||||
hooks=hooks
|
||||
)
|
||||
34
midas/backbones/swin2.py
Normal file
34
midas/backbones/swin2.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import timm
|
||||
|
||||
from .swin_common import _make_swin_backbone
|
||||
|
||||
|
||||
def _make_pretrained_swin2l24_384(pretrained, hooks=None):
|
||||
model = timm.create_model("swinv2_large_window12to24_192to384_22kft1k", pretrained=pretrained)
|
||||
|
||||
hooks = [1, 1, 17, 1] if hooks == None else hooks
|
||||
return _make_swin_backbone(
|
||||
model,
|
||||
hooks=hooks
|
||||
)
|
||||
|
||||
|
||||
def _make_pretrained_swin2b24_384(pretrained, hooks=None):
|
||||
model = timm.create_model("swinv2_base_window12to24_192to384_22kft1k", pretrained=pretrained)
|
||||
|
||||
hooks = [1, 1, 17, 1] if hooks == None else hooks
|
||||
return _make_swin_backbone(
|
||||
model,
|
||||
hooks=hooks
|
||||
)
|
||||
|
||||
|
||||
def _make_pretrained_swin2t16_256(pretrained, hooks=None):
|
||||
model = timm.create_model("swinv2_tiny_window16_256", pretrained=pretrained)
|
||||
|
||||
hooks = [1, 1, 5, 1] if hooks == None else hooks
|
||||
return _make_swin_backbone(
|
||||
model,
|
||||
hooks=hooks,
|
||||
patch_grid=[64, 64]
|
||||
)
|
||||
52
midas/backbones/swin_common.py
Normal file
52
midas/backbones/swin_common.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .utils import activations, forward_default, get_activation, Transpose
|
||||
|
||||
|
||||
def forward_swin(pretrained, x):
|
||||
return forward_default(pretrained, x)
|
||||
|
||||
|
||||
def _make_swin_backbone(
|
||||
model,
|
||||
hooks=[1, 1, 17, 1],
|
||||
patch_grid=[96, 96]
|
||||
):
|
||||
pretrained = nn.Module()
|
||||
|
||||
pretrained.model = model
|
||||
pretrained.model.layers[0].blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
||||
pretrained.model.layers[1].blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
||||
pretrained.model.layers[2].blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
||||
pretrained.model.layers[3].blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
||||
|
||||
pretrained.activations = activations
|
||||
|
||||
if hasattr(model, "patch_grid"):
|
||||
used_patch_grid = model.patch_grid
|
||||
else:
|
||||
used_patch_grid = patch_grid
|
||||
|
||||
patch_grid_size = np.array(used_patch_grid, dtype=int)
|
||||
|
||||
pretrained.act_postprocess1 = nn.Sequential(
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size(patch_grid_size.tolist()))
|
||||
)
|
||||
pretrained.act_postprocess2 = nn.Sequential(
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size((patch_grid_size // 2).tolist()))
|
||||
)
|
||||
pretrained.act_postprocess3 = nn.Sequential(
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size((patch_grid_size // 4).tolist()))
|
||||
)
|
||||
pretrained.act_postprocess4 = nn.Sequential(
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size((patch_grid_size // 8).tolist()))
|
||||
)
|
||||
|
||||
return pretrained
|
||||
249
midas/backbones/utils.py
Normal file
249
midas/backbones/utils.py
Normal file
@@ -0,0 +1,249 @@
|
||||
import torch
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Slice(nn.Module):
|
||||
def __init__(self, start_index=1):
|
||||
super(Slice, self).__init__()
|
||||
self.start_index = start_index
|
||||
|
||||
def forward(self, x):
|
||||
return x[:, self.start_index:]
|
||||
|
||||
|
||||
class AddReadout(nn.Module):
|
||||
def __init__(self, start_index=1):
|
||||
super(AddReadout, self).__init__()
|
||||
self.start_index = start_index
|
||||
|
||||
def forward(self, x):
|
||||
if self.start_index == 2:
|
||||
readout = (x[:, 0] + x[:, 1]) / 2
|
||||
else:
|
||||
readout = x[:, 0]
|
||||
return x[:, self.start_index:] + readout.unsqueeze(1)
|
||||
|
||||
|
||||
class ProjectReadout(nn.Module):
|
||||
def __init__(self, in_features, start_index=1):
|
||||
super(ProjectReadout, self).__init__()
|
||||
self.start_index = start_index
|
||||
|
||||
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
||||
|
||||
def forward(self, x):
|
||||
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:])
|
||||
features = torch.cat((x[:, self.start_index:], readout), -1)
|
||||
|
||||
return self.project(features)
|
||||
|
||||
|
||||
class Transpose(nn.Module):
|
||||
def __init__(self, dim0, dim1):
|
||||
super(Transpose, self).__init__()
|
||||
self.dim0 = dim0
|
||||
self.dim1 = dim1
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(self.dim0, self.dim1)
|
||||
return x
|
||||
|
||||
|
||||
activations = {}
|
||||
|
||||
|
||||
def get_activation(name):
|
||||
def hook(model, input, output):
|
||||
activations[name] = output
|
||||
|
||||
return hook
|
||||
|
||||
|
||||
def forward_default(pretrained, x, function_name="forward_features"):
|
||||
exec(f"pretrained.model.{function_name}(x)")
|
||||
|
||||
layer_1 = pretrained.activations["1"]
|
||||
layer_2 = pretrained.activations["2"]
|
||||
layer_3 = pretrained.activations["3"]
|
||||
layer_4 = pretrained.activations["4"]
|
||||
|
||||
if hasattr(pretrained, "act_postprocess1"):
|
||||
layer_1 = pretrained.act_postprocess1(layer_1)
|
||||
if hasattr(pretrained, "act_postprocess2"):
|
||||
layer_2 = pretrained.act_postprocess2(layer_2)
|
||||
if hasattr(pretrained, "act_postprocess3"):
|
||||
layer_3 = pretrained.act_postprocess3(layer_3)
|
||||
if hasattr(pretrained, "act_postprocess4"):
|
||||
layer_4 = pretrained.act_postprocess4(layer_4)
|
||||
|
||||
return layer_1, layer_2, layer_3, layer_4
|
||||
|
||||
|
||||
def forward_adapted_unflatten(pretrained, x, function_name="forward_features"):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
exec(f"glob = pretrained.model.{function_name}(x)")
|
||||
|
||||
layer_1 = pretrained.activations["1"]
|
||||
layer_2 = pretrained.activations["2"]
|
||||
layer_3 = pretrained.activations["3"]
|
||||
layer_4 = pretrained.activations["4"]
|
||||
|
||||
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
|
||||
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
|
||||
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
|
||||
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
|
||||
|
||||
unflatten = nn.Sequential(
|
||||
nn.Unflatten(
|
||||
2,
|
||||
torch.Size(
|
||||
[
|
||||
h // pretrained.model.patch_size[1],
|
||||
w // pretrained.model.patch_size[0],
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if layer_1.ndim == 3:
|
||||
layer_1 = unflatten(layer_1)
|
||||
if layer_2.ndim == 3:
|
||||
layer_2 = unflatten(layer_2)
|
||||
if layer_3.ndim == 3:
|
||||
layer_3 = unflatten(layer_3)
|
||||
if layer_4.ndim == 3:
|
||||
layer_4 = unflatten(layer_4)
|
||||
|
||||
layer_1 = pretrained.act_postprocess1[3: len(pretrained.act_postprocess1)](layer_1)
|
||||
layer_2 = pretrained.act_postprocess2[3: len(pretrained.act_postprocess2)](layer_2)
|
||||
layer_3 = pretrained.act_postprocess3[3: len(pretrained.act_postprocess3)](layer_3)
|
||||
layer_4 = pretrained.act_postprocess4[3: len(pretrained.act_postprocess4)](layer_4)
|
||||
|
||||
return layer_1, layer_2, layer_3, layer_4
|
||||
|
||||
|
||||
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
||||
if use_readout == "ignore":
|
||||
readout_oper = [Slice(start_index)] * len(features)
|
||||
elif use_readout == "add":
|
||||
readout_oper = [AddReadout(start_index)] * len(features)
|
||||
elif use_readout == "project":
|
||||
readout_oper = [
|
||||
ProjectReadout(vit_features, start_index) for out_feat in features
|
||||
]
|
||||
else:
|
||||
assert (
|
||||
False
|
||||
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
||||
|
||||
return readout_oper
|
||||
|
||||
|
||||
def make_backbone_default(
|
||||
model,
|
||||
features=[96, 192, 384, 768],
|
||||
size=[384, 384],
|
||||
hooks=[2, 5, 8, 11],
|
||||
vit_features=768,
|
||||
use_readout="ignore",
|
||||
start_index=1,
|
||||
start_index_readout=1,
|
||||
):
|
||||
pretrained = nn.Module()
|
||||
|
||||
pretrained.model = model
|
||||
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
||||
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
||||
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
||||
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
||||
|
||||
pretrained.activations = activations
|
||||
|
||||
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index_readout)
|
||||
|
||||
# 32, 48, 136, 384
|
||||
pretrained.act_postprocess1 = nn.Sequential(
|
||||
readout_oper[0],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[0],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=features[0],
|
||||
out_channels=features[0],
|
||||
kernel_size=4,
|
||||
stride=4,
|
||||
padding=0,
|
||||
bias=True,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.act_postprocess2 = nn.Sequential(
|
||||
readout_oper[1],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[1],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=features[1],
|
||||
out_channels=features[1],
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0,
|
||||
bias=True,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.act_postprocess3 = nn.Sequential(
|
||||
readout_oper[2],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[2],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.act_postprocess4 = nn.Sequential(
|
||||
readout_oper[3],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[3],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
nn.Conv2d(
|
||||
in_channels=features[3],
|
||||
out_channels=features[3],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
),
|
||||
)
|
||||
|
||||
pretrained.model.start_index = start_index
|
||||
pretrained.model.patch_size = [16, 16]
|
||||
|
||||
return pretrained
|
||||
221
midas/backbones/vit.py
Normal file
221
midas/backbones/vit.py
Normal file
@@ -0,0 +1,221 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import timm
|
||||
import types
|
||||
import math
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .utils import (activations, forward_adapted_unflatten, get_activation, get_readout_oper,
|
||||
make_backbone_default, Transpose)
|
||||
|
||||
|
||||
def forward_vit(pretrained, x):
|
||||
return forward_adapted_unflatten(pretrained, x, "forward_flex")
|
||||
|
||||
|
||||
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
||||
posemb_tok, posemb_grid = (
|
||||
posemb[:, : self.start_index],
|
||||
posemb[0, self.start_index:],
|
||||
)
|
||||
|
||||
gs_old = int(math.sqrt(len(posemb_grid)))
|
||||
|
||||
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
||||
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
|
||||
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
||||
|
||||
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
||||
|
||||
return posemb
|
||||
|
||||
|
||||
def forward_flex(self, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
pos_embed = self._resize_pos_embed(
|
||||
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
|
||||
)
|
||||
|
||||
B = x.shape[0]
|
||||
|
||||
if hasattr(self.patch_embed, "backbone"):
|
||||
x = self.patch_embed.backbone(x)
|
||||
if isinstance(x, (list, tuple)):
|
||||
x = x[-1] # last feature if backbone outputs list/tuple of features
|
||||
|
||||
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
||||
|
||||
if getattr(self, "dist_token", None) is not None:
|
||||
cls_tokens = self.cls_token.expand(
|
||||
B, -1, -1
|
||||
) # stole cls_tokens impl from Phil Wang, thanks
|
||||
dist_token = self.dist_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
||||
else:
|
||||
if self.no_embed_class:
|
||||
x = x + pos_embed
|
||||
cls_tokens = self.cls_token.expand(
|
||||
B, -1, -1
|
||||
) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
if not self.no_embed_class:
|
||||
x = x + pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _make_vit_b16_backbone(
|
||||
model,
|
||||
features=[96, 192, 384, 768],
|
||||
size=[384, 384],
|
||||
hooks=[2, 5, 8, 11],
|
||||
vit_features=768,
|
||||
use_readout="ignore",
|
||||
start_index=1,
|
||||
start_index_readout=1,
|
||||
):
|
||||
pretrained = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index,
|
||||
start_index_readout)
|
||||
|
||||
# We inject this function into the VisionTransformer instances so that
|
||||
# we can use it with interpolated position embeddings without modifying the library source.
|
||||
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
||||
pretrained.model._resize_pos_embed = types.MethodType(
|
||||
_resize_pos_embed, pretrained.model
|
||||
)
|
||||
|
||||
return pretrained
|
||||
|
||||
|
||||
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
|
||||
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
|
||||
|
||||
hooks = [5, 11, 17, 23] if hooks == None else hooks
|
||||
return _make_vit_b16_backbone(
|
||||
model,
|
||||
features=[256, 512, 1024, 1024],
|
||||
hooks=hooks,
|
||||
vit_features=1024,
|
||||
use_readout=use_readout,
|
||||
)
|
||||
|
||||
|
||||
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
|
||||
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
|
||||
|
||||
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
||||
return _make_vit_b16_backbone(
|
||||
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
||||
)
|
||||
|
||||
|
||||
def _make_vit_b_rn50_backbone(
|
||||
model,
|
||||
features=[256, 512, 768, 768],
|
||||
size=[384, 384],
|
||||
hooks=[0, 1, 8, 11],
|
||||
vit_features=768,
|
||||
patch_size=[16, 16],
|
||||
number_stages=2,
|
||||
use_vit_only=False,
|
||||
use_readout="ignore",
|
||||
start_index=1,
|
||||
):
|
||||
pretrained = nn.Module()
|
||||
|
||||
pretrained.model = model
|
||||
|
||||
used_number_stages = 0 if use_vit_only else number_stages
|
||||
for s in range(used_number_stages):
|
||||
pretrained.model.patch_embed.backbone.stages[s].register_forward_hook(
|
||||
get_activation(str(s + 1))
|
||||
)
|
||||
for s in range(used_number_stages, 4):
|
||||
pretrained.model.blocks[hooks[s]].register_forward_hook(get_activation(str(s + 1)))
|
||||
|
||||
pretrained.activations = activations
|
||||
|
||||
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
||||
|
||||
for s in range(used_number_stages):
|
||||
value = nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity())
|
||||
exec(f"pretrained.act_postprocess{s + 1}=value")
|
||||
for s in range(used_number_stages, 4):
|
||||
if s < number_stages:
|
||||
final_layer = nn.ConvTranspose2d(
|
||||
in_channels=features[s],
|
||||
out_channels=features[s],
|
||||
kernel_size=4 // (2 ** s),
|
||||
stride=4 // (2 ** s),
|
||||
padding=0,
|
||||
bias=True,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
)
|
||||
elif s > number_stages:
|
||||
final_layer = nn.Conv2d(
|
||||
in_channels=features[3],
|
||||
out_channels=features[3],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
)
|
||||
else:
|
||||
final_layer = None
|
||||
|
||||
layers = [
|
||||
readout_oper[s],
|
||||
Transpose(1, 2),
|
||||
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
||||
nn.Conv2d(
|
||||
in_channels=vit_features,
|
||||
out_channels=features[s],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
),
|
||||
]
|
||||
if final_layer is not None:
|
||||
layers.append(final_layer)
|
||||
|
||||
value = nn.Sequential(*layers)
|
||||
exec(f"pretrained.act_postprocess{s + 1}=value")
|
||||
|
||||
pretrained.model.start_index = start_index
|
||||
pretrained.model.patch_size = patch_size
|
||||
|
||||
# We inject this function into the VisionTransformer instances so that
|
||||
# we can use it with interpolated position embeddings without modifying the library source.
|
||||
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
||||
|
||||
# We inject this function into the VisionTransformer instances so that
|
||||
# we can use it with interpolated position embeddings without modifying the library source.
|
||||
pretrained.model._resize_pos_embed = types.MethodType(
|
||||
_resize_pos_embed, pretrained.model
|
||||
)
|
||||
|
||||
return pretrained
|
||||
|
||||
|
||||
def _make_pretrained_vitb_rn50_384(
|
||||
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
|
||||
):
|
||||
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
|
||||
|
||||
hooks = [0, 1, 8, 11] if hooks == None else hooks
|
||||
return _make_vit_b_rn50_backbone(
|
||||
model,
|
||||
features=[256, 512, 768, 768],
|
||||
size=[384, 384],
|
||||
hooks=hooks,
|
||||
use_vit_only=use_vit_only,
|
||||
use_readout=use_readout,
|
||||
)
|
||||
16
midas/base_model.py
Normal file
16
midas/base_model.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import torch
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def load(self, path):
|
||||
"""Load model from file.
|
||||
|
||||
Args:
|
||||
path (str): file path
|
||||
"""
|
||||
parameters = torch.load(path, map_location=torch.device('cpu'))
|
||||
|
||||
if "optimizer" in parameters:
|
||||
parameters = parameters["model"]
|
||||
|
||||
self.load_state_dict(parameters)
|
||||
439
midas/blocks.py
Normal file
439
midas/blocks.py
Normal file
@@ -0,0 +1,439 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .backbones.beit import (
|
||||
_make_pretrained_beitl16_512,
|
||||
_make_pretrained_beitl16_384,
|
||||
_make_pretrained_beitb16_384,
|
||||
forward_beit,
|
||||
)
|
||||
from .backbones.swin_common import (
|
||||
forward_swin,
|
||||
)
|
||||
from .backbones.swin2 import (
|
||||
_make_pretrained_swin2l24_384,
|
||||
_make_pretrained_swin2b24_384,
|
||||
_make_pretrained_swin2t16_256,
|
||||
)
|
||||
from .backbones.swin import (
|
||||
_make_pretrained_swinl12_384,
|
||||
)
|
||||
from .backbones.levit import (
|
||||
_make_pretrained_levit_384,
|
||||
forward_levit,
|
||||
)
|
||||
from .backbones.vit import (
|
||||
_make_pretrained_vitb_rn50_384,
|
||||
_make_pretrained_vitl16_384,
|
||||
_make_pretrained_vitb16_384,
|
||||
forward_vit,
|
||||
)
|
||||
|
||||
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None,
|
||||
use_vit_only=False, use_readout="ignore", in_features=[96, 256, 512, 1024]):
|
||||
if backbone == "beitl16_512":
|
||||
pretrained = _make_pretrained_beitl16_512(
|
||||
use_pretrained, hooks=hooks, use_readout=use_readout
|
||||
)
|
||||
scratch = _make_scratch(
|
||||
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
||||
) # BEiT_512-L (backbone)
|
||||
elif backbone == "beitl16_384":
|
||||
pretrained = _make_pretrained_beitl16_384(
|
||||
use_pretrained, hooks=hooks, use_readout=use_readout
|
||||
)
|
||||
scratch = _make_scratch(
|
||||
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
||||
) # BEiT_384-L (backbone)
|
||||
elif backbone == "beitb16_384":
|
||||
pretrained = _make_pretrained_beitb16_384(
|
||||
use_pretrained, hooks=hooks, use_readout=use_readout
|
||||
)
|
||||
scratch = _make_scratch(
|
||||
[96, 192, 384, 768], features, groups=groups, expand=expand
|
||||
) # BEiT_384-B (backbone)
|
||||
elif backbone == "swin2l24_384":
|
||||
pretrained = _make_pretrained_swin2l24_384(
|
||||
use_pretrained, hooks=hooks
|
||||
)
|
||||
scratch = _make_scratch(
|
||||
[192, 384, 768, 1536], features, groups=groups, expand=expand
|
||||
) # Swin2-L/12to24 (backbone)
|
||||
elif backbone == "swin2b24_384":
|
||||
pretrained = _make_pretrained_swin2b24_384(
|
||||
use_pretrained, hooks=hooks
|
||||
)
|
||||
scratch = _make_scratch(
|
||||
[128, 256, 512, 1024], features, groups=groups, expand=expand
|
||||
) # Swin2-B/12to24 (backbone)
|
||||
elif backbone == "swin2t16_256":
|
||||
pretrained = _make_pretrained_swin2t16_256(
|
||||
use_pretrained, hooks=hooks
|
||||
)
|
||||
scratch = _make_scratch(
|
||||
[96, 192, 384, 768], features, groups=groups, expand=expand
|
||||
) # Swin2-T/16 (backbone)
|
||||
elif backbone == "swinl12_384":
|
||||
pretrained = _make_pretrained_swinl12_384(
|
||||
use_pretrained, hooks=hooks
|
||||
)
|
||||
scratch = _make_scratch(
|
||||
[192, 384, 768, 1536], features, groups=groups, expand=expand
|
||||
) # Swin-L/12 (backbone)
|
||||
elif backbone == "next_vit_large_6m":
|
||||
from .backbones.next_vit import _make_pretrained_next_vit_large_6m
|
||||
pretrained = _make_pretrained_next_vit_large_6m(hooks=hooks)
|
||||
scratch = _make_scratch(
|
||||
in_features, features, groups=groups, expand=expand
|
||||
) # Next-ViT-L on ImageNet-1K-6M (backbone)
|
||||
elif backbone == "levit_384":
|
||||
pretrained = _make_pretrained_levit_384(
|
||||
use_pretrained, hooks=hooks
|
||||
)
|
||||
scratch = _make_scratch(
|
||||
[384, 512, 768], features, groups=groups, expand=expand
|
||||
) # LeViT 384 (backbone)
|
||||
elif backbone == "vitl16_384":
|
||||
pretrained = _make_pretrained_vitl16_384(
|
||||
use_pretrained, hooks=hooks, use_readout=use_readout
|
||||
)
|
||||
scratch = _make_scratch(
|
||||
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
||||
) # ViT-L/16 - 85.0% Top1 (backbone)
|
||||
elif backbone == "vitb_rn50_384":
|
||||
pretrained = _make_pretrained_vitb_rn50_384(
|
||||
use_pretrained,
|
||||
hooks=hooks,
|
||||
use_vit_only=use_vit_only,
|
||||
use_readout=use_readout,
|
||||
)
|
||||
scratch = _make_scratch(
|
||||
[256, 512, 768, 768], features, groups=groups, expand=expand
|
||||
) # ViT-H/16 - 85.0% Top1 (backbone)
|
||||
elif backbone == "vitb16_384":
|
||||
pretrained = _make_pretrained_vitb16_384(
|
||||
use_pretrained, hooks=hooks, use_readout=use_readout
|
||||
)
|
||||
scratch = _make_scratch(
|
||||
[96, 192, 384, 768], features, groups=groups, expand=expand
|
||||
) # ViT-B/16 - 84.6% Top1 (backbone)
|
||||
elif backbone == "resnext101_wsl":
|
||||
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
||||
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
|
||||
elif backbone == "efficientnet_lite3":
|
||||
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
|
||||
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
|
||||
else:
|
||||
print(f"Backbone '{backbone}' not implemented")
|
||||
assert False
|
||||
|
||||
return pretrained, scratch
|
||||
|
||||
|
||||
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
||||
scratch = nn.Module()
|
||||
|
||||
out_shape1 = out_shape
|
||||
out_shape2 = out_shape
|
||||
out_shape3 = out_shape
|
||||
if len(in_shape) >= 4:
|
||||
out_shape4 = out_shape
|
||||
|
||||
if expand:
|
||||
out_shape1 = out_shape
|
||||
out_shape2 = out_shape*2
|
||||
out_shape3 = out_shape*4
|
||||
if len(in_shape) >= 4:
|
||||
out_shape4 = out_shape*8
|
||||
|
||||
scratch.layer1_rn = nn.Conv2d(
|
||||
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
scratch.layer2_rn = nn.Conv2d(
|
||||
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
scratch.layer3_rn = nn.Conv2d(
|
||||
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
if len(in_shape) >= 4:
|
||||
scratch.layer4_rn = nn.Conv2d(
|
||||
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
|
||||
return scratch
|
||||
|
||||
|
||||
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
|
||||
efficientnet = torch.hub.load(
|
||||
"rwightman/gen-efficientnet-pytorch",
|
||||
"tf_efficientnet_lite3",
|
||||
pretrained=use_pretrained,
|
||||
exportable=exportable
|
||||
)
|
||||
return _make_efficientnet_backbone(efficientnet)
|
||||
|
||||
|
||||
def _make_efficientnet_backbone(effnet):
|
||||
pretrained = nn.Module()
|
||||
|
||||
pretrained.layer1 = nn.Sequential(
|
||||
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
|
||||
)
|
||||
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
|
||||
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
|
||||
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
|
||||
|
||||
return pretrained
|
||||
|
||||
|
||||
def _make_resnet_backbone(resnet):
|
||||
pretrained = nn.Module()
|
||||
pretrained.layer1 = nn.Sequential(
|
||||
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
|
||||
)
|
||||
|
||||
pretrained.layer2 = resnet.layer2
|
||||
pretrained.layer3 = resnet.layer3
|
||||
pretrained.layer4 = resnet.layer4
|
||||
|
||||
return pretrained
|
||||
|
||||
|
||||
def _make_pretrained_resnext101_wsl(use_pretrained):
|
||||
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
|
||||
return _make_resnet_backbone(resnet)
|
||||
|
||||
|
||||
|
||||
class Interpolate(nn.Module):
|
||||
"""Interpolation module.
|
||||
"""
|
||||
|
||||
def __init__(self, scale_factor, mode, align_corners=False):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
scale_factor (float): scaling
|
||||
mode (str): interpolation mode
|
||||
"""
|
||||
super(Interpolate, self).__init__()
|
||||
|
||||
self.interp = nn.functional.interpolate
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input
|
||||
|
||||
Returns:
|
||||
tensor: interpolated data
|
||||
"""
|
||||
|
||||
x = self.interp(
|
||||
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResidualConvUnit(nn.Module):
|
||||
"""Residual convolution module.
|
||||
"""
|
||||
|
||||
def __init__(self, features):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
||||
)
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
||||
)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
out = self.relu(x)
|
||||
out = self.conv1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
|
||||
return out + x
|
||||
|
||||
|
||||
class FeatureFusionBlock(nn.Module):
|
||||
"""Feature fusion block.
|
||||
"""
|
||||
|
||||
def __init__(self, features):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super(FeatureFusionBlock, self).__init__()
|
||||
|
||||
self.resConfUnit1 = ResidualConvUnit(features)
|
||||
self.resConfUnit2 = ResidualConvUnit(features)
|
||||
|
||||
def forward(self, *xs):
|
||||
"""Forward pass.
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
output = xs[0]
|
||||
|
||||
if len(xs) == 2:
|
||||
output += self.resConfUnit1(xs[1])
|
||||
|
||||
output = self.resConfUnit2(output)
|
||||
|
||||
output = nn.functional.interpolate(
|
||||
output, scale_factor=2, mode="bilinear", align_corners=True
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
||||
|
||||
class ResidualConvUnit_custom(nn.Module):
|
||||
"""Residual convolution module.
|
||||
"""
|
||||
|
||||
def __init__(self, features, activation, bn):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.bn = bn
|
||||
|
||||
self.groups=1
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
||||
)
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
||||
)
|
||||
|
||||
if self.bn==True:
|
||||
self.bn1 = nn.BatchNorm2d(features)
|
||||
self.bn2 = nn.BatchNorm2d(features)
|
||||
|
||||
self.activation = activation
|
||||
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
|
||||
out = self.activation(x)
|
||||
out = self.conv1(out)
|
||||
if self.bn==True:
|
||||
out = self.bn1(out)
|
||||
|
||||
out = self.activation(out)
|
||||
out = self.conv2(out)
|
||||
if self.bn==True:
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.groups > 1:
|
||||
out = self.conv_merge(out)
|
||||
|
||||
return self.skip_add.add(out, x)
|
||||
|
||||
# return out + x
|
||||
|
||||
|
||||
class FeatureFusionBlock_custom(nn.Module):
|
||||
"""Feature fusion block.
|
||||
"""
|
||||
|
||||
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super(FeatureFusionBlock_custom, self).__init__()
|
||||
|
||||
self.deconv = deconv
|
||||
self.align_corners = align_corners
|
||||
|
||||
self.groups=1
|
||||
|
||||
self.expand = expand
|
||||
out_features = features
|
||||
if self.expand==True:
|
||||
out_features = features//2
|
||||
|
||||
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
||||
|
||||
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
||||
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
||||
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
|
||||
self.size=size
|
||||
|
||||
def forward(self, *xs, size=None):
|
||||
"""Forward pass.
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
output = xs[0]
|
||||
|
||||
if len(xs) == 2:
|
||||
res = self.resConfUnit1(xs[1])
|
||||
output = self.skip_add.add(output, res)
|
||||
# output += res
|
||||
|
||||
output = self.resConfUnit2(output)
|
||||
|
||||
if (size is None) and (self.size is None):
|
||||
modifier = {"scale_factor": 2}
|
||||
elif size is None:
|
||||
modifier = {"size": self.size}
|
||||
else:
|
||||
modifier = {"size": size}
|
||||
|
||||
output = nn.functional.interpolate(
|
||||
output, **modifier, mode="bilinear", align_corners=self.align_corners
|
||||
)
|
||||
|
||||
output = self.out_conv(output)
|
||||
|
||||
return output
|
||||
|
||||
166
midas/dpt_depth.py
Normal file
166
midas/dpt_depth.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .base_model import BaseModel
|
||||
from .blocks import (
|
||||
FeatureFusionBlock_custom,
|
||||
Interpolate,
|
||||
_make_encoder,
|
||||
forward_beit,
|
||||
forward_swin,
|
||||
forward_levit,
|
||||
forward_vit,
|
||||
)
|
||||
from .backbones.levit import stem_b4_transpose
|
||||
from timm.models.layers import get_act_layer
|
||||
|
||||
|
||||
def _make_fusion_block(features, use_bn, size = None):
|
||||
return FeatureFusionBlock_custom(
|
||||
features,
|
||||
nn.ReLU(False),
|
||||
deconv=False,
|
||||
bn=use_bn,
|
||||
expand=False,
|
||||
align_corners=True,
|
||||
size=size,
|
||||
)
|
||||
|
||||
|
||||
class DPT(BaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
head,
|
||||
features=256,
|
||||
backbone="vitb_rn50_384",
|
||||
readout="project",
|
||||
channels_last=False,
|
||||
use_bn=False,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
super(DPT, self).__init__()
|
||||
|
||||
self.channels_last = channels_last
|
||||
|
||||
# For the Swin, Swin 2, LeViT and Next-ViT Transformers, the hierarchical architectures prevent setting the
|
||||
# hooks freely. Instead, the hooks have to be chosen according to the ranges specified in the comments.
|
||||
hooks = {
|
||||
"beitl16_512": [5, 11, 17, 23],
|
||||
"beitl16_384": [5, 11, 17, 23],
|
||||
"beitb16_384": [2, 5, 8, 11],
|
||||
"swin2l24_384": [1, 1, 17, 1], # Allowed ranges: [0, 1], [0, 1], [ 0, 17], [ 0, 1]
|
||||
"swin2b24_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1]
|
||||
"swin2t16_256": [1, 1, 5, 1], # [0, 1], [0, 1], [ 0, 5], [ 0, 1]
|
||||
"swinl12_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1]
|
||||
"next_vit_large_6m": [2, 6, 36, 39], # [0, 2], [3, 6], [ 7, 36], [37, 39]
|
||||
"levit_384": [3, 11, 21], # [0, 3], [6, 11], [14, 21]
|
||||
"vitb_rn50_384": [0, 1, 8, 11],
|
||||
"vitb16_384": [2, 5, 8, 11],
|
||||
"vitl16_384": [5, 11, 17, 23],
|
||||
}[backbone]
|
||||
|
||||
if "next_vit" in backbone:
|
||||
in_features = {
|
||||
"next_vit_large_6m": [96, 256, 512, 1024],
|
||||
}[backbone]
|
||||
else:
|
||||
in_features = None
|
||||
|
||||
# Instantiate backbone and reassemble blocks
|
||||
self.pretrained, self.scratch = _make_encoder(
|
||||
backbone,
|
||||
features,
|
||||
False, # Set to true of you want to train from scratch, uses ImageNet weights
|
||||
groups=1,
|
||||
expand=False,
|
||||
exportable=False,
|
||||
hooks=hooks,
|
||||
use_readout=readout,
|
||||
in_features=in_features,
|
||||
)
|
||||
|
||||
self.number_layers = len(hooks) if hooks is not None else 4
|
||||
size_refinenet3 = None
|
||||
self.scratch.stem_transpose = None
|
||||
|
||||
if "beit" in backbone:
|
||||
self.forward_transformer = forward_beit
|
||||
elif "swin" in backbone:
|
||||
self.forward_transformer = forward_swin
|
||||
elif "next_vit" in backbone:
|
||||
from .backbones.next_vit import forward_next_vit
|
||||
self.forward_transformer = forward_next_vit
|
||||
elif "levit" in backbone:
|
||||
self.forward_transformer = forward_levit
|
||||
size_refinenet3 = 7
|
||||
self.scratch.stem_transpose = stem_b4_transpose(256, 128, get_act_layer("hard_swish"))
|
||||
else:
|
||||
self.forward_transformer = forward_vit
|
||||
|
||||
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
||||
self.scratch.refinenet3 = _make_fusion_block(features, use_bn, size_refinenet3)
|
||||
if self.number_layers >= 4:
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
||||
|
||||
self.scratch.output_conv = head
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
if self.channels_last == True:
|
||||
x.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
layers = self.forward_transformer(self.pretrained, x)
|
||||
if self.number_layers == 3:
|
||||
layer_1, layer_2, layer_3 = layers
|
||||
else:
|
||||
layer_1, layer_2, layer_3, layer_4 = layers
|
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||
if self.number_layers >= 4:
|
||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||
|
||||
if self.number_layers == 3:
|
||||
path_3 = self.scratch.refinenet3(layer_3_rn, size=layer_2_rn.shape[2:])
|
||||
else:
|
||||
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||
|
||||
if self.scratch.stem_transpose is not None:
|
||||
path_1 = self.scratch.stem_transpose(path_1)
|
||||
|
||||
out = self.scratch.output_conv(path_1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class DPTDepthModel(DPT):
|
||||
def __init__(self, path=None, non_negative=True, **kwargs):
|
||||
features = kwargs["features"] if "features" in kwargs else 256
|
||||
head_features_1 = kwargs["head_features_1"] if "head_features_1" in kwargs else features
|
||||
head_features_2 = kwargs["head_features_2"] if "head_features_2" in kwargs else 32
|
||||
kwargs.pop("head_features_1", None)
|
||||
kwargs.pop("head_features_2", None)
|
||||
|
||||
head = nn.Sequential(
|
||||
nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1),
|
||||
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
||||
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
||||
nn.ReLU(True) if non_negative else nn.Identity(),
|
||||
nn.Identity(),
|
||||
)
|
||||
|
||||
super().__init__(head, **kwargs)
|
||||
|
||||
if path is not None:
|
||||
self.load(path)
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x).squeeze(dim=1)
|
||||
76
midas/midas_net.py
Normal file
76
midas/midas_net.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
||||
This file contains code that is adapted from
|
||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .base_model import BaseModel
|
||||
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
|
||||
|
||||
|
||||
class MidasNet(BaseModel):
|
||||
"""Network for monocular depth estimation.
|
||||
"""
|
||||
|
||||
def __init__(self, path=None, features=256, non_negative=True):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
path (str, optional): Path to saved model. Defaults to None.
|
||||
features (int, optional): Number of features. Defaults to 256.
|
||||
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
||||
"""
|
||||
print("Loading weights: ", path)
|
||||
|
||||
super(MidasNet, self).__init__()
|
||||
|
||||
use_pretrained = False if path is None else True
|
||||
|
||||
self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
|
||||
|
||||
self.scratch.refinenet4 = FeatureFusionBlock(features)
|
||||
self.scratch.refinenet3 = FeatureFusionBlock(features)
|
||||
self.scratch.refinenet2 = FeatureFusionBlock(features)
|
||||
self.scratch.refinenet1 = FeatureFusionBlock(features)
|
||||
|
||||
self.scratch.output_conv = nn.Sequential(
|
||||
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
||||
Interpolate(scale_factor=2, mode="bilinear"),
|
||||
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
||||
nn.ReLU(True) if non_negative else nn.Identity(),
|
||||
)
|
||||
|
||||
if path:
|
||||
self.load(path)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input data (image)
|
||||
|
||||
Returns:
|
||||
tensor: depth
|
||||
"""
|
||||
|
||||
layer_1 = self.pretrained.layer1(x)
|
||||
layer_2 = self.pretrained.layer2(layer_1)
|
||||
layer_3 = self.pretrained.layer3(layer_2)
|
||||
layer_4 = self.pretrained.layer4(layer_3)
|
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||
|
||||
path_4 = self.scratch.refinenet4(layer_4_rn)
|
||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||
|
||||
out = self.scratch.output_conv(path_1)
|
||||
|
||||
return torch.squeeze(out, dim=1)
|
||||
128
midas/midas_net_custom.py
Normal file
128
midas/midas_net_custom.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
||||
This file contains code that is adapted from
|
||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .base_model import BaseModel
|
||||
from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
|
||||
|
||||
|
||||
class MidasNet_small(BaseModel):
|
||||
"""Network for monocular depth estimation.
|
||||
"""
|
||||
|
||||
def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
|
||||
blocks={'expand': True}):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
path (str, optional): Path to saved model. Defaults to None.
|
||||
features (int, optional): Number of features. Defaults to 256.
|
||||
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
||||
"""
|
||||
print("Loading weights: ", path)
|
||||
|
||||
super(MidasNet_small, self).__init__()
|
||||
|
||||
use_pretrained = False if path else True
|
||||
|
||||
self.channels_last = channels_last
|
||||
self.blocks = blocks
|
||||
self.backbone = backbone
|
||||
|
||||
self.groups = 1
|
||||
|
||||
features1=features
|
||||
features2=features
|
||||
features3=features
|
||||
features4=features
|
||||
self.expand = False
|
||||
if "expand" in self.blocks and self.blocks['expand'] == True:
|
||||
self.expand = True
|
||||
features1=features
|
||||
features2=features*2
|
||||
features3=features*4
|
||||
features4=features*8
|
||||
|
||||
self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
|
||||
|
||||
self.scratch.activation = nn.ReLU(False)
|
||||
|
||||
self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
||||
self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
||||
self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
||||
self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
|
||||
|
||||
|
||||
self.scratch.output_conv = nn.Sequential(
|
||||
nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
|
||||
Interpolate(scale_factor=2, mode="bilinear"),
|
||||
nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
|
||||
self.scratch.activation,
|
||||
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
||||
nn.ReLU(True) if non_negative else nn.Identity(),
|
||||
nn.Identity(),
|
||||
)
|
||||
|
||||
if path:
|
||||
self.load(path)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input data (image)
|
||||
|
||||
Returns:
|
||||
tensor: depth
|
||||
"""
|
||||
if self.channels_last==True:
|
||||
print("self.channels_last = ", self.channels_last)
|
||||
x.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
|
||||
layer_1 = self.pretrained.layer1(x)
|
||||
layer_2 = self.pretrained.layer2(layer_1)
|
||||
layer_3 = self.pretrained.layer3(layer_2)
|
||||
layer_4 = self.pretrained.layer4(layer_3)
|
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||
|
||||
|
||||
path_4 = self.scratch.refinenet4(layer_4_rn)
|
||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||
|
||||
out = self.scratch.output_conv(path_1)
|
||||
|
||||
return torch.squeeze(out, dim=1)
|
||||
|
||||
|
||||
|
||||
def fuse_model(m):
|
||||
prev_previous_type = nn.Identity()
|
||||
prev_previous_name = ''
|
||||
previous_type = nn.Identity()
|
||||
previous_name = ''
|
||||
for name, module in m.named_modules():
|
||||
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
|
||||
# print("FUSED ", prev_previous_name, previous_name, name)
|
||||
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
|
||||
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
|
||||
# print("FUSED ", prev_previous_name, previous_name)
|
||||
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
|
||||
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
|
||||
# print("FUSED ", previous_name, name)
|
||||
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
|
||||
|
||||
prev_previous_type = previous_type
|
||||
prev_previous_name = previous_name
|
||||
previous_type = type(module)
|
||||
previous_name = name
|
||||
242
midas/model_loader.py
Normal file
242
midas/model_loader.py
Normal file
@@ -0,0 +1,242 @@
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from midas.dpt_depth import DPTDepthModel
|
||||
from midas.midas_net import MidasNet
|
||||
from midas.midas_net_custom import MidasNet_small
|
||||
from midas.transforms import Resize, NormalizeImage, PrepareForNet
|
||||
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
default_models = {
|
||||
"dpt_beit_large_512": "weights/dpt_beit_large_512.pt",
|
||||
"dpt_beit_large_384": "weights/dpt_beit_large_384.pt",
|
||||
"dpt_beit_base_384": "weights/dpt_beit_base_384.pt",
|
||||
"dpt_swin2_large_384": "weights/dpt_swin2_large_384.pt",
|
||||
"dpt_swin2_base_384": "weights/dpt_swin2_base_384.pt",
|
||||
"dpt_swin2_tiny_256": "weights/dpt_swin2_tiny_256.pt",
|
||||
"dpt_swin_large_384": "weights/dpt_swin_large_384.pt",
|
||||
"dpt_next_vit_large_384": "weights/dpt_next_vit_large_384.pt",
|
||||
"dpt_levit_224": "weights/dpt_levit_224.pt",
|
||||
"dpt_large_384": "weights/dpt_large_384.pt",
|
||||
"dpt_hybrid_384": "weights/dpt_hybrid_384.pt",
|
||||
"midas_v21_384": "weights/midas_v21_384.pt",
|
||||
"midas_v21_small_256": "weights/midas_v21_small_256.pt",
|
||||
"openvino_midas_v21_small_256": "weights/openvino_midas_v21_small_256.xml",
|
||||
}
|
||||
|
||||
|
||||
def load_model(device, model_path, model_type="dpt_large_384", optimize=True, height=None, square=False):
|
||||
"""Load the specified network.
|
||||
|
||||
Args:
|
||||
device (device): the torch device used
|
||||
model_path (str): path to saved model
|
||||
model_type (str): the type of the model to be loaded
|
||||
optimize (bool): optimize the model to half-integer on CUDA?
|
||||
height (int): inference encoder image height
|
||||
square (bool): resize to a square resolution?
|
||||
|
||||
Returns:
|
||||
The loaded network, the transform which prepares images as input to the network and the dimensions of the
|
||||
network input
|
||||
"""
|
||||
if "openvino" in model_type:
|
||||
from openvino.runtime import Core
|
||||
|
||||
keep_aspect_ratio = not square
|
||||
|
||||
if model_type == "dpt_beit_large_512":
|
||||
model = DPTDepthModel(
|
||||
path=model_path,
|
||||
backbone="beitl16_512",
|
||||
non_negative=True,
|
||||
)
|
||||
net_w, net_h = 512, 512
|
||||
resize_mode = "minimal"
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
elif model_type == "dpt_beit_large_384":
|
||||
model = DPTDepthModel(
|
||||
path=model_path,
|
||||
backbone="beitl16_384",
|
||||
non_negative=True,
|
||||
)
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = "minimal"
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
elif model_type == "dpt_beit_base_384":
|
||||
model = DPTDepthModel(
|
||||
path=model_path,
|
||||
backbone="beitb16_384",
|
||||
non_negative=True,
|
||||
)
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = "minimal"
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
elif model_type == "dpt_swin2_large_384":
|
||||
model = DPTDepthModel(
|
||||
path=model_path,
|
||||
backbone="swin2l24_384",
|
||||
non_negative=True,
|
||||
)
|
||||
net_w, net_h = 384, 384
|
||||
keep_aspect_ratio = False
|
||||
resize_mode = "minimal"
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
elif model_type == "dpt_swin2_base_384":
|
||||
model = DPTDepthModel(
|
||||
path=model_path,
|
||||
backbone="swin2b24_384",
|
||||
non_negative=True,
|
||||
)
|
||||
net_w, net_h = 384, 384
|
||||
keep_aspect_ratio = False
|
||||
resize_mode = "minimal"
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
elif model_type == "dpt_swin2_tiny_256":
|
||||
model = DPTDepthModel(
|
||||
path=model_path,
|
||||
backbone="swin2t16_256",
|
||||
non_negative=True,
|
||||
)
|
||||
net_w, net_h = 256, 256
|
||||
keep_aspect_ratio = False
|
||||
resize_mode = "minimal"
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
elif model_type == "dpt_swin_large_384":
|
||||
model = DPTDepthModel(
|
||||
path=model_path,
|
||||
backbone="swinl12_384",
|
||||
non_negative=True,
|
||||
)
|
||||
net_w, net_h = 384, 384
|
||||
keep_aspect_ratio = False
|
||||
resize_mode = "minimal"
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
elif model_type == "dpt_next_vit_large_384":
|
||||
model = DPTDepthModel(
|
||||
path=model_path,
|
||||
backbone="next_vit_large_6m",
|
||||
non_negative=True,
|
||||
)
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = "minimal"
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
# We change the notation from dpt_levit_224 (MiDaS notation) to levit_384 (timm notation) here, where the 224 refers
|
||||
# to the resolution 224x224 used by LeViT and 384 is the first entry of the embed_dim, see _cfg and model_cfgs of
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/levit.py
|
||||
# (commit id: 927f031293a30afb940fff0bee34b85d9c059b0e)
|
||||
elif model_type == "dpt_levit_224":
|
||||
model = DPTDepthModel(
|
||||
path=model_path,
|
||||
backbone="levit_384",
|
||||
non_negative=True,
|
||||
head_features_1=64,
|
||||
head_features_2=8,
|
||||
)
|
||||
net_w, net_h = 224, 224
|
||||
keep_aspect_ratio = False
|
||||
resize_mode = "minimal"
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
elif model_type == "dpt_large_384":
|
||||
model = DPTDepthModel(
|
||||
path=model_path,
|
||||
backbone="vitl16_384",
|
||||
non_negative=True,
|
||||
)
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = "minimal"
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
elif model_type == "dpt_hybrid_384":
|
||||
model = DPTDepthModel(
|
||||
path=model_path,
|
||||
backbone="vitb_rn50_384",
|
||||
non_negative=True,
|
||||
)
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = "minimal"
|
||||
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
elif model_type == "midas_v21_384":
|
||||
model = MidasNet(model_path, non_negative=True)
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = "upper_bound"
|
||||
normalization = NormalizeImage(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
|
||||
elif model_type == "midas_v21_small_256":
|
||||
model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
|
||||
non_negative=True, blocks={'expand': True})
|
||||
net_w, net_h = 256, 256
|
||||
resize_mode = "upper_bound"
|
||||
normalization = NormalizeImage(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
|
||||
elif model_type == "openvino_midas_v21_small_256":
|
||||
ie = Core()
|
||||
uncompiled_model = ie.read_model(model=model_path)
|
||||
model = ie.compile_model(uncompiled_model, "CPU")
|
||||
net_w, net_h = 256, 256
|
||||
resize_mode = "upper_bound"
|
||||
normalization = NormalizeImage(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
|
||||
else:
|
||||
print(f"model_type '{model_type}' not implemented, use: --model_type large")
|
||||
assert False
|
||||
|
||||
if not "openvino" in model_type:
|
||||
print("Model loaded, number of parameters = {:.0f}M".format(sum(p.numel() for p in model.parameters()) / 1e6))
|
||||
else:
|
||||
print("Model loaded, optimized with OpenVINO")
|
||||
|
||||
if "openvino" in model_type:
|
||||
keep_aspect_ratio = False
|
||||
|
||||
if height is not None:
|
||||
net_w, net_h = height, height
|
||||
|
||||
transform = Compose(
|
||||
[
|
||||
Resize(
|
||||
net_w,
|
||||
net_h,
|
||||
resize_target=None,
|
||||
keep_aspect_ratio=keep_aspect_ratio,
|
||||
ensure_multiple_of=32,
|
||||
resize_method=resize_mode,
|
||||
image_interpolation_method=cv2.INTER_CUBIC,
|
||||
),
|
||||
normalization,
|
||||
PrepareForNet(),
|
||||
]
|
||||
)
|
||||
|
||||
if not "openvino" in model_type:
|
||||
model.eval()
|
||||
|
||||
if optimize and (device == torch.device("cuda")):
|
||||
if not "openvino" in model_type:
|
||||
model = model.to(memory_format=torch.channels_last)
|
||||
model = model.half()
|
||||
else:
|
||||
print("Error: OpenVINO models are already optimized. No optimization to half-float possible.")
|
||||
exit()
|
||||
|
||||
if not "openvino" in model_type:
|
||||
model.to(device)
|
||||
|
||||
return model, transform, net_w, net_h
|
||||
234
midas/transforms.py
Normal file
234
midas/transforms.py
Normal file
@@ -0,0 +1,234 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
import math
|
||||
|
||||
|
||||
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
||||
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
||||
|
||||
Args:
|
||||
sample (dict): sample
|
||||
size (tuple): image size
|
||||
|
||||
Returns:
|
||||
tuple: new size
|
||||
"""
|
||||
shape = list(sample["disparity"].shape)
|
||||
|
||||
if shape[0] >= size[0] and shape[1] >= size[1]:
|
||||
return sample
|
||||
|
||||
scale = [0, 0]
|
||||
scale[0] = size[0] / shape[0]
|
||||
scale[1] = size[1] / shape[1]
|
||||
|
||||
scale = max(scale)
|
||||
|
||||
shape[0] = math.ceil(scale * shape[0])
|
||||
shape[1] = math.ceil(scale * shape[1])
|
||||
|
||||
# resize
|
||||
sample["image"] = cv2.resize(
|
||||
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
|
||||
)
|
||||
|
||||
sample["disparity"] = cv2.resize(
|
||||
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
|
||||
)
|
||||
sample["mask"] = cv2.resize(
|
||||
sample["mask"].astype(np.float32),
|
||||
tuple(shape[::-1]),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
sample["mask"] = sample["mask"].astype(bool)
|
||||
|
||||
return tuple(shape)
|
||||
|
||||
|
||||
class Resize(object):
|
||||
"""Resize sample to given size (width, height).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width,
|
||||
height,
|
||||
resize_target=True,
|
||||
keep_aspect_ratio=False,
|
||||
ensure_multiple_of=1,
|
||||
resize_method="lower_bound",
|
||||
image_interpolation_method=cv2.INTER_AREA,
|
||||
):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
width (int): desired output width
|
||||
height (int): desired output height
|
||||
resize_target (bool, optional):
|
||||
True: Resize the full sample (image, mask, target).
|
||||
False: Resize image only.
|
||||
Defaults to True.
|
||||
keep_aspect_ratio (bool, optional):
|
||||
True: Keep the aspect ratio of the input sample.
|
||||
Output sample might not have the given width and height, and
|
||||
resize behaviour depends on the parameter 'resize_method'.
|
||||
Defaults to False.
|
||||
ensure_multiple_of (int, optional):
|
||||
Output width and height is constrained to be multiple of this parameter.
|
||||
Defaults to 1.
|
||||
resize_method (str, optional):
|
||||
"lower_bound": Output will be at least as large as the given size.
|
||||
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
||||
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
||||
Defaults to "lower_bound".
|
||||
"""
|
||||
self.__width = width
|
||||
self.__height = height
|
||||
|
||||
self.__resize_target = resize_target
|
||||
self.__keep_aspect_ratio = keep_aspect_ratio
|
||||
self.__multiple_of = ensure_multiple_of
|
||||
self.__resize_method = resize_method
|
||||
self.__image_interpolation_method = image_interpolation_method
|
||||
|
||||
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
||||
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||
|
||||
if max_val is not None and y > max_val:
|
||||
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||
|
||||
if y < min_val:
|
||||
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||
|
||||
return y
|
||||
|
||||
def get_size(self, width, height):
|
||||
# determine new height and width
|
||||
scale_height = self.__height / height
|
||||
scale_width = self.__width / width
|
||||
|
||||
if self.__keep_aspect_ratio:
|
||||
if self.__resize_method == "lower_bound":
|
||||
# scale such that output size is lower bound
|
||||
if scale_width > scale_height:
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
elif self.__resize_method == "upper_bound":
|
||||
# scale such that output size is upper bound
|
||||
if scale_width < scale_height:
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
elif self.__resize_method == "minimal":
|
||||
# scale as least as possbile
|
||||
if abs(1 - scale_width) < abs(1 - scale_height):
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
else:
|
||||
raise ValueError(
|
||||
f"resize_method {self.__resize_method} not implemented"
|
||||
)
|
||||
|
||||
if self.__resize_method == "lower_bound":
|
||||
new_height = self.constrain_to_multiple_of(
|
||||
scale_height * height, min_val=self.__height
|
||||
)
|
||||
new_width = self.constrain_to_multiple_of(
|
||||
scale_width * width, min_val=self.__width
|
||||
)
|
||||
elif self.__resize_method == "upper_bound":
|
||||
new_height = self.constrain_to_multiple_of(
|
||||
scale_height * height, max_val=self.__height
|
||||
)
|
||||
new_width = self.constrain_to_multiple_of(
|
||||
scale_width * width, max_val=self.__width
|
||||
)
|
||||
elif self.__resize_method == "minimal":
|
||||
new_height = self.constrain_to_multiple_of(scale_height * height)
|
||||
new_width = self.constrain_to_multiple_of(scale_width * width)
|
||||
else:
|
||||
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
||||
|
||||
return (new_width, new_height)
|
||||
|
||||
def __call__(self, sample):
|
||||
width, height = self.get_size(
|
||||
sample["image"].shape[1], sample["image"].shape[0]
|
||||
)
|
||||
|
||||
# resize sample
|
||||
sample["image"] = cv2.resize(
|
||||
sample["image"],
|
||||
(width, height),
|
||||
interpolation=self.__image_interpolation_method,
|
||||
)
|
||||
|
||||
if self.__resize_target:
|
||||
if "disparity" in sample:
|
||||
sample["disparity"] = cv2.resize(
|
||||
sample["disparity"],
|
||||
(width, height),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
|
||||
if "depth" in sample:
|
||||
sample["depth"] = cv2.resize(
|
||||
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
|
||||
)
|
||||
|
||||
sample["mask"] = cv2.resize(
|
||||
sample["mask"].astype(np.float32),
|
||||
(width, height),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
sample["mask"] = sample["mask"].astype(bool)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class NormalizeImage(object):
|
||||
"""Normlize image by given mean and std.
|
||||
"""
|
||||
|
||||
def __init__(self, mean, std):
|
||||
self.__mean = mean
|
||||
self.__std = std
|
||||
|
||||
def __call__(self, sample):
|
||||
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class PrepareForNet(object):
|
||||
"""Prepare sample for usage as network input.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, sample):
|
||||
image = np.transpose(sample["image"], (2, 0, 1))
|
||||
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
||||
|
||||
if "mask" in sample:
|
||||
sample["mask"] = sample["mask"].astype(np.float32)
|
||||
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
||||
|
||||
if "disparity" in sample:
|
||||
disparity = sample["disparity"].astype(np.float32)
|
||||
sample["disparity"] = np.ascontiguousarray(disparity)
|
||||
|
||||
if "depth" in sample:
|
||||
depth = sample["depth"].astype(np.float32)
|
||||
sample["depth"] = np.ascontiguousarray(depth)
|
||||
|
||||
return sample
|
||||
127
nerf/clip.py
Normal file
127
nerf/clip.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as TF
|
||||
|
||||
# import clip
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer, CLIPProcessor
|
||||
from torchvision import transforms
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def spherical_dist_loss(x, y):
|
||||
x = F.normalize(x, dim=-1)
|
||||
y = F.normalize(y, dim=-1)
|
||||
# print(x.shape, y.shape)
|
||||
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
|
||||
|
||||
|
||||
class CLIP(nn.Module):
|
||||
def __init__(self, device,
|
||||
# clip_name = 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
|
||||
clip_name = 'openai/clip-vit-large-patch14'
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.device = device
|
||||
|
||||
clip_name = clip_name
|
||||
|
||||
# self.feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_name)
|
||||
self.clip_model = CLIPModel.from_pretrained(clip_name).cuda()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14')
|
||||
# self.tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch32')
|
||||
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
||||
|
||||
# self.normalize = transforms.Normalize(mean=self.feature_extractor.image_mean, std=self.feature_extractor.image_std)
|
||||
|
||||
# self.resize = transforms.Resize(224)
|
||||
|
||||
# # image augmentation
|
||||
# self.aug = T.Compose([
|
||||
# T.Resize((224, 224)),
|
||||
# T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
# ])
|
||||
|
||||
|
||||
def get_text_embeds(self, prompt, neg_prompt=None, dir=None):
|
||||
|
||||
clip_text_input = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).input_ids.cuda()
|
||||
text_z = self.clip_model.get_text_features(clip_text_input)
|
||||
# text = clip.tokenize(prompt).to(self.device)
|
||||
# text_z = self.clip_model.encode_text(text)
|
||||
text_z = text_z / text_z.norm(dim=-1, keepdim=True)
|
||||
|
||||
return text_z
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
pass
|
||||
|
||||
def get_img_embeds(self, img):
|
||||
img = self.aug(img)
|
||||
image_z = self.clip_model.get_image_features(img)
|
||||
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
|
||||
return image_z
|
||||
|
||||
|
||||
# def train_step(self, text_z, pred_rgb, image_ref_clip, **kwargs):
|
||||
|
||||
# pred_rgb = self.resize(pred_rgb)
|
||||
# pred_rgb = self.normalize(pred_rgb)
|
||||
|
||||
# image_z = self.clip_model.get_image_features(pred_rgb)
|
||||
# image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
|
||||
|
||||
# # print(image_z.shape, text_z.shape)
|
||||
# loss = spherical_dist_loss(image_z, image_ref_clip)
|
||||
|
||||
# # loss = - (image_z * text_z).sum(-1).mean()
|
||||
|
||||
# return loss
|
||||
|
||||
def train_step(self, text_z, pred_rgb):
|
||||
|
||||
pred_rgb = self.aug(pred_rgb)
|
||||
|
||||
image_z = self.clip_model.encode_image(pred_rgb)
|
||||
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
|
||||
|
||||
loss = - (image_z * text_z).sum(-1).mean()
|
||||
# loss = spherical_dist_loss(image_z, text_z)
|
||||
return loss
|
||||
|
||||
def text_loss(self, text_z, pred_rgb):
|
||||
|
||||
pred_rgb = self.resize(pred_rgb)
|
||||
pred_rgb = self.normalize(pred_rgb)
|
||||
|
||||
image_z = self.clip_model.get_image_features(pred_rgb)
|
||||
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
|
||||
|
||||
# print(image_z.shape, text_z.shape)
|
||||
loss = spherical_dist_loss(image_z, text_z)
|
||||
|
||||
# loss = - (image_z * text_z).sum(-1).mean()
|
||||
|
||||
return loss
|
||||
|
||||
def img_loss(self, img_ref_z, pred_rgb):
|
||||
# pred_rgb = self.aug(pred_rgb)
|
||||
pred_rgb = self.resize(pred_rgb)
|
||||
pred_rgb = self.normalize(pred_rgb)
|
||||
|
||||
image_z = self.clip_model.get_image_features(pred_rgb)
|
||||
image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
|
||||
|
||||
# loss = - (image_z * img_ref_z).sum(-1).mean()
|
||||
loss = spherical_dist_loss(image_z, img_ref_z)
|
||||
|
||||
return loss
|
||||
485
nerf/gui.py
Normal file
485
nerf/gui.py
Normal file
@@ -0,0 +1,485 @@
|
||||
import math
|
||||
import torch
|
||||
import numpy as np
|
||||
import dearpygui.dearpygui as dpg
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
|
||||
from nerf.utils import *
|
||||
|
||||
|
||||
class OrbitCamera:
|
||||
def __init__(self, W, H, r=2, fovy=60):
|
||||
self.W = W
|
||||
self.H = H
|
||||
self.radius = r # camera distance from center
|
||||
self.fovy = fovy # in degree
|
||||
self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point
|
||||
self.rot = R.from_matrix(np.eye(3))
|
||||
self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized!
|
||||
self.near = 0.001
|
||||
self.far = 1000
|
||||
|
||||
# pose
|
||||
@property
|
||||
def pose(self):
|
||||
# first move camera to radius
|
||||
res = np.eye(4, dtype=np.float32)
|
||||
res[2, 3] = self.radius
|
||||
# rotate
|
||||
rot = np.eye(4, dtype=np.float32)
|
||||
rot[:3, :3] = self.rot.as_matrix()
|
||||
res = rot @ res
|
||||
# translate
|
||||
res[:3, 3] -= self.center
|
||||
return res
|
||||
|
||||
# intrinsics
|
||||
@property
|
||||
def intrinsics(self):
|
||||
focal = self.H / (2 * np.tan(np.deg2rad(self.fovy) / 2))
|
||||
return np.array([focal, focal, self.W // 2, self.H // 2])
|
||||
|
||||
@property
|
||||
def mvp(self):
|
||||
focal = self.H / (2 * np.tan(np.deg2rad(self.fovy) / 2))
|
||||
projection = np.array([
|
||||
[2*focal/self.W, 0, 0, 0],
|
||||
[0, -2*focal/self.H, 0, 0],
|
||||
[0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)],
|
||||
[0, 0, -1, 0]
|
||||
], dtype=np.float32)
|
||||
|
||||
return projection @ np.linalg.inv(self.pose) # [4, 4]
|
||||
|
||||
def orbit(self, dx, dy):
|
||||
# rotate along camera up/side axis!
|
||||
side = self.rot.as_matrix()[:3, 0] # why this is side --> ? # already normalized.
|
||||
rotvec_x = self.up * np.deg2rad(-0.1 * dx)
|
||||
rotvec_y = side * np.deg2rad(-0.1 * dy)
|
||||
self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot
|
||||
|
||||
def scale(self, delta):
|
||||
self.radius *= 1.1 ** (-delta)
|
||||
|
||||
def pan(self, dx, dy, dz=0):
|
||||
# pan in camera coordinate system (careful on the sensitivity!)
|
||||
self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([dx, -dy, dz])
|
||||
|
||||
|
||||
class NeRFGUI:
|
||||
def __init__(self, opt, trainer, loader=None, debug=True):
|
||||
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
|
||||
self.W = opt.W
|
||||
self.H = opt.H
|
||||
self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
|
||||
self.debug = debug
|
||||
self.bg_color = torch.ones(3, dtype=torch.float32) # default white bg
|
||||
self.training = False
|
||||
self.step = 0 # training step
|
||||
|
||||
self.trainer = trainer
|
||||
self.loader = loader
|
||||
self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
|
||||
self.need_update = True # camera moved, should reset accumulation
|
||||
self.spp = 1 # sample per pixel
|
||||
self.light_dir = np.array([opt.light_theta, opt.light_phi])
|
||||
self.ambient_ratio = 1.0
|
||||
self.mode = 'image' # choose from ['image', 'depth']
|
||||
self.shading = 'albedo'
|
||||
|
||||
self.dynamic_resolution = True if not self.opt.dmtet else False
|
||||
self.downscale = 1
|
||||
self.train_steps = 16
|
||||
|
||||
dpg.create_context()
|
||||
self.register_dpg()
|
||||
self.test_step()
|
||||
|
||||
|
||||
def __del__(self):
|
||||
dpg.destroy_context()
|
||||
|
||||
|
||||
def train_step(self):
|
||||
|
||||
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
||||
starter.record()
|
||||
|
||||
outputs = self.trainer.train_gui(self.loader, step=self.train_steps)
|
||||
|
||||
ender.record()
|
||||
torch.cuda.synchronize()
|
||||
t = starter.elapsed_time(ender)
|
||||
|
||||
self.step += self.train_steps
|
||||
self.need_update = True
|
||||
|
||||
dpg.set_value("_log_train_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
|
||||
dpg.set_value("_log_train_log", f'step = {self.step: 5d} (+{self.train_steps: 2d}), loss = {outputs["loss"]:.4f}, lr = {outputs["lr"]:.5f}')
|
||||
|
||||
# dynamic train steps
|
||||
# max allowed train time per-frame is 500 ms
|
||||
full_t = t / self.train_steps * 16
|
||||
train_steps = min(16, max(4, int(16 * 500 / full_t)))
|
||||
if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8:
|
||||
self.train_steps = train_steps
|
||||
|
||||
|
||||
def prepare_buffer(self, outputs):
|
||||
if self.mode == 'image':
|
||||
return outputs['image'].astype(np.float32)
|
||||
else:
|
||||
depth = outputs['depth'].astype(np.float32)
|
||||
depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-6)
|
||||
return np.expand_dims(depth, -1).repeat(3, -1)
|
||||
|
||||
|
||||
def test_step(self):
|
||||
|
||||
if self.need_update or self.spp < self.opt.max_spp:
|
||||
|
||||
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
||||
starter.record()
|
||||
|
||||
outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.cam.mvp, self.W, self.H, self.bg_color, self.spp, self.downscale, self.light_dir, self.ambient_ratio, self.shading)
|
||||
|
||||
ender.record()
|
||||
torch.cuda.synchronize()
|
||||
t = starter.elapsed_time(ender)
|
||||
|
||||
# update dynamic resolution
|
||||
if self.dynamic_resolution:
|
||||
# max allowed infer time per-frame is 200 ms
|
||||
full_t = t / (self.downscale ** 2)
|
||||
downscale = min(1, max(1/4, math.sqrt(200 / full_t)))
|
||||
if downscale > self.downscale * 1.2 or downscale < self.downscale * 0.8:
|
||||
self.downscale = downscale
|
||||
|
||||
if self.need_update:
|
||||
self.render_buffer = self.prepare_buffer(outputs)
|
||||
self.spp = 1
|
||||
self.need_update = False
|
||||
else:
|
||||
self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1)
|
||||
self.spp += 1
|
||||
|
||||
dpg.set_value("_log_infer_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
|
||||
dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}')
|
||||
dpg.set_value("_log_spp", self.spp)
|
||||
dpg.set_value("_texture", self.render_buffer)
|
||||
|
||||
|
||||
def register_dpg(self):
|
||||
|
||||
### register texture
|
||||
|
||||
with dpg.texture_registry(show=False):
|
||||
dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture")
|
||||
|
||||
### register window
|
||||
|
||||
# the rendered image, as the primary window
|
||||
with dpg.window(tag="_primary_window", width=self.W, height=self.H):
|
||||
|
||||
# add the texture
|
||||
dpg.add_image("_texture")
|
||||
|
||||
dpg.set_primary_window("_primary_window", True)
|
||||
|
||||
# control window
|
||||
with dpg.window(label="Control", tag="_control_window", width=400, height=300):
|
||||
|
||||
# text prompt
|
||||
if self.opt.text is not None:
|
||||
dpg.add_text("text: " + self.opt.text, tag="_log_prompt_text")
|
||||
|
||||
if self.opt.negative != '':
|
||||
dpg.add_text("negative text: " + self.opt.negative, tag="_log_prompt_negative_text")
|
||||
|
||||
# button theme
|
||||
with dpg.theme() as theme_button:
|
||||
with dpg.theme_component(dpg.mvButton):
|
||||
dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))
|
||||
dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))
|
||||
dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))
|
||||
dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)
|
||||
dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)
|
||||
|
||||
# time
|
||||
if not self.opt.test:
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_text("Train time: ")
|
||||
dpg.add_text("no data", tag="_log_train_time")
|
||||
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_text("Infer time: ")
|
||||
dpg.add_text("no data", tag="_log_infer_time")
|
||||
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_text("SPP: ")
|
||||
dpg.add_text("1", tag="_log_spp")
|
||||
|
||||
# train button
|
||||
if not self.opt.test:
|
||||
with dpg.collapsing_header(label="Train", default_open=True):
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_text("Train: ")
|
||||
|
||||
def callback_train(sender, app_data):
|
||||
if self.training:
|
||||
self.training = False
|
||||
dpg.configure_item("_button_train", label="start")
|
||||
else:
|
||||
self.training = True
|
||||
dpg.configure_item("_button_train", label="stop")
|
||||
|
||||
dpg.add_button(label="start", tag="_button_train", callback=callback_train)
|
||||
dpg.bind_item_theme("_button_train", theme_button)
|
||||
|
||||
def callback_reset(sender, app_data):
|
||||
@torch.no_grad()
|
||||
def weight_reset(m: nn.Module):
|
||||
reset_parameters = getattr(m, "reset_parameters", None)
|
||||
if callable(reset_parameters):
|
||||
m.reset_parameters()
|
||||
self.trainer.model.apply(fn=weight_reset)
|
||||
self.trainer.model.reset_extra_state() # for cuda_ray density_grid and step_counter
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_button(label="reset", tag="_button_reset", callback=callback_reset)
|
||||
dpg.bind_item_theme("_button_reset", theme_button)
|
||||
|
||||
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_text("Checkpoint: ")
|
||||
|
||||
def callback_save(sender, app_data):
|
||||
self.trainer.save_checkpoint(full=True, best=False)
|
||||
dpg.set_value("_log_ckpt", "saved " + os.path.basename(self.trainer.stats["checkpoints"][-1]))
|
||||
self.trainer.epoch += 1 # use epoch to indicate different calls.
|
||||
|
||||
dpg.add_button(label="save", tag="_button_save", callback=callback_save)
|
||||
dpg.bind_item_theme("_button_save", theme_button)
|
||||
|
||||
dpg.add_text("", tag="_log_ckpt")
|
||||
|
||||
# save mesh
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_text("Marching Cubes: ")
|
||||
|
||||
def callback_mesh(sender, app_data):
|
||||
self.trainer.save_mesh()
|
||||
dpg.set_value("_log_mesh", "saved " + f'{self.trainer.name}_{self.trainer.epoch}.ply')
|
||||
self.trainer.epoch += 1 # use epoch to indicate different calls.
|
||||
|
||||
dpg.add_button(label="mesh", tag="_button_mesh", callback=callback_mesh)
|
||||
dpg.bind_item_theme("_button_mesh", theme_button)
|
||||
|
||||
dpg.add_text("", tag="_log_mesh")
|
||||
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_text("", tag="_log_train_log")
|
||||
|
||||
|
||||
# rendering options
|
||||
with dpg.collapsing_header(label="Options", default_open=True):
|
||||
|
||||
# dynamic rendering resolution
|
||||
with dpg.group(horizontal=True):
|
||||
|
||||
def callback_set_dynamic_resolution(sender, app_data):
|
||||
if self.dynamic_resolution:
|
||||
self.dynamic_resolution = False
|
||||
self.downscale = 1
|
||||
else:
|
||||
self.dynamic_resolution = True
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_checkbox(label="dynamic resolution", default_value=self.dynamic_resolution, callback=callback_set_dynamic_resolution)
|
||||
dpg.add_text(f"{self.W}x{self.H}", tag="_log_resolution")
|
||||
|
||||
# mode combo
|
||||
def callback_change_mode(sender, app_data):
|
||||
self.mode = app_data
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_combo(('image', 'depth'), label='mode', default_value=self.mode, callback=callback_change_mode)
|
||||
|
||||
# bg_color picker
|
||||
def callback_change_bg(sender, app_data):
|
||||
self.bg_color = torch.tensor(app_data[:3], dtype=torch.float32) # only need RGB in [0, 1]
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_color_edit((255, 255, 255), label="Background Color", width=200, tag="_color_editor", no_alpha=True, callback=callback_change_bg)
|
||||
|
||||
# fov slider
|
||||
def callback_set_fovy(sender, app_data):
|
||||
self.cam.fovy = app_data
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy)
|
||||
|
||||
# dt_gamma slider
|
||||
def callback_set_dt_gamma(sender, app_data):
|
||||
self.opt.dt_gamma = app_data
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_slider_float(label="dt_gamma", min_value=0, max_value=0.1, format="%.5f", default_value=self.opt.dt_gamma, callback=callback_set_dt_gamma)
|
||||
|
||||
# max_steps slider
|
||||
def callback_set_max_steps(sender, app_data):
|
||||
self.opt.max_steps = app_data
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_slider_int(label="max steps", min_value=1, max_value=1024, format="%d", default_value=self.opt.max_steps, callback=callback_set_max_steps)
|
||||
|
||||
# aabb slider
|
||||
def callback_set_aabb(sender, app_data, user_data):
|
||||
# user_data is the dimension for aabb (xmin, ymin, zmin, xmax, ymax, zmax)
|
||||
self.trainer.model.aabb_infer[user_data] = app_data
|
||||
|
||||
# also change train aabb ? [better not...]
|
||||
#self.trainer.model.aabb_train[user_data] = app_data
|
||||
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_separator()
|
||||
dpg.add_text("Axis-aligned bounding box:")
|
||||
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_slider_float(label="x", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=0)
|
||||
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=3)
|
||||
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_slider_float(label="y", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=1)
|
||||
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=4)
|
||||
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_slider_float(label="z", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=2)
|
||||
dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=5)
|
||||
|
||||
# light dir
|
||||
def callback_set_light_dir(sender, app_data, user_data):
|
||||
self.light_dir[user_data] = app_data
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_separator()
|
||||
dpg.add_text("Plane Light Direction:")
|
||||
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_slider_float(label="theta", min_value=0, max_value=180, format="%.2f", default_value=self.opt.light_theta, callback=callback_set_light_dir, user_data=0)
|
||||
|
||||
with dpg.group(horizontal=True):
|
||||
dpg.add_slider_float(label="phi", min_value=0, max_value=360, format="%.2f", default_value=self.opt.light_phi, callback=callback_set_light_dir, user_data=1)
|
||||
|
||||
# ambient ratio
|
||||
def callback_set_abm_ratio(sender, app_data):
|
||||
self.ambient_ratio = app_data
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_slider_float(label="ambient", min_value=0, max_value=1.0, format="%.5f", default_value=self.ambient_ratio, callback=callback_set_abm_ratio)
|
||||
|
||||
# shading mode
|
||||
def callback_change_shading(sender, app_data):
|
||||
self.shading = app_data
|
||||
self.need_update = True
|
||||
|
||||
dpg.add_combo(('albedo', 'lambertian', 'textureless', 'normal'), label='shading', default_value=self.shading, callback=callback_change_shading)
|
||||
|
||||
|
||||
# debug info
|
||||
if self.debug:
|
||||
with dpg.collapsing_header(label="Debug"):
|
||||
# pose
|
||||
dpg.add_separator()
|
||||
dpg.add_text("Camera Pose:")
|
||||
dpg.add_text(str(self.cam.pose), tag="_log_pose")
|
||||
|
||||
|
||||
### register camera handler
|
||||
|
||||
def callback_camera_drag_rotate(sender, app_data):
|
||||
|
||||
if not dpg.is_item_focused("_primary_window"):
|
||||
return
|
||||
|
||||
dx = app_data[1]
|
||||
dy = app_data[2]
|
||||
|
||||
self.cam.orbit(dx, dy)
|
||||
self.need_update = True
|
||||
|
||||
if self.debug:
|
||||
dpg.set_value("_log_pose", str(self.cam.pose))
|
||||
|
||||
|
||||
def callback_camera_wheel_scale(sender, app_data):
|
||||
|
||||
if not dpg.is_item_focused("_primary_window"):
|
||||
return
|
||||
|
||||
delta = app_data
|
||||
|
||||
self.cam.scale(delta)
|
||||
self.need_update = True
|
||||
|
||||
if self.debug:
|
||||
dpg.set_value("_log_pose", str(self.cam.pose))
|
||||
|
||||
|
||||
def callback_camera_drag_pan(sender, app_data):
|
||||
|
||||
if not dpg.is_item_focused("_primary_window"):
|
||||
return
|
||||
|
||||
dx = app_data[1]
|
||||
dy = app_data[2]
|
||||
|
||||
self.cam.pan(dx, dy)
|
||||
self.need_update = True
|
||||
|
||||
if self.debug:
|
||||
dpg.set_value("_log_pose", str(self.cam.pose))
|
||||
|
||||
|
||||
with dpg.handler_registry():
|
||||
dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate)
|
||||
dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
|
||||
dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Right, callback=callback_camera_drag_pan)
|
||||
|
||||
|
||||
dpg.create_viewport(title='torch-ngp', width=self.W, height=self.H, resizable=False)
|
||||
|
||||
# TODO: seems dearpygui doesn't support resizing texture...
|
||||
# def callback_resize(sender, app_data):
|
||||
# self.W = app_data[0]
|
||||
# self.H = app_data[1]
|
||||
# # how to reload texture ???
|
||||
|
||||
# dpg.set_viewport_resize_callback(callback_resize)
|
||||
|
||||
### global theme
|
||||
with dpg.theme() as theme_no_padding:
|
||||
with dpg.theme_component(dpg.mvAll):
|
||||
# set all padding to 0 to avoid scroll bar
|
||||
dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core)
|
||||
dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core)
|
||||
dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core)
|
||||
|
||||
dpg.bind_item_theme("_primary_window", theme_no_padding)
|
||||
|
||||
dpg.setup_dearpygui()
|
||||
|
||||
#dpg.show_metrics()
|
||||
|
||||
dpg.show_viewport()
|
||||
|
||||
|
||||
def render(self):
|
||||
|
||||
while dpg.is_dearpygui_running():
|
||||
# update texture every frame
|
||||
if self.training:
|
||||
self.train_step()
|
||||
self.test_step()
|
||||
dpg.render_dearpygui_frame()
|
||||
238
nerf/network.py
Normal file
238
nerf/network.py
Normal file
@@ -0,0 +1,238 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from activation import trunc_exp
|
||||
from .renderer import NeRFRenderer
|
||||
|
||||
import numpy as np
|
||||
from encoding import get_encoder
|
||||
|
||||
from .utils import safe_normalize
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, bias=True):
|
||||
super().__init__()
|
||||
self.dim_in = dim_in
|
||||
self.dim_out = dim_out
|
||||
|
||||
self.dense = nn.Linear(self.dim_in, self.dim_out, bias=bias)
|
||||
self.norm = nn.LayerNorm(self.dim_out)
|
||||
self.activation = nn.SiLU(inplace=True)
|
||||
|
||||
if self.dim_in != self.dim_out:
|
||||
self.skip = nn.Linear(self.dim_in, self.dim_out, bias=False)
|
||||
else:
|
||||
self.skip = None
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, C]
|
||||
identity = x
|
||||
|
||||
out = self.dense(x)
|
||||
out = self.norm(out)
|
||||
|
||||
if self.skip is not None:
|
||||
identity = self.skip(identity)
|
||||
|
||||
out += identity
|
||||
out = self.activation(out)
|
||||
|
||||
return out
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, bias=True):
|
||||
super().__init__()
|
||||
self.dim_in = dim_in
|
||||
self.dim_out = dim_out
|
||||
|
||||
self.dense = nn.Linear(self.dim_in, self.dim_out, bias=bias)
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, C]
|
||||
|
||||
out = self.dense(x)
|
||||
out = self.activation(out)
|
||||
|
||||
return out
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True, block=BasicBlock):
|
||||
super().__init__()
|
||||
self.dim_in = dim_in
|
||||
self.dim_out = dim_out
|
||||
self.dim_hidden = dim_hidden
|
||||
self.num_layers = num_layers
|
||||
|
||||
net = []
|
||||
for l in range(num_layers):
|
||||
if l == 0:
|
||||
net.append(BasicBlock(self.dim_in, self.dim_hidden, bias=bias))
|
||||
elif l != num_layers - 1:
|
||||
net.append(block(self.dim_hidden, self.dim_hidden, bias=bias))
|
||||
else:
|
||||
net.append(nn.Linear(self.dim_hidden, self.dim_out, bias=bias))
|
||||
|
||||
self.net = nn.ModuleList(net)
|
||||
|
||||
def reset_parameters(self):
|
||||
@torch.no_grad()
|
||||
def weight_init(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
|
||||
nn.init.zeros_(m.bias)
|
||||
self.apply(weight_init)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
for l in range(self.num_layers):
|
||||
x = self.net[l](x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class NeRFNetwork(NeRFRenderer):
|
||||
def __init__(self,
|
||||
opt,
|
||||
num_layers=5, # 5 in paper
|
||||
hidden_dim=64, # 128 in paper
|
||||
num_layers_bg=2, # 3 in paper
|
||||
hidden_dim_bg=32, # 64 in paper
|
||||
encoding='frequency_torch', # pure pytorch
|
||||
output_dim=4, # 7 for DMTet (sdf 1 + color 3 + deform 3), 4 for NeRF
|
||||
):
|
||||
|
||||
super().__init__(opt)
|
||||
self.num_layers = opt.num_layers if hasattr(opt, 'num_layers') else num_layers
|
||||
self.hidden_dim = opt.hidden_dim if hasattr(opt, 'hidden_dim') else hidden_dim
|
||||
num_layers_bg = opt.num_layers_bg if hasattr(opt, 'num_layers_bg') else num_layers_bg
|
||||
hidden_dim_bg = opt.hidden_dim_bg if hasattr(opt, 'hidden_dim_bg') else hidden_dim_bg
|
||||
|
||||
self.encoder, self.in_dim = get_encoder(encoding, input_dim=3, multires=6)
|
||||
self.sigma_net = MLP(self.in_dim, output_dim, hidden_dim, num_layers, bias=True, block=ResBlock)
|
||||
|
||||
self.grid_levels_mask = 0
|
||||
|
||||
# background network
|
||||
if self.opt.bg_radius > 0:
|
||||
self.num_layers_bg = num_layers_bg
|
||||
self.hidden_dim_bg = hidden_dim_bg
|
||||
self.encoder_bg, self.in_dim_bg = get_encoder(encoding, input_dim=3, multires=4)
|
||||
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
|
||||
|
||||
else:
|
||||
self.bg_net = None
|
||||
|
||||
def common_forward(self, x):
|
||||
# x: [N, 3], in [-bound, bound]
|
||||
|
||||
# sigma
|
||||
h = self.encoder(x, bound=self.bound, max_level=self.max_level)
|
||||
|
||||
# Feature masking for coarse-to-fine training
|
||||
if self.grid_levels_mask > 0:
|
||||
h_mask: torch.Tensor = torch.arange(self.in_dim, device=h.device) < self.in_dim - self.grid_levels_mask * self.level_dim # (self.in_dim)
|
||||
h_mask = h_mask.reshape(1, self.in_dim).float() # (1, self.in_dim)
|
||||
h = h * h_mask # (N, self.in_dim)
|
||||
|
||||
h = self.sigma_net(h)
|
||||
|
||||
sigma = self.density_activation(h[..., 0] + self.density_blob(x))
|
||||
albedo = torch.sigmoid(h[..., 1:])
|
||||
|
||||
return sigma, albedo
|
||||
|
||||
def normal(self, x):
|
||||
|
||||
with torch.enable_grad():
|
||||
x.requires_grad_(True)
|
||||
sigma, albedo = self.common_forward(x)
|
||||
# query gradient
|
||||
normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
|
||||
|
||||
# normal = self.finite_difference_normal(x)
|
||||
normal = safe_normalize(normal)
|
||||
# normal = torch.nan_to_num(normal)
|
||||
|
||||
return normal
|
||||
|
||||
def forward(self, x, d, l=None, ratio=1, shading='albedo'):
|
||||
# x: [N, 3], in [-bound, bound]
|
||||
# d: [N, 3], view direction, nomalized in [-1, 1]
|
||||
# l: [3], plane light direction, nomalized in [-1, 1]
|
||||
# ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
|
||||
|
||||
if shading == 'albedo':
|
||||
# no need to query normal
|
||||
sigma, color = self.common_forward(x)
|
||||
normal = None
|
||||
|
||||
else:
|
||||
# query normal
|
||||
|
||||
# sigma, albedo = self.common_forward(x)
|
||||
# normal = self.normal(x)
|
||||
|
||||
with torch.enable_grad():
|
||||
x.requires_grad_(True)
|
||||
sigma, albedo = self.common_forward(x)
|
||||
# query gradient
|
||||
normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
|
||||
normal = safe_normalize(normal)
|
||||
# normal = torch.nan_to_num(normal)
|
||||
# normal = normal.detach()
|
||||
|
||||
# lambertian shading
|
||||
lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]
|
||||
|
||||
if shading == 'textureless':
|
||||
color = lambertian.unsqueeze(-1).repeat(1, 3)
|
||||
elif shading == 'normal':
|
||||
color = (normal + 1) / 2
|
||||
else: # 'lambertian'
|
||||
color = albedo * lambertian.unsqueeze(-1)
|
||||
|
||||
return sigma, color, normal
|
||||
|
||||
|
||||
def density(self, x):
|
||||
# x: [N, 3], in [-bound, bound]
|
||||
|
||||
sigma, albedo = self.common_forward(x)
|
||||
|
||||
return {
|
||||
'sigma': sigma,
|
||||
'albedo': albedo,
|
||||
}
|
||||
|
||||
|
||||
def background(self, d):
|
||||
|
||||
h = self.encoder_bg(d) # [N, C]
|
||||
|
||||
h = self.bg_net(h)
|
||||
|
||||
# sigmoid activation for rgb
|
||||
rgbs = torch.sigmoid(h)
|
||||
|
||||
return rgbs
|
||||
|
||||
# optimizer utils
|
||||
def get_params(self, lr):
|
||||
|
||||
params = [
|
||||
# {'params': self.encoder.parameters(), 'lr': lr * 10},
|
||||
{'params': self.sigma_net.parameters(), 'lr': lr * self.opt.lr_scale_nerf},
|
||||
]
|
||||
|
||||
if self.opt.bg_radius > 0:
|
||||
# params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
|
||||
params.append({'params': self.bg_net.parameters(), 'lr': lr})
|
||||
|
||||
if self.opt.dmtet:
|
||||
params.append({'params': self.dmtet.parameters(), 'lr': lr})
|
||||
|
||||
return params
|
||||
216
nerf/network_grid.py
Normal file
216
nerf/network_grid.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from activation import trunc_exp, biased_softplus
|
||||
from .renderer import NeRFRenderer, MLP
|
||||
|
||||
import numpy as np
|
||||
from encoding import get_encoder
|
||||
|
||||
from .utils import safe_normalize
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NeRFNetwork(NeRFRenderer):
|
||||
def __init__(self,
|
||||
opt,
|
||||
num_layers=3,
|
||||
hidden_dim=64,
|
||||
num_layers_bg=2,
|
||||
hidden_dim_bg=32,
|
||||
level_dim=2
|
||||
):
|
||||
|
||||
super().__init__(opt)
|
||||
|
||||
self.num_layers = opt.num_layers if hasattr(opt, 'num_layers') else num_layers
|
||||
self.hidden_dim = opt.hidden_dim if hasattr(opt, 'hidden_dim') else hidden_dim
|
||||
self.level_dim = opt.level_dim if hasattr(opt, 'level_dim') else level_dim
|
||||
num_layers_bg = opt.num_layers_bg if hasattr(opt, 'num_layers_bg') else num_layers_bg
|
||||
hidden_dim_bg = opt.hidden_dim_bg if hasattr(opt, 'hidden_dim_bg') else hidden_dim_bg
|
||||
|
||||
if self.opt.grid_type == 'hashgrid':
|
||||
self.encoder, self.in_dim = get_encoder('hashgrid', input_dim=3, log2_hashmap_size=19, desired_resolution=2048 * self.bound, interpolation='smoothstep')
|
||||
elif self.opt.grid_type == 'tilegrid':
|
||||
self.encoder, self.in_dim = get_encoder(
|
||||
'tiledgrid',
|
||||
input_dim=3,
|
||||
level_dim=self.level_dim,
|
||||
log2_hashmap_size=16,
|
||||
num_levels=16,
|
||||
desired_resolution= 2048 * self.bound,
|
||||
)
|
||||
self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
|
||||
# self.normal_net = MLP(self.in_dim, 3, hidden_dim, num_layers, bias=True)
|
||||
|
||||
# masking
|
||||
self.grid_levels_mask = 0
|
||||
|
||||
# background network
|
||||
if self.opt.bg_radius > 0:
|
||||
self.num_layers_bg = num_layers_bg
|
||||
self.hidden_dim_bg = hidden_dim_bg
|
||||
|
||||
# use a very simple network to avoid it learning the prompt...
|
||||
self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3, multires=6)
|
||||
self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
|
||||
|
||||
else:
|
||||
self.bg_net = None
|
||||
|
||||
def common_forward(self, x):
|
||||
|
||||
# sigma
|
||||
h = self.encoder(x, bound=self.bound, max_level=self.max_level)
|
||||
|
||||
# Feature masking for coarse-to-fine training
|
||||
if self.grid_levels_mask > 0:
|
||||
h_mask: torch.Tensor = torch.arange(self.in_dim, device=h.device) < self.in_dim - self.grid_levels_mask * self.level_dim # (self.in_dim)
|
||||
h_mask = h_mask.reshape(1, self.in_dim).float() # (1, self.in_dim)
|
||||
h = h * h_mask # (N, self.in_dim)
|
||||
|
||||
h = self.sigma_net(h)
|
||||
|
||||
sigma = self.density_activation(h[..., 0] + self.density_blob(x))
|
||||
albedo = torch.sigmoid(h[..., 1:])
|
||||
|
||||
return sigma, albedo
|
||||
|
||||
|
||||
def forward(self, x, d, l=None, ratio=1, shading='albedo'):
|
||||
# x: [N, 3], in [-bound, bound]
|
||||
# d: [N, 3], view direction, nomalized in [-1, 1]
|
||||
# l: [3], plane light direction, nomalized in [-1, 1]
|
||||
# ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
|
||||
|
||||
sigma, albedo = self.common_forward(x)
|
||||
|
||||
if shading == 'albedo':
|
||||
normal = None
|
||||
color = albedo
|
||||
|
||||
else: # lambertian shading
|
||||
|
||||
normal = self.normal(x)
|
||||
if shading == 'normal':
|
||||
color = (normal + 1) / 2
|
||||
else:
|
||||
lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]
|
||||
if shading == 'textureless':
|
||||
color = lambertian.unsqueeze(-1).repeat(1, 3)
|
||||
else: # 'lambertian'
|
||||
color = albedo * lambertian.unsqueeze(-1)
|
||||
|
||||
return sigma, color, normal
|
||||
|
||||
|
||||
def density(self, x):
|
||||
# x: [N, 3], in [-bound, bound]
|
||||
|
||||
sigma, albedo = self.common_forward(x)
|
||||
|
||||
return {
|
||||
'sigma': sigma,
|
||||
'albedo': albedo,
|
||||
}
|
||||
|
||||
|
||||
def background(self, d):
|
||||
|
||||
h = self.encoder_bg(d) # [N, C]
|
||||
|
||||
h = self.bg_net(h)
|
||||
|
||||
# sigmoid activation for rgb
|
||||
rgbs = torch.sigmoid(h)
|
||||
|
||||
return rgbs
|
||||
|
||||
# optimizer utils
|
||||
def get_params(self, lr):
|
||||
|
||||
params = [
|
||||
{'params': self.encoder.parameters(), 'lr': lr * 10},
|
||||
{'params': self.sigma_net.parameters(), 'lr': lr * self.opt.lr_scale_nerf},
|
||||
# {'params': self.normal_net.parameters(), 'lr': lr},
|
||||
]
|
||||
|
||||
if self.opt.bg_radius > 0:
|
||||
# params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
|
||||
params.append({'params': self.bg_net.parameters(), 'lr': lr})
|
||||
|
||||
if self.opt.dmtet:
|
||||
params.append({'params': self.dmtet.parameters(), 'lr': lr})
|
||||
|
||||
return params
|
||||
|
||||
def reset_sigmanet(self):
|
||||
self.sigma_net.reset_parameters()
|
||||
|
||||
def init_nerf_from_sdf_color(self, rpst, albedo,
|
||||
points=None, pretrain_iters=10000, lr=0.001, rpst_type='sdf',
|
||||
):
|
||||
self.reset_sigmanet()
|
||||
# matching optimization
|
||||
self.train()
|
||||
self.grid_levels_mask = 0
|
||||
loss_fn = torch.nn.MSELoss()
|
||||
optimizer = torch.optim.Adam(list(self.parameters()), lr=lr)
|
||||
|
||||
milestones = [int(0.4 * pretrain_iters), int(0.8 * pretrain_iters)]
|
||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)
|
||||
|
||||
rpst = rpst.squeeze().clamp(0, 1)
|
||||
|
||||
# rpst = torch.ones_like(rpst) * 0.4
|
||||
pbar = tqdm(range(pretrain_iters), desc="NeRF sigma optimization")
|
||||
rgb_loss = torch.tensor(0, device=rpst.device)
|
||||
for i in pbar:
|
||||
output = self.density(points)
|
||||
if rpst_type == 'sdf':
|
||||
pred_rpst = output['sigma'] - self.density_thresh
|
||||
else:
|
||||
pred_rpst = output['sigma']
|
||||
sdf_loss = loss_fn(pred_rpst, rpst)
|
||||
|
||||
if albedo is not None:
|
||||
pred_albedo = output['albedo']
|
||||
rgb_loss = loss_fn(pred_albedo, albedo)
|
||||
loss = 10 * sdf_loss + rgb_loss
|
||||
else:
|
||||
loss = sdf_loss
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
pbar.set_postfix(loss=loss.item(), rgb_loss=rgb_loss.item(), sdf_loss=sdf_loss.item())
|
||||
logger.info(f'lr: {lr} Accuracy: (pred_rpst[rpst>0]>0).sum() / (rpst>0).sum()')
|
||||
pbar = tqdm(range(pretrain_iters), desc="NeRF color optimization")
|
||||
|
||||
|
||||
def init_tet_from_sdf_color(self, sdf, colors=None, pretrain_iters=5000, lr=0.01):
|
||||
self.train()
|
||||
self.grid_levels_mask = 0
|
||||
|
||||
self.dmtet.reset_tet(reset_scale=False)
|
||||
self.dmtet.init_tet_from_sdf(sdf, pretrain_iters=pretrain_iters, lr=lr)
|
||||
|
||||
if colors is not None:
|
||||
self.reset_sigmanet()
|
||||
loss_fn = torch.nn.MSELoss()
|
||||
pretrain_iters = 5000
|
||||
optimizer = torch.optim.Adam(list(self.parameters()), lr=0.01)
|
||||
pbar = tqdm(range(pretrain_iters), desc="NeRF color optimization")
|
||||
for i in pbar:
|
||||
pred_albedo = self.density(self.dmtet.verts)['albedo']
|
||||
loss = loss_fn(pred_albedo, colors)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
pbar.set_postfix(loss=loss.item())
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user