diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..aceb1df
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,29 @@
+__pycache__/
+venv*
+workspace/
+out*
+saved_out*
+all_outputs*
+shap_e_model_cache/*
+slurm_logs/
+debug/
+notinclude/
+scripts/snap/yamls
+
+# */validataion
+*csv
+build/
+*.egg-info/
+*.so
+.vscode/
+
+tmp*
+data/
+trial*/
+.vs/
+
+TOKEN
+*.ckpt
+*.pt
+densegridencoder
+tets/256_tets.npz
diff --git a/LICENSE b/LICENSE
index 171269a..261eeb9 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,21 +1,201 @@
-MIT License
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
-Copyright (c) 2023 Gordon Guocheng Qian 钱国成
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
+ 1. Definitions.
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/activation.py b/activation.py
new file mode 100644
index 0000000..e6cba6a
--- /dev/null
+++ b/activation.py
@@ -0,0 +1,21 @@
+import torch
+from torch.autograd import Function
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+class _trunc_exp(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float)
+ def forward(ctx, x):
+ ctx.save_for_backward(x)
+ return torch.exp(x)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, g):
+ x = ctx.saved_tensors[0]
+ return g * torch.exp(x.clamp(max=15))
+
+trunc_exp = _trunc_exp.apply
+
+def biased_softplus(x, bias=0):
+ return torch.nn.functional.softplus(x - bias)
\ No newline at end of file
diff --git a/all_metrics/metric_utils.py b/all_metrics/metric_utils.py
new file mode 100755
index 0000000..a5f134a
--- /dev/null
+++ b/all_metrics/metric_utils.py
@@ -0,0 +1,459 @@
+# * evaluate use laion/CLIP-ViT-H-14-laion2B-s32B-b79K
+# best open source clip so far: laion/CLIP-ViT-bigG-14-laion2B-39B-b160k
+# code adapted from NeuralLift-360
+
+import torch
+import torch.nn as nn
+import os
+import torchvision.transforms as T
+import torchvision.transforms.functional as TF
+import matplotlib.pyplot as plt
+# import clip
+from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer, CLIPProcessor
+from torchvision import transforms
+import numpy as np
+import torch.nn.functional as F
+from tqdm import tqdm
+import cv2
+from PIL import Image
+# import torchvision.transforms as transforms
+import glob
+from skimage.metrics import peak_signal_noise_ratio as compare_psnr
+import lpips
+from os.path import join as osp
+import argparse
+import pandas as pd
+
+class CLIP(nn.Module):
+
+ def __init__(self,
+ device,
+ clip_name='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k',
+ size=224): #'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'):
+ super().__init__()
+ self.size = size
+ self.device = f"cuda:{device}"
+
+ clip_name = clip_name
+
+ self.feature_extractor = CLIPFeatureExtractor.from_pretrained(
+ clip_name)
+ self.clip_model = CLIPModel.from_pretrained(clip_name).to(self.device)
+ self.tokenizer = CLIPTokenizer.from_pretrained(
+ 'openai/clip-vit-base-patch32')
+
+ self.normalize = transforms.Normalize(
+ mean=self.feature_extractor.image_mean,
+ std=self.feature_extractor.image_std)
+
+ self.resize = transforms.Resize(224)
+ self.to_tensor = transforms.ToTensor()
+
+ # image augmentation
+ self.aug = T.Compose([
+ T.Resize((224, 224)),
+ T.Normalize((0.48145466, 0.4578275, 0.40821073),
+ (0.26862954, 0.26130258, 0.27577711)),
+ ])
+
+ # * recommend to use this function for evaluation
+ @torch.no_grad()
+ def score_gt(self, ref_img_path, novel_views):
+ # assert len(novel_views) == 100
+ clip_scores = []
+ for novel in novel_views:
+ clip_scores.append(self.score_from_path(ref_img_path, [novel]))
+ return np.mean(clip_scores)
+
+ # * recommend to use this function for evaluation
+ # def score_gt(self, ref_paths, novel_paths):
+ # clip_scores = []
+ # for img1_path, img2_path in zip(ref_paths, novel_paths):
+ # clip_scores.append(self.score_from_path(img1_path, img2_path))
+
+ # return np.mean(clip_scores)
+
+ def similarity(self, image1_features: torch.Tensor,
+ image2_features: torch.Tensor) -> float:
+ with torch.no_grad(), torch.cuda.amp.autocast():
+ y = image1_features.T.view(image1_features.T.shape[1],
+ image1_features.T.shape[0])
+ similarity = torch.matmul(y, image2_features.T)
+ # print(similarity)
+ return similarity[0][0].item()
+
+ def get_img_embeds(self, img):
+ if img.shape[0] == 4:
+ img = img[:3, :, :]
+
+ img = self.aug(img).to(self.device)
+ img = img.unsqueeze(0) # b,c,h,w
+
+ # plt.imshow(img.cpu().squeeze(0).permute(1, 2, 0).numpy())
+ # plt.show()
+ # print(img)
+
+ image_z = self.clip_model.get_image_features(img)
+ image_z = image_z / image_z.norm(dim=-1,
+ keepdim=True) # normalize features
+ return image_z
+
+ def score_from_feature(self, img1, img2):
+ img1_feature, img2_feature = self.get_img_embeds(
+ img1), self.get_img_embeds(img2)
+ # for debug
+ return self.similarity(img1_feature, img2_feature)
+
+ def read_img_list(self, img_list):
+ size = self.size
+ images = []
+ # white_background = np.ones((size, size, 3), dtype=np.uint8) * 255
+
+ for img_path in img_list:
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
+ # print(img_path)
+ if img.shape[2] == 4: # Handle BGRA images
+ alpha = img[:, :, 3] # Extract alpha channel
+ img = cv2.cvtColor(img,cv2.COLOR_BGRA2RGB) # Convert BGRA to BGR
+ img[np.where(alpha == 0)] = [
+ 255, 255, 255
+ ] # Set transparent pixels to white
+ else: # Handle other image formats like JPG and PNG
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+ img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
+
+ # plt.imshow(img)
+ # plt.show()
+
+ images.append(img)
+
+ images = np.stack(images, axis=0)
+ # images[np.where(images == 0)] = 255 # Set black pixels to white
+ # images = np.where(images == 0, white_background, images) # Set transparent pixels to white
+ # images = images.astype(np.float32)
+
+ return images
+
+ def score_from_path(self, img1_path, img2_path):
+ img1, img2 = self.read_img_list(img1_path), self.read_img_list(img2_path)
+ img1 = np.squeeze(img1)
+ img2 = np.squeeze(img2)
+ # plt.imshow(img1)
+ # plt.show()
+ # plt.imshow(img2)
+ # plt.show()
+
+ img1, img2 = self.to_tensor(img1), self.to_tensor(img2)
+ # print("img1 to tensor ",img1)
+ return self.score_from_feature(img1, img2)
+
+
+def numpy_to_torch(images):
+ images = images * 2.0 - 1.0
+ images = torch.from_numpy(images.transpose((0, 3, 1, 2))).float()
+ return images.cuda()
+
+
+class LPIPSMeter:
+
+ def __init__(self,
+ net='alex',
+ device=None,
+ size=224): # or we can use 'alex', 'vgg' as network
+ self.size = size
+ self.net = net
+ self.results = []
+ self.device = device if device is not None else torch.device(
+ 'cuda' if torch.cuda.is_available() else 'cpu')
+ self.fn = lpips.LPIPS(net=net).eval().to(self.device)
+
+ def measure(self):
+ return np.mean(self.results)
+
+ def report(self):
+ return f'LPIPS ({self.net}) = {self.measure():.6f}'
+
+ def read_img_list(self, img_list):
+ size = self.size
+ images = []
+ white_background = np.ones((size, size, 3), dtype=np.uint8) * 255
+
+ for img_path in img_list:
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
+
+ if img.shape[2] == 4: # Handle BGRA images
+ alpha = img[:, :, 3] # Extract alpha channel
+ img = cv2.cvtColor(img,
+ cv2.COLOR_BGRA2BGR) # Convert BGRA to BGR
+
+ img = cv2.cvtColor(img,
+ cv2.COLOR_BGR2RGB) # Convert BGR to RGB
+ img[np.where(alpha == 0)] = [
+ 255, 255, 255
+ ] # Set transparent pixels to white
+ else: # Handle other image formats like JPG and PNG
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+ img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
+ images.append(img)
+
+ images = np.stack(images, axis=0)
+ # images[np.where(images == 0)] = 255 # Set black pixels to white
+ # images = np.where(images == 0, white_background, images) # Set transparent pixels to white
+ images = images.astype(np.float32) / 255.0
+
+ return images
+
+ # * recommend to use this function for evaluation
+ @torch.no_grad()
+ def score_gt(self, ref_paths, novel_paths):
+ self.results = []
+ for path0, path1 in zip(ref_paths, novel_paths):
+ # Load images
+ # img0 = lpips.im2tensor(lpips.load_image(path0)).cuda() # RGB image from [-1,1]
+ # img1 = lpips.im2tensor(lpips.load_image(path1)).cuda()
+ img0, img1 = self.read_img_list([path0]), self.read_img_list(
+ [path1])
+ img0, img1 = numpy_to_torch(img0), numpy_to_torch(img1)
+ # print(img0.shape,img1.shape)
+ img0 = F.interpolate(img0,
+ size=(self.size, self.size),
+ mode='area')
+ img1 = F.interpolate(img1,
+ size=(self.size, self.size),
+ mode='area')
+
+ # for debug vis
+ # plt.imshow(img0.cpu().squeeze(0).permute(1, 2, 0).numpy())
+ # plt.show()
+ # plt.imshow(img1.cpu().squeeze(0).permute(1, 2, 0).numpy())
+ # plt.show()
+ # equivalent to cv2.resize(rgba, (w, h), interpolation=cv2.INTER_AREA
+
+ # print(img0.shape,img1.shape)
+
+ self.results.append(self.fn.forward(img0, img1).cpu().numpy())
+
+ return self.measure()
+
+
+class PSNRMeter:
+
+ def __init__(self, size=800):
+ self.results = []
+ self.size = size
+
+ def read_img_list(self, img_list):
+ size = self.size
+ images = []
+ white_background = np.ones((size, size, 3), dtype=np.uint8) * 255
+ for img_path in img_list:
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
+
+ if img.shape[2] == 4: # Handle BGRA images
+ alpha = img[:, :, 3] # Extract alpha channel
+ img = cv2.cvtColor(img,
+ cv2.COLOR_BGRA2BGR) # Convert BGRA to BGR
+
+ img = cv2.cvtColor(img,
+ cv2.COLOR_BGR2RGB) # Convert BGR to RGB
+ img[np.where(alpha == 0)] = [
+ 255, 255, 255
+ ] # Set transparent pixels to white
+ else: # Handle other image formats like JPG and PNG
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+ img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
+ images.append(img)
+
+ images = np.stack(images, axis=0)
+ # images[np.where(images == 0)] = 255 # Set black pixels to white
+ # images = np.where(images == 0, white_background, images) # Set transparent pixels to white
+ images = images.astype(np.float32) / 255.0
+ # print(images.shape)
+ return images
+
+ def update(self, preds, truths):
+ # print(preds.shape)
+
+ psnr_values = []
+ # For each pair of images in the batches
+ for img1, img2 in zip(preds, truths):
+ # Compute the PSNR and add it to the list
+ # print(img1.shape,img2.shape)
+
+ # for debug
+ # plt.imshow(img1)
+ # plt.show()
+ # plt.imshow(img2)
+ # plt.show()
+
+ psnr = compare_psnr(
+ img1, img2,
+ data_range=1.0) # assuming your images are scaled to [0,1]
+ # print(f"temp psnr {psnr}")
+ psnr_values.append(psnr)
+
+ # Convert the list of PSNR values to a numpy array
+ self.results = psnr_values
+
+ def measure(self):
+ return np.mean(self.results)
+
+ def report(self):
+ return f'PSNR = {self.measure():.6f}'
+
+ # * recommend to use this function for evaluation
+ def score_gt(self, ref_paths, novel_paths):
+ self.results = []
+ # [B, N, 3] or [B, H, W, 3], range[0, 1]
+ preds = self.read_img_list(ref_paths)
+ truths = self.read_img_list(novel_paths)
+ self.update(preds, truths)
+ return self.measure()
+
+all_inputs = 'data'
+nerf_dataset = os.listdir(osp(all_inputs, 'nerf4'))
+realfusion_dataset = os.listdir(osp(all_inputs, 'realfusion15'))
+meta_examples = {
+ 'nerf4': nerf_dataset,
+ 'realfusion15': realfusion_dataset,
+}
+all_datasets = meta_examples.keys()
+
+# organization 1
+def deprecated_score_from_method_for_dataset(my_scorer,
+ method,
+ dataset,
+ input,
+ output,
+ score_type='clip',
+ ): # psnr, lpips
+ # print("\n\n\n")
+ # print(f"______{method}___{dataset}___{score_type}_________")
+ scores = {}
+ final_res = 0
+ examples = meta_examples[dataset]
+ for i in range(len(examples)):
+
+ # compare entire folder for clip
+ if score_type == 'clip':
+ novel_view = osp(pred_path, examples[i], 'colors')
+ # compare first image for other metrics
+ else:
+ if method == '3d_fuse': method = '3d_fuse_0'
+ novel_view = list(
+ glob.glob(
+ osp(pred_path, examples[i], 'colors',
+ 'step_0000*')))[0]
+
+ score_i = my_scorer.score_gt(
+ [], [novel_view])
+ scores[examples[i]] = score_i
+ final_res += score_i
+ # print(scores, " Avg : ", final_res / len(examples))
+ # print("``````````````````````")
+ return scores
+
+# results organization 2
+def score_from_method_for_dataset(my_scorer,
+ input_path,
+ pred_path,
+ score_type='clip',
+ rgb_name='lambertian',
+ result_folder='results/images',
+ first_str='*0000*'
+ ): # psnr, lpips
+ scores = {}
+ final_res = 0
+ examples = os.listdir(input_path)
+ for i in range(len(examples)):
+ # ref path
+ ref_path = osp(input_path, examples[i], 'rgba.png')
+ # compare entire folder for clip
+ if score_type == 'clip':
+ novel_view = glob.glob(osp(pred_path,'*'+examples[i]+'*', result_folder, f'*{rgb_name}*'))
+ print(f'[INOF] {score_type} loss for example {examples[i]} between 1 GT and {len(novel_view)} predictions')
+ # compare first image for other metrics
+ else:
+ novel_view = glob.glob(osp(pred_path, '*'+examples[i]+'*/', result_folder, f'{first_str}{rgb_name}*'))
+ print(f'[INOF] {score_type} loss for example {examples[i]} between {ref_path} and {novel_view}')
+ # breakpoint()
+ score_i = my_scorer.score_gt([ref_path], novel_view)
+ scores[examples[i]] = score_i
+ final_res += score_i
+ avg_score = final_res / len(examples)
+ scores['average'] = avg_score
+ return scores
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Script to accept three string arguments")
+ parser.add_argument("--input_path",
+ default=all_inputs,
+ help="Specify the input path")
+ parser.add_argument("--pred_pattern",
+ default="out/magic123*",
+ help="Specify the pattern of predition paths")
+ parser.add_argument("--results_folder",
+ default="results/images",
+ help="where are the results under each pred_path")
+ parser.add_argument("--rgb_name",
+ default="lambertian",
+ help="the postfix of the image")
+ parser.add_argument("--first_str",
+ default="*0000*",
+ help="the str to indicate the first view")
+ parser.add_argument("--datasets",
+ default=all_datasets,
+ nargs='*',
+ help="Specify the output path")
+ parser.add_argument("--device",
+ type=int,
+ default=0,
+ help="Specify the GPU device to be used")
+ parser.add_argument("--save_dir", type=str, default='all_metrics/results')
+ args = parser.parse_args()
+
+ clip_scorer = CLIP(args.device)
+ lpips_scorer = LPIPSMeter()
+ psnr_scorer = PSNRMeter()
+
+ os.makedirs(args.save_dir, exist_ok=True)
+
+ for dataset in args.datasets:
+ input_path = osp(args.input_path, dataset)
+
+ # assume the pred_path is organized as: pred_path/methods/dataset
+ pred_pattern = osp(args.pred_pattern, dataset)
+ pred_paths = glob.glob(pred_pattern)
+ print(f"[INFO] Following the pattern {pred_pattern}, find {len(pred_paths)} pred_paths: \n", pred_paths)
+ if len(pred_paths) == 0:
+ raise IOError
+ for pred_path in pred_paths:
+ if not os.path.exists(pred_path):
+ print(f'[WARN] prediction does not exit for {pred_path}')
+ else:
+ print(f'[INFO] evaluate {pred_path}')
+ results_dict = {}
+ results_dict['clip'] = score_from_method_for_dataset(
+ clip_scorer, input_path, pred_path, 'clip',
+ result_folder=args.results_folder, rgb_name=args.rgb_name, first_str=args.first_str)
+
+ results_dict['psnr'] = score_from_method_for_dataset(
+ psnr_scorer, input_path, pred_path, 'psnr',
+ result_folder=args.results_folder, rgb_name=args.rgb_name, first_str=args.first_str)
+
+ results_dict['lpips'] = score_from_method_for_dataset(
+ lpips_scorer, input_path, pred_path, 'lpips',
+ result_folder=args.results_folder, rgb_name=args.rgb_name, first_str=args.first_str)
+
+ df = pd.DataFrame(results_dict)
+ method = pred_path.split('/')[-2]
+ print(osp(pred_path, args.results_folder))
+ results_str = '_'.join(args.results_folder.split('/'))
+ print(method+'-'+results_str)
+ print(df)
+ df.to_csv(f"{args.save_dir}/{method}-{results_str}-{dataset}.csv")
\ No newline at end of file
diff --git a/all_metrics/test.sh b/all_metrics/test.sh
new file mode 100755
index 0000000..01dcadf
--- /dev/null
+++ b/all_metrics/test.sh
@@ -0,0 +1,13 @@
+# python all_metrics/metric_utils.py --datasets nerf4 realfusion15 --pred_pattern "all_outputs/magic123/magic123-2d1-3d30-dmtet" --results_folder "results/images"
+
+# 2d only
+python all_metrics/metric_utils.py --datasets nerf4 realfusion15 --pred_pattern "all_outputs/magic123-2d/magic123-2d1*" --results_folder "results/images"
+
+# 3d only
+python all_metrics/metric_utils.py --datasets nerf4 realfusion15 --pred_pattern "all_outputs/magic123-3d/zero123-z40*" --results_folder "results/images"
+
+# python all_metrics/metric_utils.py --datasets nerf4 realfusion15 --pred_pattern "all_outputs/3d_fuse" --results_folder "color" --rgb_name "" --first_str "*_0_*"
+# python all_metrics/metric_utils.py --datasets nerf4 realfusion15 --pred_pattern "all_outputs/neural_lift" --results_folder "albedo" --rgb_name "albedo" --first_str "*0000*"
+# python all_metrics/metric_utils.py --datasets nerf4 realfusion15 --pred_pattern "all_outputs/real_fusion" --results_folder "colors" --rgb_name ""
+# python all_metrics/metric_utils.py --datasets nerf4 realfusion15 --pred_pattern "all_outputs/shape" --results_folder "" --rgb_name "" --first_str "0."
+# python all_metrics/metric_utils.py --datasets realfusion15 --pred_pattern "all_outputs/pointe-r100" --results_folder "" --rgb_name "" --first_str "0."
\ No newline at end of file
diff --git a/assets/advanced.md b/assets/advanced.md
new file mode 100644
index 0000000..2f9baca
--- /dev/null
+++ b/assets/advanced.md
@@ -0,0 +1,71 @@
+
+# Code organization & Advanced tips
+
+This is a simple description of the most important implementation details.
+If you are interested in improving this repo, this might be a starting point.
+Any contribution would be greatly appreciated!
+
+* The SDS loss is located at `./guidance/sd_utils.py > StableDiffusion > train_step`:
+```python
+## 1. we need to interpolate the NeRF rendering to 512x512, to feed it to SD's VAE.
+pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
+## 2. image (512x512) --- VAE --> latents (64x64), this is SD's difference from Imagen.
+latents = self.encode_imgs(pred_rgb_512)
+... # timestep sampling, noise adding and UNet noise predicting
+## 3. the SDS loss
+w = (1 - self.alphas[t])
+grad = w * (noise_pred - noise)
+# since UNet part is ignored and cannot simply audodiff, we have two ways to set the grad:
+# 3.1. call backward and set the grad now (need to retain graph since we will call a second backward for the other losses later)
+latents.backward(gradient=grad, retain_graph=True)
+return 0 # dummy loss
+# 3.2. use a custom function to set a hook in backward, so we only call backward once (credits to @elliottzheng)
+class SpecifyGradient(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd
+ def forward(ctx, input_tensor, gt_grad):
+ ctx.save_for_backward(gt_grad)
+ # we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
+ return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_scale):
+ gt_grad, = ctx.saved_tensors
+ gt_grad = gt_grad * grad_scale
+ return gt_grad, None
+
+loss = SpecifyGradient.apply(latents, grad)
+return loss # functional loss
+```
+* Other regularizations are in `./nerf/utils.py > Trainer > train_step`.
+ * The generation seems quite sensitive to regularizations on weights_sum (alphas for each ray). The original opacity loss tends to make NeRF disappear (zero density everywhere), so we use an entropy loss to replace it for now (encourages alpha to be either 0 or 1).
+* NeRF Rendering core function: `./nerf/renderer.py > NeRFRenderer > run & run_cuda`.
+* Shading & normal evaluation: `./nerf/network*.py > NeRFNetwork > forward`.
+ * light direction: current implementation use a plane light source, instead of a point light source.
+* View-dependent prompting: `./nerf/provider.py > get_view_direction`.
+ * use `--angle_overhead, --angle_front` to set the border.
+* Network backbone (`./nerf/network*.py`) can be chosen by the `--backbone` option.
+* Spatial density bias (density blob): `./nerf/network*.py > NeRFNetwork > density_blob`.
+
+
+# Debugging
+
+`debugpy-run` is a convenient way to remotely debug this project. Simply replace a command like this one:
+
+```bash
+python main.py --text "a hamburger" --workspace trial -O --vram_O
+```
+
+... with:
+
+```bash
+debugpy-run main.py -- --text "a hamburger" --workspace trial -O --vram_O
+```
+
+For more details: https://github.com/bulletmark/debugpy-run
+
+# Axes and directions of polar, azimuth, etc. in NeRF and Zero123
+
+
+
diff --git a/assets/update_logs.md b/assets/update_logs.md
new file mode 100644
index 0000000..b1c2e2c
--- /dev/null
+++ b/assets/update_logs.md
@@ -0,0 +1,39 @@
+### 2023.4.19
+* Fix depth supervision, migrate depth estimation model to omnidata.
+* Add normal supervision (also by omnidata).
+
+https://user-images.githubusercontent.com/25863658/232403294-b77409bf-ddc7-4bb8-af32-ee0cc123825a.mp4
+
+### 2023.4.7
+Improvement on mesh quality & DMTet finetuning support.
+
+https://user-images.githubusercontent.com/25863658/230535363-298c960e-bf9c-4906-8b96-cd60edcb24dd.mp4
+
+### 2023.3.30
+* adopt ideas from [Fantasia3D](https://fantasia3d.github.io/) to concatenate normal and mask as the latent code in a warm up stage, which shows faster convergence of shape.
+
+https://user-images.githubusercontent.com/25863658/230535373-6ee28f16-bb21-4ec4-bc86-d46597361a04.mp4
+
+### 2023.1.30
+* Use an MLP to predict the surface normals as in Magic3D to avoid finite difference / second order gradient, generation quality is greatly improved.
+* More efficient two-pass raymarching in training inspired by nerfacc.
+
+https://user-images.githubusercontent.com/25863658/215996308-9fd959f5-b5c7-4a8e-a241-0fe63ec86a4a.mp4
+
+### 2022.12.3
+* Support Stable-diffusion 2.0 base.
+
+### 2022.11.15
+* Add the vanilla backbone that is pure-pytorch.
+
+### 2022.10.9
+* The shading (partially) starts to work, at least it won't make scene empty. For some prompts, it shows better results (less severe Janus problem). The textureless rendering mode is still disabled.
+* Enable shading by default (--latent_iter_ratio 1000).
+
+### 2022.10.5
+* Basic reproduction finished.
+* Non --cuda_ray, --tcnn are not working, need to fix.
+* Shading is not working, disabled in utils.py for now. Surface normals are bad.
+* Use an entropy loss to regularize weights_sum (alpha), the original L2 reg always leads to degenerated geometry...
+
+https://user-images.githubusercontent.com/25863658/194241493-f3e68f78-aefe-479e-a4a8-001424a61b37.mp4
diff --git a/dnnultis/REAMD.me b/dnnultis/REAMD.me
new file mode 100644
index 0000000..b86c794
--- /dev/null
+++ b/dnnultis/REAMD.me
@@ -0,0 +1,16 @@
+# dnnutils
+dnnutils is a simple library for the commonly used utils for DNN research including:
+1. log (log/experiment directory, logging, tensorboard, wandb)
+2.
+
+
+
+dnnutils is designed to be compitabe for timm.
+
+
+# Install
+
+
+
+# TODO
+configs
\ No newline at end of file
diff --git a/dnnultis/__init__.py b/dnnultis/__init__.py
new file mode 100644
index 0000000..98fed2c
--- /dev/null
+++ b/dnnultis/__init__.py
@@ -0,0 +1 @@
+from .log import *
\ No newline at end of file
diff --git a/dnnultis/log/__init__.py b/dnnultis/log/__init__.py
new file mode 100644
index 0000000..298fa1c
--- /dev/null
+++ b/dnnultis/log/__init__.py
@@ -0,0 +1,2 @@
+from .logger import *
+from .wandb import *
\ No newline at end of file
diff --git a/dnnultis/log/logger.py b/dnnultis/log/logger.py
new file mode 100644
index 0000000..34c07c3
--- /dev/null
+++ b/dnnultis/log/logger.py
@@ -0,0 +1,86 @@
+import functools
+import logging
+import os, sys
+from termcolor import colored
+
+
+class _ColorfulFormatter(logging.Formatter):
+ def __init__(self, *args, **kwargs):
+ self._root_name = kwargs.pop("root_name") + "."
+ self._abbrev_name = kwargs.pop("abbrev_name", "")
+ if len(self._abbrev_name):
+ self._abbrev_name = self._abbrev_name + "."
+ super(_ColorfulFormatter, self).__init__(*args, **kwargs)
+
+ def formatMessage(self, record):
+ record.name = record.name.replace(self._root_name, self._abbrev_name)
+ log = super(_ColorfulFormatter, self).formatMessage(record)
+ if record.levelno == logging.WARNING:
+ prefix = colored("WARNING", "red", attrs=["blink"])
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
+ else:
+ return log
+ return prefix + " " + log
+
+
+# so that calling setup_logging.root multiple times won't add many handlers
+@functools.lru_cache()
+def setup_logging(output=None,
+ distributed_rank=0,
+ default_level=logging.INFO,
+ *,
+ color=True,
+ name=__name__):
+ """
+ Initialize the detectron2 logging.root and set its verbosity level to "INFO".
+ Args:
+ output (str): a file name or a directory to save log. If None, will not save log file.
+ If ends with ".txt" or ".log", assumed to be a file name.
+ Otherwise, logs will be saved to `output/log.txt`.
+ name (str): the root module name of this logging.root
+ Returns:
+ logging.logging.root: a logging.root
+ """
+ logging.root.setLevel(default_level)
+ logging.root.propagate = False
+
+ plain_formatter = logging.Formatter(
+ "[%(asctime)s] %(name)s %(levelname)s: %(message)s",
+ datefmt="%y/%m/%d %H:%M:%S")
+ # stdout logging: master only
+ if distributed_rank == 0:
+ ch = logging.StreamHandler(stream=sys.stdout)
+ ch.setLevel(default_level)
+ if color:
+ formatter = _ColorfulFormatter(
+ colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
+ datefmt="%y/%m/%d %H:%M:%S",
+ root_name=name,
+ )
+ else:
+ formatter = plain_formatter
+ ch.setFormatter(formatter)
+ logging.root.addHandler(ch)
+
+ # file logging: all workers
+ if output is not None:
+ if output.endswith(".txt") or output.endswith(".log"):
+ filename = output
+ else:
+ filename = os.path.join(output, "log.txt")
+ if distributed_rank > 0:
+ filename = filename + f".rank{distributed_rank}"
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+
+ fh = logging.StreamHandler(_cached_log_stream(filename))
+ fh.setLevel(default_level)
+ fh.setFormatter(plain_formatter)
+ logging.root.addHandler(fh)
+
+
+# cache the opened file object, so that different calls to `setup_logging.root`
+# with the same file name can safely write to the same file.
+@functools.lru_cache(maxsize=None)
+def _cached_log_stream(filename):
+ return open(filename, "a")
\ No newline at end of file
diff --git a/dnnultis/log/wandb.py b/dnnultis/log/wandb.py
new file mode 100644
index 0000000..c0ee1a0
--- /dev/null
+++ b/dnnultis/log/wandb.py
@@ -0,0 +1,84 @@
+import shutil
+import os
+import subprocess
+import wandb
+
+
+class WandbUrls:
+ def __init__(self, url):
+
+ hash = url.split("/")[-2]
+ project = url.split("/")[-3]
+ entity = url.split("/")[-4]
+
+ self.weight_url = url
+ self.log_url = "https://app.wandb.ai/{}/{}/runs/{}/logs".format(entity, project, hash)
+ self.chart_url = "https://app.wandb.ai/{}/{}/runs/{}".format(entity, project, hash)
+ self.overview_url = "https://app.wandb.ai/{}/{}/runs/{}/overview".format(entity, project, hash)
+ self.config_url = "https://app.wandb.ai/{}/{}/runs/{}/files/hydra-config.yaml".format(
+ entity, project, hash
+ )
+ self.overrides_url = "https://app.wandb.ai/{}/{}/runs/{}/files/overrides.yaml".format(entity, project, hash)
+
+ def __repr__(self):
+ msg = "=================================================== WANDB URLS ===================================================================\n"
+ for k, v in self.__dict__.items():
+ msg += "{}: {}\n".format(k.upper(), v)
+ msg += "=================================================================================================================================\n"
+ return msg
+
+
+class Wandb:
+ IS_ACTIVE = False
+
+ @staticmethod
+ def set_urls_to_model(model, url):
+ wandb_urls = WandbUrls(url)
+ model.wandb = wandb_urls
+
+ @staticmethod
+ def _set_to_wandb_args(wandb_args, cfg, name):
+ var = getattr(cfg.wandb, name, None)
+ if var:
+ wandb_args[name] = var
+
+ @staticmethod
+ def launch(cfg, launch: bool):
+ if launch:
+
+ Wandb.IS_ACTIVE = True
+
+ wandb_args = {}
+ wandb_args["resume"] = "allow"
+ Wandb._set_to_wandb_args(wandb_args, cfg, "tags")
+ Wandb._set_to_wandb_args(wandb_args, cfg, "project")
+ Wandb._set_to_wandb_args(wandb_args, cfg, "name")
+ Wandb._set_to_wandb_args(wandb_args, cfg, "entity")
+ Wandb._set_to_wandb_args(wandb_args, cfg, "notes")
+ Wandb._set_to_wandb_args(wandb_args, cfg, "config")
+ Wandb._set_to_wandb_args(wandb_args, cfg, "id")
+
+ try:
+ commit_sha = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()
+ gitdiff = subprocess.check_output(["git", "diff", "--", "':!notebooks'"]).decode()
+ except:
+ commit_sha = "n/a"
+ gitdiff = ""
+
+ config = wandb_args.get("config", {})
+ wandb_args["config"] = {
+ **config,
+ "run_path": os.getcwd(),
+ "commit": commit_sha,
+ "gitdiff": gitdiff
+ }
+ wandb.init(**wandb_args, sync_tensorboard=True)
+ wandb.save(os.path.join(os.getcwd(), cfg.cfg_path))
+
+ @staticmethod
+ def add_file(file_path: str):
+ if not Wandb.IS_ACTIVE:
+ raise RuntimeError("wandb is inactive, please launch first.")
+
+ filename = os.path.basename(file_path)
+ shutil.copyfile(file_path, os.path.join(wandb.run.dir, filename))
diff --git a/docker/Dockerfile b/docker/Dockerfile
new file mode 100644
index 0000000..47fd296
--- /dev/null
+++ b/docker/Dockerfile
@@ -0,0 +1,53 @@
+FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04
+
+# Remove any third-party apt sources to avoid issues with expiring keys.
+RUN rm -f /etc/apt/sources.list.d/*.list
+
+RUN apt-get update
+
+RUN DEBIAN_FRONTEND=noninteractive TZ=Europe/MADRID apt-get install -y tzdata
+
+# Install some basic utilities
+RUN apt-get install -y \
+ curl \
+ ca-certificates \
+ sudo \
+ git \
+ bzip2 \
+ libx11-6 \
+ python3 \
+ python3-pip \
+ libglfw3-dev \
+ libgles2-mesa-dev \
+ libglib2.0-0 \
+ && rm -rf /var/lib/apt/lists/*
+
+
+# Create a working directory
+RUN mkdir /app
+WORKDIR /app
+
+RUN cd /app
+RUN git clone https://github.com/ashawkey/stable-dreamfusion.git
+
+
+RUN pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
+
+WORKDIR /app/stable-dreamfusion
+
+RUN pip3 install -r requirements.txt
+RUN pip3 install git+https://github.com/NVlabs/nvdiffrast/
+
+# Needs nvidia runtime, if you have "No CUDA runtime is found" error: https://stackoverflow.com/questions/59691207/docker-build-with-nvidia-runtime, first answer
+RUN pip3 install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
+
+RUN pip3 install git+https://github.com/openai/CLIP.git
+RUN bash scripts/install_ext.sh
+
+
+
+
+
+# Set the default command to python3
+#CMD ["python3"]
+
diff --git a/docker/README.md b/docker/README.md
new file mode 100644
index 0000000..2fe00e4
--- /dev/null
+++ b/docker/README.md
@@ -0,0 +1,80 @@
+### Docker installation
+
+## Build image
+To build the docker image on your own machine, which may take 15-30 mins:
+```
+docker build -t stable-dreamfusion:latest .
+```
+
+If you have the error **No CUDA runtime is found** when building the wheels for tiny-cuda-nn you need to setup the nvidia-runtime for docker.
+```
+sudo apt-get install nvidia-container-runtime
+```
+Then edit `/etc/docker/daemon.json` and add the default-runtime:
+```
+{
+ "runtimes": {
+ "nvidia": {
+ "path": "nvidia-container-runtime",
+ "runtimeArgs": []
+ }
+ },
+ "default-runtime": "nvidia"
+}
+```
+And restart docker:
+```
+sudo systemctl restart docker
+```
+Now you can build tiny-cuda-nn inside docker.
+
+## Download image
+To download the image (~6GB) instead:
+```
+docker pull supercabb/stable-dreamfusion:3080_0.0.1
+docker tag supercabb/stable-dreamfusion:3080_0.0.1 stable-dreamfusion
+```
+
+## Use image
+
+You can launch an interactive shell inside the container:
+
+```
+docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash
+```
+From this shell, all the code in the repo should work.
+
+To run any single command `` inside the docker container:
+```
+docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash -c ""
+```
+To train:
+```
+export TOKEN="#HUGGING FACE ACCESS TOKEN#"
+docker run --gpus all -it --rm -v $(cd ~ && pwd):/mnt stable-dreamfusion /bin/bash -c "echo ${TOKEN} > TOKEN \
+&& python3 main.py --text \"a hamburger\" --workspace trial -O"
+
+```
+Run test without gui:
+```
+export PATH_TO_WORKSPACE="#PATH_TO_WORKSPACE#"
+docker run --gpus all -it --rm -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix:ro -v $(cd ~ && pwd):/mnt \
+-v $(cd ${PATH_TO_WORKSPACE} && pwd):/app/stable-dreamfusion/trial stable-dreamfusion /bin/bash -c "python3 \
+main.py --workspace trial -O --test"
+```
+Run test with gui:
+```
+export PATH_TO_WORKSPACE="#PATH_TO_WORKSPACE#"
+xhost +
+docker run --gpus all -it --rm -e DISPLAY=$DISPLAY -v /tmp/.X11-unix:/tmp/.X11-unix:ro -v $(cd ~ && pwd):/mnt \
+-v $(cd ${PATH_TO_WORKSPACE} && pwd):/app/stable-dreamfusion/trial stable-dreamfusion /bin/bash -c "python3 \
+main.py --workspace trial -O --test --gui"
+xhost -
+```
+
+
+
+
+
+
+
diff --git a/docs/static/ironman-val-magic123.gif b/docs/static/ironman-val-magic123.gif
new file mode 100644
index 0000000..ac8de84
Binary files /dev/null and b/docs/static/ironman-val-magic123.gif differ
diff --git a/docs/static/magic123-results.mp4 b/docs/static/magic123-results.mp4
new file mode 100644
index 0000000..1ec7011
Binary files /dev/null and b/docs/static/magic123-results.mp4 differ
diff --git a/docs/static/magic123.gif b/docs/static/magic123.gif
new file mode 100644
index 0000000..5616b1e
Binary files /dev/null and b/docs/static/magic123.gif differ
diff --git a/dpt.py b/dpt.py
new file mode 100644
index 0000000..8cc0479
--- /dev/null
+++ b/dpt.py
@@ -0,0 +1,924 @@
+import math
+import types
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import timm
+
+class BaseModel(torch.nn.Module):
+ def load(self, path):
+ """Load model from file.
+ Args:
+ path (str): file path
+ """
+ parameters = torch.load(path, map_location=torch.device('cpu'))
+
+ if "optimizer" in parameters:
+ parameters = parameters["model"]
+
+ self.load_state_dict(parameters)
+
+
+def unflatten_with_named_tensor(input, dim, sizes):
+ """Workaround for unflattening with named tensor."""
+ # tracer acts up with unflatten. See https://github.com/pytorch/pytorch/issues/49538
+ new_shape = list(input.shape)[:dim] + list(sizes) + list(input.shape)[dim+1:]
+ return input.view(*new_shape)
+
+class Slice(nn.Module):
+ def __init__(self, start_index=1):
+ super(Slice, self).__init__()
+ self.start_index = start_index
+
+ def forward(self, x):
+ return x[:, self.start_index :]
+
+
+class AddReadout(nn.Module):
+ def __init__(self, start_index=1):
+ super(AddReadout, self).__init__()
+ self.start_index = start_index
+
+ def forward(self, x):
+ if self.start_index == 2:
+ readout = (x[:, 0] + x[:, 1]) / 2
+ else:
+ readout = x[:, 0]
+ return x[:, self.start_index :] + readout.unsqueeze(1)
+
+
+class ProjectReadout(nn.Module):
+ def __init__(self, in_features, start_index=1):
+ super(ProjectReadout, self).__init__()
+ self.start_index = start_index
+
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
+
+ def forward(self, x):
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
+ features = torch.cat((x[:, self.start_index :], readout), -1)
+
+ return self.project(features)
+
+
+class Transpose(nn.Module):
+ def __init__(self, dim0, dim1):
+ super(Transpose, self).__init__()
+ self.dim0 = dim0
+ self.dim1 = dim1
+
+ def forward(self, x):
+ x = x.transpose(self.dim0, self.dim1)
+ return x
+
+
+def forward_vit(pretrained, x):
+ b, c, h, w = x.shape
+
+ glob = pretrained.model.forward_flex(x)
+
+ layer_1 = pretrained.activations["1"]
+ layer_2 = pretrained.activations["2"]
+ layer_3 = pretrained.activations["3"]
+ layer_4 = pretrained.activations["4"]
+
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
+
+
+ unflattened_dim = 2
+ unflattened_size = (
+ int(torch.div(h, pretrained.model.patch_size[1], rounding_mode='floor')),
+ int(torch.div(w, pretrained.model.patch_size[0], rounding_mode='floor')),
+ )
+ unflatten = nn.Sequential(nn.Unflatten(unflattened_dim, unflattened_size))
+
+
+ if layer_1.ndim == 3:
+ layer_1 = unflatten(layer_1)
+ if layer_2.ndim == 3:
+ layer_2 = unflatten(layer_2)
+ if layer_3.ndim == 3:
+ layer_3 = unflatten_with_named_tensor(layer_3, unflattened_dim, unflattened_size)
+ if layer_4.ndim == 3:
+ layer_4 = unflatten_with_named_tensor(layer_4, unflattened_dim, unflattened_size)
+
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
+
+ return layer_1, layer_2, layer_3, layer_4
+
+
+def _resize_pos_embed(self, posemb, gs_h, gs_w):
+ posemb_tok, posemb_grid = (
+ posemb[:, : self.start_index],
+ posemb[0, self.start_index :],
+ )
+
+ gs_old = int(math.sqrt(posemb_grid.shape[0]))
+
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
+
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+
+ return posemb
+
+
+def forward_flex(self, x):
+ b, c, h, w = x.shape
+
+ pos_embed = self._resize_pos_embed(
+ self.pos_embed, torch.div(h, self.patch_size[1], rounding_mode='floor'), torch.div(w, self.patch_size[0], rounding_mode='floor')
+ )
+
+ B = x.shape[0]
+
+ if hasattr(self.patch_embed, "backbone"):
+ x = self.patch_embed.backbone(x)
+ if isinstance(x, (list, tuple)):
+ x = x[-1] # last feature if backbone outputs list/tuple of features
+
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
+
+ if getattr(self, "dist_token", None) is not None:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ dist_token = self.dist_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
+ else:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ x = x + pos_embed
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.norm(x)
+
+ return x
+
+
+activations = {}
+
+
+def get_activation(name):
+ def hook(model, input, output):
+ activations[name] = output
+
+ return hook
+
+
+def get_readout_oper(vit_features, features, use_readout, start_index=1):
+ if use_readout == "ignore":
+ readout_oper = [Slice(start_index)] * len(features)
+ elif use_readout == "add":
+ readout_oper = [AddReadout(start_index)] * len(features)
+ elif use_readout == "project":
+ readout_oper = [
+ ProjectReadout(vit_features, start_index) for out_feat in features
+ ]
+ else:
+ assert (
+ False
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
+
+ return readout_oper
+
+
+def _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ size=[384, 384],
+ hooks=[2, 5, 8, 11],
+ vit_features=768,
+ use_readout="ignore",
+ start_index=1,
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+ # 32, 48, 136, 384
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+
+ return pretrained
+
+
+def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
+
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[256, 512, 1024, 1024],
+ hooks=hooks,
+ vit_features=1024,
+ use_readout=use_readout,
+ )
+
+
+def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+ )
+
+
+def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+ )
+
+
+def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model(
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
+ )
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ hooks=hooks,
+ use_readout=use_readout,
+ start_index=2,
+ )
+
+
+def _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=[0, 1, 8, 11],
+ vit_features=768,
+ use_vit_only=False,
+ use_readout="ignore",
+ start_index=1,
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+
+ if use_vit_only == True:
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ else:
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
+ get_activation("1")
+ )
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
+ get_activation("2")
+ )
+
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+ if use_vit_only == True:
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+ else:
+ pretrained.act_postprocess1 = nn.Sequential(
+ nn.Identity(), nn.Identity(), nn.Identity()
+ )
+ pretrained.act_postprocess2 = nn.Sequential(
+ nn.Identity(), nn.Identity(), nn.Identity()
+ )
+
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+
+ return pretrained
+
+
+def _make_pretrained_vitb_rn50_384(
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
+):
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
+
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
+ return _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
+
+def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
+ if backbone == "vitl16_384":
+ pretrained = _make_pretrained_vitl16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb_rn50_384":
+ pretrained = _make_pretrained_vitb_rn50_384(
+ use_pretrained,
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
+ scratch = _make_scratch(
+ [256, 512, 768, 768], features, groups=groups, expand=expand
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb16_384":
+ pretrained = _make_pretrained_vitb16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [96, 192, 384, 768], features, groups=groups, expand=expand
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
+ elif backbone == "resnext101_wsl":
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
+ elif backbone == "efficientnet_lite3":
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
+ else:
+ print(f"Backbone '{backbone}' not implemented")
+ assert False
+
+ return pretrained, scratch
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ out_shape4 = out_shape
+ if expand==True:
+ out_shape1 = out_shape
+ out_shape2 = out_shape*2
+ out_shape3 = out_shape*4
+ out_shape4 = out_shape*8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+
+ return scratch
+
+
+def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
+ efficientnet = torch.hub.load(
+ "rwightman/gen-efficientnet-pytorch",
+ "tf_efficientnet_lite3",
+ pretrained=use_pretrained,
+ exportable=exportable
+ )
+ return _make_efficientnet_backbone(efficientnet)
+
+
+def _make_efficientnet_backbone(effnet):
+ pretrained = nn.Module()
+
+ pretrained.layer1 = nn.Sequential(
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
+ )
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
+
+ return pretrained
+
+
+def _make_resnet_backbone(resnet):
+ pretrained = nn.Module()
+ pretrained.layer1 = nn.Sequential(
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
+ )
+
+ pretrained.layer2 = resnet.layer2
+ pretrained.layer3 = resnet.layer3
+ pretrained.layer4 = resnet.layer4
+
+ return pretrained
+
+
+def _make_pretrained_resnext101_wsl(use_pretrained):
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
+ return _make_resnet_backbone(resnet)
+
+
+
+class Interpolate(nn.Module):
+ """Interpolation module.
+ """
+
+ def __init__(self, scale_factor, mode, align_corners=False):
+ """Init.
+ Args:
+ scale_factor (float): scaling
+ mode (str): interpolation mode
+ """
+ super(Interpolate, self).__init__()
+
+ self.interp = nn.functional.interpolate
+ self.scale_factor = scale_factor
+ self.mode = mode
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ """Forward pass.
+ Args:
+ x (tensor): input
+ Returns:
+ tensor: interpolated data
+ """
+
+ x = self.interp(
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
+ )
+
+ return x
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features):
+ """Init.
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ """Forward pass.
+ Args:
+ x (tensor): input
+ Returns:
+ tensor: output
+ """
+ out = self.relu(x)
+ out = self.conv1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+
+ return out + x
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(self, features):
+ """Init.
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.resConfUnit1 = ResidualConvUnit(features)
+ self.resConfUnit2 = ResidualConvUnit(features)
+
+ def forward(self, *xs):
+ """Forward pass.
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ output += self.resConfUnit1(xs[1])
+
+ output = self.resConfUnit2(output)
+
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=True
+ )
+
+ return output
+
+
+
+
+class ResidualConvUnit_custom(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features, activation, bn):
+ """Init.
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups=1
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ if self.bn==True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+ Args:
+ x (tensor): input
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn==True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn==True:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+
+ # return out + x
+
+
+class FeatureFusionBlock_custom(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
+ """Init.
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock_custom, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups=1
+
+ self.expand = expand
+ out_features = features
+ if self.expand==True:
+ out_features = features//2
+
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, *xs):
+ """Forward pass.
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+ # output += res
+
+ output = self.resConfUnit2(output)
+
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
+ )
+
+ output = self.out_conv(output)
+
+ return output
+
+
+
+def _make_fusion_block(features, use_bn):
+ return FeatureFusionBlock_custom(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ )
+
+
+class DPT(BaseModel):
+ def __init__(
+ self,
+ head,
+ features=256,
+ backbone="vitb_rn50_384",
+ readout="project",
+ channels_last=False,
+ use_bn=False,
+ ):
+
+ super(DPT, self).__init__()
+
+ self.channels_last = channels_last
+
+ hooks = {
+ "vitb_rn50_384": [0, 1, 8, 11],
+ "vitb16_384": [2, 5, 8, 11],
+ "vitl16_384": [5, 11, 17, 23],
+ }
+
+ # Instantiate backbone and reassemble blocks
+ self.pretrained, self.scratch = _make_encoder(
+ backbone,
+ features,
+ True, # Set to true of you want to train from scratch, uses ImageNet weights
+ groups=1,
+ expand=False,
+ exportable=False,
+ hooks=hooks[backbone],
+ use_readout=readout,
+ )
+
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+ self.scratch.output_conv = head
+
+
+ def forward(self, x):
+ if self.channels_last == True:
+ x.contiguous(memory_format=torch.channels_last)
+
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return out
+
+class DPTDepthModel(DPT):
+ def __init__(self, path=None, non_negative=True, num_channels=1, **kwargs):
+ features = kwargs["features"] if "features" in kwargs else 256
+
+ head = nn.Sequential(
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, num_channels, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+
+ super().__init__(head, **kwargs)
+
+ if path is not None:
+ self.load(path)
+
+ def forward(self, x):
+ return super().forward(x).squeeze(dim=1)
\ No newline at end of file
diff --git a/encoding.py b/encoding.py
new file mode 100644
index 0000000..407589e
--- /dev/null
+++ b/encoding.py
@@ -0,0 +1,89 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class FreqEncoder_torch(nn.Module):
+ def __init__(self, input_dim, max_freq_log2, N_freqs,
+ log_sampling=True, include_input=True,
+ periodic_fns=(torch.sin, torch.cos)):
+
+ super().__init__()
+
+ self.input_dim = input_dim
+ self.include_input = include_input
+ self.periodic_fns = periodic_fns
+ self.N_freqs = N_freqs
+
+ self.output_dim = 0
+ if self.include_input:
+ self.output_dim += self.input_dim
+
+ self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns)
+
+ if log_sampling:
+ self.freq_bands = 2 ** torch.linspace(0, max_freq_log2, N_freqs)
+ else:
+ self.freq_bands = torch.linspace(2 ** 0, 2 ** max_freq_log2, N_freqs)
+
+ self.freq_bands = self.freq_bands.numpy().tolist()
+
+ def forward(self, input, max_level=None, **kwargs):
+
+ if max_level is None:
+ max_level = self.N_freqs
+ else:
+ max_level = int(max_level * self.N_freqs)
+
+ out = []
+ if self.include_input:
+ out.append(input)
+
+ for i in range(max_level):
+ freq = self.freq_bands[i]
+ for p_fn in self.periodic_fns:
+ out.append(p_fn(input * freq))
+
+ # append 0
+ if self.N_freqs - max_level > 0:
+ out.append(torch.zeros(input.shape[0], (self.N_freqs - max_level) * 2 * input.shape[1], device=input.device, dtype=input.dtype))
+
+ out = torch.cat(out, dim=-1)
+
+ return out
+
+def get_encoder(encoding, input_dim=3,
+ multires=6,
+ degree=4,
+ num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, interpolation='linear',
+ **kwargs):
+
+ if encoding == 'None':
+ return lambda x, **kwargs: x, input_dim
+
+ elif encoding == 'frequency_torch':
+ encoder = FreqEncoder_torch(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True)
+
+ elif encoding == 'frequency': # CUDA implementation, faster than torch.
+ from freqencoder import FreqEncoder
+ encoder = FreqEncoder(input_dim=input_dim, degree=multires)
+
+ elif encoding == 'sphere_harmonics':
+ from shencoder import SHEncoder
+ encoder = SHEncoder(input_dim=input_dim, degree=degree)
+
+ elif encoding == 'hashgrid':
+ from gridencoder import GridEncoder
+ encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners, interpolation=interpolation)
+
+ elif encoding == 'tiledgrid':
+ from gridencoder import GridEncoder
+ encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners, interpolation=interpolation)
+
+ elif encoding == 'hashgrid_taichi':
+ from taichi_modules.hash_encoder import HashEncoderTaichi
+ encoder = HashEncoderTaichi(batch_size=4096) #TODO: hard encoded batch size
+
+ else:
+ raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]')
+
+ return encoder, encoder.output_dim
\ No newline at end of file
diff --git a/freqencoder/__init__.py b/freqencoder/__init__.py
new file mode 100644
index 0000000..69ec49c
--- /dev/null
+++ b/freqencoder/__init__.py
@@ -0,0 +1 @@
+from .freq import FreqEncoder
\ No newline at end of file
diff --git a/freqencoder/backend.py b/freqencoder/backend.py
new file mode 100644
index 0000000..fa0e820
--- /dev/null
+++ b/freqencoder/backend.py
@@ -0,0 +1,42 @@
+import os
+from torch.utils.cpp_extension import load
+
+_src_path = os.path.dirname(os.path.abspath(__file__))
+
+nvcc_flags = [
+ '-O3', '-std=c++14',
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
+ '-use_fast_math'
+]
+
+if os.name == "posix":
+ c_flags = ['-O3', '-std=c++14']
+elif os.name == "nt":
+ c_flags = ['/O2', '/std:c++17']
+
+ # find cl.exe
+ def find_cl_path():
+ import glob
+ for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
+ paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
+ if paths:
+ return paths[0]
+
+ # If cl.exe is not on path, try to find it.
+ if os.system("where cl.exe >nul 2>nul") != 0:
+ cl_path = find_cl_path()
+ if cl_path is None:
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
+ os.environ["PATH"] += ";" + cl_path
+
+_backend = load(name='_freqencoder',
+ extra_cflags=c_flags,
+ extra_cuda_cflags=nvcc_flags,
+ sources=[os.path.join(_src_path, 'src', f) for f in [
+ 'freqencoder.cu',
+ 'bindings.cpp',
+ ]],
+ )
+
+__all__ = ['_backend']
\ No newline at end of file
diff --git a/freqencoder/freq.py b/freqencoder/freq.py
new file mode 100644
index 0000000..5cba1e6
--- /dev/null
+++ b/freqencoder/freq.py
@@ -0,0 +1,77 @@
+import numpy as np
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+try:
+ import _freqencoder as _backend
+except ImportError:
+ from .backend import _backend
+
+
+class _freq_encoder(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
+ def forward(ctx, inputs, degree, output_dim):
+ # inputs: [B, input_dim], float
+ # RETURN: [B, F], float
+
+ if not inputs.is_cuda: inputs = inputs.cuda()
+ inputs = inputs.contiguous()
+
+ B, input_dim = inputs.shape # batch size, coord dim
+
+ outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
+
+ _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
+
+ ctx.save_for_backward(inputs, outputs)
+ ctx.dims = [B, input_dim, degree, output_dim]
+
+ return outputs
+
+ @staticmethod
+ #@once_differentiable
+ @custom_bwd
+ def backward(ctx, grad):
+ # grad: [B, C * C]
+
+ grad = grad.contiguous()
+ inputs, outputs = ctx.saved_tensors
+ B, input_dim, degree, output_dim = ctx.dims
+
+ grad_inputs = torch.zeros_like(inputs)
+ _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
+
+ return grad_inputs, None, None
+
+
+freq_encode = _freq_encoder.apply
+
+
+class FreqEncoder(nn.Module):
+ def __init__(self, input_dim=3, degree=4):
+ super().__init__()
+
+ self.input_dim = input_dim
+ self.degree = degree
+ self.output_dim = input_dim + input_dim * 2 * degree
+
+ def __repr__(self):
+ return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}"
+
+ def forward(self, inputs, **kwargs):
+ # inputs: [..., input_dim]
+ # return: [..., ]
+
+ prefix_shape = list(inputs.shape[:-1])
+ inputs = inputs.reshape(-1, self.input_dim)
+
+ outputs = freq_encode(inputs, self.degree, self.output_dim)
+
+ outputs = outputs.reshape(prefix_shape + [self.output_dim])
+
+ return outputs
\ No newline at end of file
diff --git a/freqencoder/setup.py b/freqencoder/setup.py
new file mode 100644
index 0000000..ea64112
--- /dev/null
+++ b/freqencoder/setup.py
@@ -0,0 +1,52 @@
+import os
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+_src_path = os.path.dirname(os.path.abspath(__file__))
+
+nvcc_flags = [
+ '-O3', '-std=c++14',
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
+ '-use_fast_math'
+]
+
+if os.name == "posix":
+ c_flags = ['-O3', '-std=c++14']
+elif os.name == "nt":
+ c_flags = ['/O2', '/std:c++17']
+
+ # find cl.exe
+ def find_cl_path():
+ import glob
+ for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
+ paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
+ if paths:
+ return paths[0]
+
+ # If cl.exe is not on path, try to find it.
+ if os.system("where cl.exe >nul 2>nul") != 0:
+ cl_path = find_cl_path()
+ if cl_path is None:
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
+ os.environ["PATH"] += ";" + cl_path
+
+setup(
+ name='freqencoder', # package name, import this to use python API
+ ext_modules=[
+ CUDAExtension(
+ name='_freqencoder', # extension name, import this to use CUDA API
+ sources=[os.path.join(_src_path, 'src', f) for f in [
+ 'freqencoder.cu',
+ 'bindings.cpp',
+ ]],
+ extra_compile_args={
+ 'cxx': c_flags,
+ 'nvcc': nvcc_flags,
+ }
+ ),
+ ],
+ cmdclass={
+ 'build_ext': BuildExtension,
+ }
+)
\ No newline at end of file
diff --git a/freqencoder/src/bindings.cpp b/freqencoder/src/bindings.cpp
new file mode 100644
index 0000000..bb5f285
--- /dev/null
+++ b/freqencoder/src/bindings.cpp
@@ -0,0 +1,8 @@
+#include
+
+#include "freqencoder.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)");
+ m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)");
+}
\ No newline at end of file
diff --git a/freqencoder/src/freqencoder.cu b/freqencoder/src/freqencoder.cu
new file mode 100644
index 0000000..072da74
--- /dev/null
+++ b/freqencoder/src/freqencoder.cu
@@ -0,0 +1,129 @@
+#include
+
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+#include
+
+#include
+
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
+#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
+#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
+
+inline constexpr __device__ float PI() { return 3.141592653589793f; }
+
+template
+__host__ __device__ T div_round_up(T val, T divisor) {
+ return (val + divisor - 1) / divisor;
+}
+
+// inputs: [B, D]
+// outputs: [B, C], C = D + D * deg * 2
+__global__ void kernel_freq(
+ const float * __restrict__ inputs,
+ uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
+ float * outputs
+) {
+ // parallel on per-element
+ const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
+ if (t >= B * C) return;
+
+ // get index
+ const uint32_t b = t / C;
+ const uint32_t c = t - b * C; // t % C;
+
+ // locate
+ inputs += b * D;
+ outputs += t;
+
+ // write self
+ if (c < D) {
+ outputs[0] = inputs[c];
+ // write freq
+ } else {
+ const uint32_t col = c / D - 1;
+ const uint32_t d = c % D;
+ const uint32_t freq = col / 2;
+ const float phase_shift = (col % 2) * (PI() / 2);
+ outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift);
+ }
+}
+
+// grad: [B, C], C = D + D * deg * 2
+// outputs: [B, C]
+// grad_inputs: [B, D]
+__global__ void kernel_freq_backward(
+ const float * __restrict__ grad,
+ const float * __restrict__ outputs,
+ uint32_t B, uint32_t D, uint32_t deg, uint32_t C,
+ float * grad_inputs
+) {
+ // parallel on per-element
+ const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
+ if (t >= B * D) return;
+
+ const uint32_t b = t / D;
+ const uint32_t d = t - b * D; // t % D;
+
+ // locate
+ grad += b * C;
+ outputs += b * C;
+ grad_inputs += t;
+
+ // register
+ float result = grad[d];
+ grad += D;
+ outputs += D;
+
+ for (uint32_t f = 0; f < deg; f++) {
+ result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]);
+ grad += 2 * D;
+ outputs += 2 * D;
+ }
+
+ // write
+ grad_inputs[0] = result;
+}
+
+
+void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) {
+ CHECK_CUDA(inputs);
+ CHECK_CUDA(outputs);
+
+ CHECK_CONTIGUOUS(inputs);
+ CHECK_CONTIGUOUS(outputs);
+
+ CHECK_IS_FLOATING(inputs);
+ CHECK_IS_FLOATING(outputs);
+
+ static constexpr uint32_t N_THREADS = 128;
+
+ kernel_freq<<>>(inputs.data_ptr(), B, D, deg, C, outputs.data_ptr());
+}
+
+
+void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) {
+ CHECK_CUDA(grad);
+ CHECK_CUDA(outputs);
+ CHECK_CUDA(grad_inputs);
+
+ CHECK_CONTIGUOUS(grad);
+ CHECK_CONTIGUOUS(outputs);
+ CHECK_CONTIGUOUS(grad_inputs);
+
+ CHECK_IS_FLOATING(grad);
+ CHECK_IS_FLOATING(outputs);
+ CHECK_IS_FLOATING(grad_inputs);
+
+ static constexpr uint32_t N_THREADS = 128;
+
+ kernel_freq_backward<<>>(grad.data_ptr(), outputs.data_ptr(), B, D, deg, C, grad_inputs.data_ptr());
+}
\ No newline at end of file
diff --git a/freqencoder/src/freqencoder.h b/freqencoder/src/freqencoder.h
new file mode 100644
index 0000000..34f28c7
--- /dev/null
+++ b/freqencoder/src/freqencoder.h
@@ -0,0 +1,10 @@
+# pragma once
+
+#include
+#include
+
+// _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
+void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs);
+
+// _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
+void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs);
\ No newline at end of file
diff --git a/gradio_app.py b/gradio_app.py
new file mode 100644
index 0000000..9573bdf
--- /dev/null
+++ b/gradio_app.py
@@ -0,0 +1,246 @@
+import torch
+import argparse
+
+from nerf.provider import NeRFDataset
+from nerf.utils import *
+
+import gradio as gr
+import gc
+
+print(f'[INFO] loading options..')
+
+# fake config object, this should not be used in CMD, only allow change from gradio UI.
+parser = argparse.ArgumentParser()
+parser.add_argument('--text', default=None, help="text prompt")
+parser.add_argument('--negative', default='', type=str, help="negative text prompt")
+parser.add_argument('--test', action='store_true', help="test mode")
+parser.add_argument('--eval_interval', type=int, default=10, help="evaluate on the valid set every interval epochs")
+parser.add_argument('--workspace', type=str, default='trial_gradio')
+parser.add_argument('--guidance', type=str, default='stable-diffusion', help='choose from [stable-diffusion, clip]')
+parser.add_argument('--seed', type=int, default=0)
+
+parser.add_argument('--save_mesh', action='store_true', help="export an obj mesh with texture")
+parser.add_argument('--mcubes_resolution', type=int, default=256, help="mcubes resolution for extracting mesh")
+parser.add_argument('--decimate_target', type=int, default=1e5, help="target face number for mesh decimation")
+
+### training options
+parser.add_argument('--iters', type=int, default=10000, help="training iters")
+parser.add_argument('--lr', type=float, default=1e-3, help="initial learning rate")
+parser.add_argument('--ckpt', type=str, default='latest')
+parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
+parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)")
+parser.add_argument('--num_steps', type=int, default=64, help="num steps sampled per ray (only valid when not using --cuda_ray)")
+parser.add_argument('--upsample_steps', type=int, default=64, help="num steps up-sampled per ray (only valid when not using --cuda_ray)")
+parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
+parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)")
+parser.add_argument('--warmup_iters', type=int, default=1000, help="training iters that only use albedo shading")
+parser.add_argument('--uniform_sphere_rate', type=float, default=0.5, help="likelihood of sampling camera location uniformly on the sphere surface area")
+# model options
+parser.add_argument('--bg_radius', type=float, default=1.4, help="if positive, use a background model at sphere(bg_radius)")
+parser.add_argument('--density_activation', type=str, default='softplus', choices=['softplus', 'exp'], help="density activation function")
+parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied")
+parser.add_argument('--blob_density', type=float, default=10, help="max (center) density for the density blob")
+parser.add_argument('--blob_radius', type=float, default=0.3, help="control the radius for the density blob")
+# network backbone
+parser.add_argument('--fp16', action='store_true', help="use float16 for training")
+parser.add_argument('--vram_O', action='store_true', help="optimization for low VRAM usage")
+parser.add_argument('--backbone', type=str, default='grid', help="nerf backbone, choose from [grid, vanilla]")
+parser.add_argument('--optim', type=str, default='adan', choices=['adan', 'adam', 'adamw'], help="optimizer")
+parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version")
+parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key")
+# rendering resolution in training, decrease this if CUDA OOM.
+parser.add_argument('--w', type=int, default=64, help="render width for NeRF in training")
+parser.add_argument('--h', type=int, default=64, help="render height for NeRF in training")
+parser.add_argument('--jitter_pose', action='store_true', help="add jitters to the randomly sampled camera poses")
+
+### dataset options
+parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box(-bound, bound)")
+parser.add_argument('--dt_gamma', type=float, default=0, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
+parser.add_argument('--min_near', type=float, default=0.1, help="minimum near distance for camera")
+parser.add_argument('--radius_range', type=float, nargs='*', default=[1.0, 1.5], help="training camera radius range")
+parser.add_argument('--fovy_range', type=float, nargs='*', default=[40, 70], help="training camera fovy range")
+parser.add_argument('--angle_overhead', type=float, default=30, help="[0, angle_overhead] is the overhead region")
+parser.add_argument('--angle_front', type=float, default=60, help="[0, angle_front] is the front region, [180, 180+angle_front] the back region, otherwise the side region.")
+
+parser.add_argument('--lambda_entropy', type=float, default=1e-4, help="loss scale for alpha entropy")
+parser.add_argument('--lambda_opacity', type=float, default=0, help="loss scale for alpha value")
+parser.add_argument('--lambda_orient', type=float, default=1e-2, help="loss scale for orientation")
+parser.add_argument('--lambda_smooth', type=float, default=0, help="loss scale for surface smoothness")
+
+### GUI options
+parser.add_argument('--gui', action='store_true', help="start a GUI")
+parser.add_argument('--W', type=int, default=800, help="GUI width")
+parser.add_argument('--H', type=int, default=800, help="GUI height")
+parser.add_argument('--radius', type=float, default=3, help="default GUI camera radius from center")
+parser.add_argument('--fovy', type=float, default=60, help="default GUI camera fovy")
+parser.add_argument('--light_theta', type=float, default=60, help="default GUI light direction in [0, 180], corresponding to elevation [90, -90]")
+parser.add_argument('--light_phi', type=float, default=0, help="default GUI light direction in [0, 360), azimuth")
+parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")
+
+parser.add_argument('--need_share', type=bool, default=False, help="do you want to share gradio app to external network?")
+
+opt = parser.parse_args()
+
+# default to use -O !!!
+opt.fp16 = True
+opt.cuda_ray = True
+opt.vram_O = True
+# opt.lambda_entropy = 1e-4
+# opt.lambda_opacity = 0
+
+if opt.backbone == 'vanilla':
+ from nerf.network import NeRFNetwork
+elif opt.backbone == 'grid':
+ from nerf.network_grid import NeRFNetwork
+else:
+ raise NotImplementedError(f'--backbone {opt.backbone} is not implemented!')
+
+print(opt)
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+print(f'[INFO] loading models..')
+
+if opt.guidance == 'stable-diffusion':
+ from guidance.sd_utils import StableDiffusion
+ guidance = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key)
+elif opt.guidance == 'clip':
+ from guidance.clip_utils import CLIP
+ guidance = CLIP(device)
+else:
+ raise NotImplementedError(f'--guidance {opt.guidance} is not implemented.')
+
+train_loader = NeRFDataset(opt, device=device, type='train', H=opt.h, W=opt.w, size=100).dataloader()
+valid_loader = NeRFDataset(opt, device=device, type='val', H=opt.H, W=opt.W, size=5).dataloader()
+test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=100).dataloader()
+
+print(f'[INFO] everything loaded!')
+
+trainer = None
+model = None
+
+# define UI
+
+with gr.Blocks(css=".gradio-container {max-width: 512px; margin: auto;}") as demo:
+
+ # title
+ gr.Markdown('[Stable-DreamFusion](https://github.com/ashawkey/stable-dreamfusion) Text-to-3D Example')
+
+ # inputs
+ prompt = gr.Textbox(label="Prompt", max_lines=1, value="a DSLR photo of a koi fish")
+ iters = gr.Slider(label="Iters", minimum=1000, maximum=20000, value=5000, step=100)
+ seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True)
+ button = gr.Button('Generate')
+
+ # outputs
+ image = gr.Image(label="image", visible=True)
+ video = gr.Video(label="video", visible=False)
+ logs = gr.Textbox(label="logging")
+
+ # gradio main func
+ def submit(text, iters, seed):
+
+ global trainer, model
+
+ # seed
+ opt.seed = seed
+ opt.text = text
+ opt.iters = iters
+
+ seed_everything(seed)
+
+ # clean up
+ if trainer is not None:
+ del model
+ del trainer
+ gc.collect()
+ torch.cuda.empty_cache()
+ print('[INFO] clean up!')
+
+ # simply reload everything...
+ model = NeRFNetwork(opt)
+
+ if opt.optim == 'adan':
+ from optimizer import Adan
+ # Adan usually requires a larger LR
+ optimizer = lambda model: Adan(model.get_params(5 * opt.lr), eps=1e-15)
+ elif opt.optim == 'adamw':
+ optimizer = lambda model: torch.optim.AdamW(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)
+ else: # adam
+ optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)
+
+ scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 1) # fixed
+
+ trainer = Trainer('df', opt, model, guidance, device=device, workspace=opt.workspace, optimizer=optimizer, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint=opt.ckpt, eval_interval=opt.eval_interval, scheduler_update_every_step=True)
+
+ # train (every ep only contain 8 steps, so we can get some vis every ~10s)
+ STEPS = 8
+ max_epochs = np.ceil(opt.iters / STEPS).astype(np.int32)
+
+ # we have to get the explicit training loop out here to yield progressive results...
+ loader = iter(valid_loader)
+
+ start_t = time.time()
+
+ for epoch in range(max_epochs):
+
+ trainer.train_gui(train_loader, step=STEPS)
+
+ # manual test and get intermediate results
+ try:
+ data = next(loader)
+ except StopIteration:
+ loader = iter(valid_loader)
+ data = next(loader)
+
+ trainer.model.eval()
+
+ if trainer.ema is not None:
+ trainer.ema.store()
+ trainer.ema.copy_to()
+
+ with torch.no_grad():
+ with torch.cuda.amp.autocast(enabled=trainer.fp16):
+ preds, preds_depth = trainer.test_step(data, perturb=False)
+
+ if trainer.ema is not None:
+ trainer.ema.restore()
+
+ pred = preds[0].detach().cpu().numpy()
+ # pred_depth = preds_depth[0].detach().cpu().numpy()
+
+ pred = (pred * 255).astype(np.uint8)
+
+ yield {
+ image: gr.update(value=pred, visible=True),
+ video: gr.update(visible=False),
+ logs: f"training iters: {epoch * STEPS} / {iters}, lr: {trainer.optimizer.param_groups[0]['lr']:.6f}",
+ }
+
+
+ # test
+ trainer.test(test_loader)
+
+ results = glob.glob(os.path.join(opt.workspace, 'results', '*rgb*.mp4'))
+ assert results is not None, "cannot retrieve results!"
+ results.sort(key=lambda x: os.path.getmtime(x)) # sort by mtime
+
+ end_t = time.time()
+
+ yield {
+ image: gr.update(visible=False),
+ video: gr.update(value=results[-1], visible=True),
+ logs: f"Generation Finished in {(end_t - start_t)/ 60:.4f} minutes!",
+ }
+
+
+ button.click(
+ submit,
+ [prompt, iters, seed],
+ [image, video, logs]
+ )
+
+# concurrency_count: only allow ONE running progress, else GPU will OOM.
+demo.queue(concurrency_count=1)
+
+demo.launch(share=opt.need_share)
diff --git a/gridencoder/__init__.py b/gridencoder/__init__.py
new file mode 100644
index 0000000..f1476ce
--- /dev/null
+++ b/gridencoder/__init__.py
@@ -0,0 +1 @@
+from .grid import GridEncoder
\ No newline at end of file
diff --git a/gridencoder/backend.py b/gridencoder/backend.py
new file mode 100644
index 0000000..b403f34
--- /dev/null
+++ b/gridencoder/backend.py
@@ -0,0 +1,40 @@
+import os
+from torch.utils.cpp_extension import load
+
+_src_path = os.path.dirname(os.path.abspath(__file__))
+
+nvcc_flags = [
+ '-O3', '-std=c++14',
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
+]
+
+if os.name == "posix":
+ c_flags = ['-O3', '-std=c++14']
+elif os.name == "nt":
+ c_flags = ['/O2', '/std:c++17']
+
+ # find cl.exe
+ def find_cl_path():
+ import glob
+ for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
+ paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
+ if paths:
+ return paths[0]
+ # If cl.exe is not on path, try to find it.
+ if os.system("where cl.exe >nul 2>nul") != 0:
+ cl_path = find_cl_path()
+ if cl_path is None:
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
+ os.environ["PATH"] += ";" + cl_path
+
+_backend = load(name='_grid_encoder',
+ extra_cflags=c_flags,
+ extra_cuda_cflags=nvcc_flags,
+ sources=[os.path.join(_src_path, 'src', f) for f in [
+ 'gridencoder.cu',
+ 'bindings.cpp',
+ ]],
+ )
+
+__all__ = ['_backend']
\ No newline at end of file
diff --git a/gridencoder/grid.py b/gridencoder/grid.py
new file mode 100644
index 0000000..3f91daf
--- /dev/null
+++ b/gridencoder/grid.py
@@ -0,0 +1,206 @@
+import math
+import numpy as np
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+try:
+ import _gridencoder as _backend
+except ImportError:
+ from .backend import _backend
+
+_gridtype_to_id = {
+ 'hash': 0,
+ 'tiled': 1,
+}
+
+_interp_to_id = {
+ 'linear': 0,
+ 'smoothstep': 1,
+}
+
+class _grid_encode(Function):
+ @staticmethod
+ @custom_fwd
+ def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False, interpolation=0, max_level=None):
+ # inputs: [B, D], float in [0, 1]
+ # embeddings: [sO, C], float
+ # offsets: [L + 1], int
+ # RETURN: [B, F], float
+
+ inputs = inputs.contiguous()
+
+ B, D = inputs.shape # batch size, coord dim
+ L = offsets.shape[0] - 1 # level
+ C = embeddings.shape[1] # embedding dim for each level
+ S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
+ H = base_resolution # base resolution
+
+ max_level = L if max_level is None else max(min(int(math.ceil(max_level * L)), L), 1)
+
+ # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
+ # if C % 2 != 0, force float, since half for atomicAdd is very slow.
+ if torch.is_autocast_enabled() and C % 2 == 0:
+ embeddings = embeddings.to(torch.half)
+
+ # L first, optimize cache for cuda kernel, but needs an extra permute later
+ outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
+
+ # zero init if we only calculate partial levels
+ if max_level < L: outputs.zero_()
+
+ if calc_grad_inputs:
+ dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype)
+ if max_level < L: dy_dx.zero_()
+ else:
+ dy_dx = None
+
+ _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interpolation)
+
+ # permute back to [B, L * C]
+ outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
+
+ ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
+ ctx.dims = [B, D, C, L, S, H, gridtype, interpolation, max_level]
+ ctx.align_corners = align_corners
+
+ return outputs
+
+ @staticmethod
+ #@once_differentiable
+ @custom_bwd
+ def backward(ctx, grad):
+
+ inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
+ B, D, C, L, S, H, gridtype, interpolation, max_level = ctx.dims
+ align_corners = ctx.align_corners
+
+ # grad: [B, L * C] --> [L, B, C]
+ grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
+
+ grad_embeddings = torch.zeros_like(embeddings)
+
+ if dy_dx is not None:
+ grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
+ else:
+ grad_inputs = None
+
+ _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interpolation)
+
+ if dy_dx is not None:
+ grad_inputs = grad_inputs.to(inputs.dtype)
+
+ return grad_inputs, grad_embeddings, None, None, None, None, None, None, None, None
+
+
+
+grid_encode = _grid_encode.apply
+
+
+class GridEncoder(nn.Module):
+ def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False, interpolation='linear'):
+ super().__init__()
+
+ # the finest resolution desired at the last level, if provided, overridee per_level_scale
+ if desired_resolution is not None:
+ per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))
+
+ self.input_dim = input_dim # coord dims, 2 or 3
+ self.num_levels = num_levels # num levels, each level multiply resolution by 2
+ self.level_dim = level_dim # encode channels per level
+ self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
+ self.log2_hashmap_size = log2_hashmap_size
+ self.base_resolution = base_resolution
+ self.output_dim = num_levels * level_dim
+ self.gridtype = gridtype
+ self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"
+ self.interpolation = interpolation
+ self.interp_id = _interp_to_id[interpolation] # "linear" or "smoothstep"
+ self.align_corners = align_corners
+
+ # allocate parameters
+ offsets = []
+ offset = 0
+ self.max_params = 2 ** log2_hashmap_size
+ for i in range(num_levels):
+ resolution = int(np.ceil(base_resolution * per_level_scale ** i))
+ params_in_level = min(self.max_params, (resolution) ** input_dim) # limit max number
+ params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
+ offsets.append(offset)
+ offset += params_in_level
+ offsets.append(offset)
+ offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
+ self.register_buffer('offsets', offsets)
+
+ self.n_params = offsets[-1] * level_dim
+
+ # parameters
+ self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ std = 1e-4
+ self.embeddings.data.uniform_(-std, std)
+
+ def __repr__(self):
+ return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}"
+
+ def forward(self, inputs, bound=1, max_level=None):
+ # inputs: [..., input_dim], normalized real world positions in [-bound, bound]
+ # max_level: only calculate first max_level levels (None will use all levels)
+ # return: [..., num_levels * level_dim]
+
+ inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
+
+ #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
+
+ prefix_shape = list(inputs.shape[:-1])
+ inputs = inputs.view(-1, self.input_dim)
+
+ outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners, self.interp_id, max_level)
+ outputs = outputs.view(prefix_shape + [self.output_dim])
+
+ #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
+
+ return outputs
+
+ # always run in float precision!
+ @torch.cuda.amp.autocast(enabled=False)
+ def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000):
+ # inputs: [..., input_dim], float in [-b, b], location to calculate TV loss.
+
+ D = self.input_dim
+ C = self.embeddings.shape[1] # embedding dim for each level
+ L = self.offsets.shape[0] - 1 # level
+ S = np.log2(self.per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
+ H = self.base_resolution # base resolution
+
+ if inputs is None:
+ # randomized in [0, 1]
+ inputs = torch.rand(B, self.input_dim, device=self.embeddings.device)
+ else:
+ inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
+ inputs = inputs.view(-1, self.input_dim)
+ B = inputs.shape[0]
+
+ if self.embeddings.grad is None:
+ raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!')
+
+ _backend.grad_total_variation(inputs, self.embeddings, self.embeddings.grad, self.offsets, weight, B, D, C, L, S, H, self.gridtype_id, self.align_corners)
+
+ @torch.cuda.amp.autocast(enabled=False)
+ def grad_weight_decay(self, weight=0.1):
+ # level-wise meaned weight decay (ref: zip-nerf)
+
+ B = self.embeddings.shape[0] # size of embedding
+ C = self.embeddings.shape[1] # embedding dim for each level
+ L = self.offsets.shape[0] - 1 # level
+
+ if self.embeddings.grad is None:
+ raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!')
+
+ _backend.grad_weight_decay(self.embeddings, self.embeddings.grad, self.offsets, weight, B, C, L)
\ No newline at end of file
diff --git a/gridencoder/setup.py b/gridencoder/setup.py
new file mode 100644
index 0000000..a91b0c1
--- /dev/null
+++ b/gridencoder/setup.py
@@ -0,0 +1,51 @@
+import os
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+_src_path = os.path.dirname(os.path.abspath(__file__))
+
+nvcc_flags = [
+ '-O3', '-std=c++14',
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
+]
+
+if os.name == "posix":
+ c_flags = ['-O3', '-std=c++14']
+elif os.name == "nt":
+ c_flags = ['/O2', '/std:c++17']
+
+ # find cl.exe
+ def find_cl_path():
+ import glob
+ for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
+ paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
+ if paths:
+ return paths[0]
+
+ # If cl.exe is not on path, try to find it.
+ if os.system("where cl.exe >nul 2>nul") != 0:
+ cl_path = find_cl_path()
+ if cl_path is None:
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
+ os.environ["PATH"] += ";" + cl_path
+
+setup(
+ name='gridencoder', # package name, import this to use python API
+ ext_modules=[
+ CUDAExtension(
+ name='_gridencoder', # extension name, import this to use CUDA API
+ sources=[os.path.join(_src_path, 'src', f) for f in [
+ 'gridencoder.cu',
+ 'bindings.cpp',
+ ]],
+ extra_compile_args={
+ 'cxx': c_flags,
+ 'nvcc': nvcc_flags,
+ }
+ ),
+ ],
+ cmdclass={
+ 'build_ext': BuildExtension,
+ }
+)
\ No newline at end of file
diff --git a/gridencoder/src/bindings.cpp b/gridencoder/src/bindings.cpp
new file mode 100644
index 0000000..fc3dd5e
--- /dev/null
+++ b/gridencoder/src/bindings.cpp
@@ -0,0 +1,10 @@
+#include
+
+#include "gridencoder.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
+ m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
+ m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)");
+ m.def("grad_weight_decay", &grad_weight_decay, "grad_weight_decay (CUDA)");
+}
\ No newline at end of file
diff --git a/gridencoder/src/gridencoder.cu b/gridencoder/src/gridencoder.cu
new file mode 100644
index 0000000..93f5b80
--- /dev/null
+++ b/gridencoder/src/gridencoder.cu
@@ -0,0 +1,713 @@
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+#include
+
+#include
+#include
+
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
+#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
+#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
+
+
+// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF... program will never reach here!
+ __device__ inline at::Half atomicAdd(at::Half *address, at::Half val) {
+ // requires CUDA >= 10 and ARCH >= 70
+ // this is very slow compared to float or __half2, never use it.
+ //return atomicAdd(reinterpret_cast<__half*>(address), val);
+}
+
+
+template
+__host__ __device__ inline T div_round_up(T val, T divisor) {
+ return (val + divisor - 1) / divisor;
+}
+
+template
+__device__ inline T smoothstep(T val) {
+ return val*val*(3.0f - 2.0f * val);
+}
+
+template
+__device__ inline T smoothstep_derivative(T val) {
+ return 6*val*(1.0f - val);
+}
+
+
+template
+__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {
+
+ // coherent type of hashing
+ constexpr uint32_t primes[7] = { 1u, 2654435761u, 805459861u, 3674653429u, 2097192037u, 1434869437u, 2165219737u };
+
+ uint32_t result = 0;
+ #pragma unroll
+ for (uint32_t i = 0; i < D; ++i) {
+ result ^= pos_grid[i] * primes[i];
+ }
+
+ return result;
+}
+
+
+template
+__device__ uint32_t get_grid_index(const uint32_t gridtype, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) {
+ uint32_t stride = 1;
+ uint32_t index = 0;
+
+ #pragma unroll
+ for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
+ index += pos_grid[d] * stride;
+ stride *= resolution;
+ }
+
+ // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97.
+ // gridtype: 0 == hash, 1 == tiled
+ if (gridtype == 0 && stride > hashmap_size) {
+ index = fast_hash(pos_grid);
+ }
+
+ return (index % hashmap_size) * C + ch;
+}
+
+
+template
+__global__ void kernel_grid(
+ const float * __restrict__ inputs,
+ const scalar_t * __restrict__ grid,
+ const int * __restrict__ offsets,
+ scalar_t * __restrict__ outputs,
+ const uint32_t B, const uint32_t L, const float S, const uint32_t H,
+ scalar_t * __restrict__ dy_dx,
+ const uint32_t gridtype,
+ const bool align_corners,
+ const uint32_t interp
+) {
+ const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (b >= B) return;
+
+ const uint32_t level = blockIdx.y;
+
+ // locate
+ grid += (uint32_t)offsets[level] * C;
+ inputs += b * D;
+ outputs += level * B * C + b * C;
+
+ // check input range (should be in [0, 1])
+ bool flag_oob = false;
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ if (inputs[d] < 0 || inputs[d] > 1) {
+ flag_oob = true;
+ }
+ }
+ // if input out of bound, just set output to 0
+ if (flag_oob) {
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ outputs[ch] = 0;
+ }
+ if (dy_dx) {
+ dy_dx += b * D * L * C + level * D * C; // B L D C
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ dy_dx[d * C + ch] = 0;
+ }
+ }
+ }
+ return;
+ }
+
+ const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
+ const uint32_t resolution = (uint32_t)ceil(exp2f(level * S) * H);
+
+ // calculate coordinate (always use float for precision!)
+ float pos[D];
+ float pos_deriv[D];
+ uint32_t pos_grid[D];
+
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+
+ // align_corners
+ if (align_corners) {
+ pos[d] = inputs[d] * (float)(resolution - 1); // [0, resolution - 1]
+ pos_grid[d] = min((uint32_t)floorf(pos[d]), resolution - 2); // left-top corner, [0, resolution - 2]
+ } else {
+ pos[d] = fminf(fmaxf(inputs[d] * (float)resolution - 0.5f, 0.0f), (float)(resolution - 1)); // [-0.5, resolution-0.5] --> [0, resolution - 1]
+ pos_grid[d] = (uint32_t)floorf(pos[d]); // left-top corner, [0, resolution - 1]
+ }
+ pos[d] -= (float)pos_grid[d];
+
+ // smoothstep instead of linear
+ if (interp == 1) {
+ pos_deriv[d] = smoothstep_derivative(pos[d]);
+ pos[d] = smoothstep(pos[d]);
+ } else {
+ pos_deriv[d] = 1.0f;
+ }
+ }
+
+ // verification of alignment
+ // if (level == L - 1 && b < 4) {
+ // printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
+ // }
+
+ // interpolate
+ scalar_t results[C] = {0}; // temp results in register
+
+ #pragma unroll
+ for (uint32_t idx = 0; idx < (1 << D); idx++) {
+ float w = 1;
+ uint32_t pos_grid_local[D];
+
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ if ((idx & (1 << d)) == 0) {
+ w *= 1 - pos[d];
+ pos_grid_local[d] = pos_grid[d];
+ } else {
+ w *= pos[d];
+ pos_grid_local[d] = min(pos_grid[d] + 1, resolution - 1);
+ }
+ }
+
+ uint32_t index = get_grid_index(gridtype, 0, hashmap_size, resolution, pos_grid_local);
+
+ // writing to register (fast)
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ results[ch] += w * grid[index + ch];
+ }
+
+ //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]);
+ }
+
+ // writing to global memory (slow)
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ outputs[ch] = results[ch];
+ }
+
+ // prepare dy_dx
+ // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9
+ if (dy_dx) {
+
+ dy_dx += b * D * L * C + level * D * C; // B L D C
+
+ #pragma unroll
+ for (uint32_t gd = 0; gd < D; gd++) {
+
+ scalar_t results_grad[C] = {0};
+
+ #pragma unroll
+ for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {
+ float w = (float)(align_corners ? resolution - 1 : resolution);
+ uint32_t pos_grid_local[D];
+
+ #pragma unroll
+ for (uint32_t nd = 0; nd < D - 1; nd++) {
+ const uint32_t d = (nd >= gd) ? (nd + 1) : nd;
+
+ if ((idx & (1 << nd)) == 0) {
+ w *= 1 - pos[d];
+ pos_grid_local[d] = pos_grid[d];
+ } else {
+ w *= pos[d];
+ pos_grid_local[d] = min(pos_grid[d] + 1, resolution - 1);
+ }
+ }
+
+ pos_grid_local[gd] = pos_grid[gd];
+ uint32_t index_left = get_grid_index(gridtype, 0, hashmap_size, resolution, pos_grid_local);
+ pos_grid_local[gd] = min(pos_grid[gd] + 1, resolution - 1);
+ uint32_t index_right = get_grid_index(gridtype, 0, hashmap_size, resolution, pos_grid_local);
+
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]) * pos_deriv[gd];
+ }
+ }
+
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ dy_dx[gd * C + ch] = results_grad[ch];
+ }
+ }
+ }
+}
+
+
+template
+__global__ void kernel_grid_backward(
+ const scalar_t * __restrict__ grad,
+ const float * __restrict__ inputs,
+ const scalar_t * __restrict__ grid,
+ const int * __restrict__ offsets,
+ scalar_t * __restrict__ grad_grid,
+ const uint32_t B, const uint32_t L, const float S, const uint32_t H,
+ const uint32_t gridtype,
+ const bool align_corners,
+ const uint32_t interp
+) {
+ const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;
+ if (b >= B) return;
+
+ const uint32_t level = blockIdx.y;
+ const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;
+
+ // locate
+ grad_grid += offsets[level] * C;
+ inputs += b * D;
+ grad += level * B * C + b * C + ch; // L, B, C
+
+ const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
+ const uint32_t resolution = (uint32_t)ceil(exp2f(level * S) * H);
+
+ // check input range (should be in [0, 1])
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ if (inputs[d] < 0 || inputs[d] > 1) {
+ return; // grad is init as 0, so we simply return.
+ }
+ }
+
+ // calculate coordinate
+ float pos[D];
+ uint32_t pos_grid[D];
+
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ // align_corners
+ if (align_corners) {
+ pos[d] = inputs[d] * (float)(resolution - 1); // [0, resolution - 1]
+ pos_grid[d] = min((uint32_t)floorf(pos[d]), resolution - 2); // left-top corner, [0, resolution - 2]
+ } else {
+ pos[d] = fminf(fmaxf(inputs[d] * (float)resolution - 0.5f, 0.0f), (float)(resolution - 1)); // [-0.5, resolution-0.5] --> [0, resolution - 1]
+ pos_grid[d] = (uint32_t)floorf(pos[d]); // left-top corner, [0, resolution - 1]
+ }
+ pos[d] -= (float)pos_grid[d];
+ // smoothstep instead of linear
+ if (interp == 1) {
+ pos[d] = smoothstep(pos[d]);
+ }
+ }
+
+ scalar_t grad_cur[N_C] = {0}; // fetch to register
+ #pragma unroll
+ for (uint32_t c = 0; c < N_C; c++) {
+ grad_cur[c] = grad[c];
+ }
+
+ // interpolate
+ #pragma unroll
+ for (uint32_t idx = 0; idx < (1 << D); idx++) {
+ float w = 1;
+ uint32_t pos_grid_local[D];
+
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ if ((idx & (1 << d)) == 0) {
+ w *= 1 - pos[d];
+ pos_grid_local[d] = pos_grid[d];
+ } else {
+ w *= pos[d];
+ pos_grid_local[d] = min(pos_grid[d] + 1, resolution - 1);
+ }
+ }
+
+ uint32_t index = get_grid_index(gridtype, ch, hashmap_size, resolution, pos_grid_local);
+
+ // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0
+ // TODO: use float which is better than __half, if N_C % 2 != 0
+ if (std::is_same::value && N_C % 2 == 0) {
+ #pragma unroll
+ for (uint32_t c = 0; c < N_C; c += 2) {
+ // process two __half at once (by interpreting as a __half2)
+ __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};
+ atomicAdd((__half2*)&grad_grid[index + c], v);
+ }
+ // float, or __half when N_C % 2 != 0 (which means C == 1)
+ } else {
+ #pragma unroll
+ for (uint32_t c = 0; c < N_C; c++) {
+ atomicAdd(&grad_grid[index + c], w * grad_cur[c]);
+ }
+ }
+ }
+}
+
+
+template
+__global__ void kernel_input_backward(
+ const scalar_t * __restrict__ grad,
+ const scalar_t * __restrict__ dy_dx,
+ scalar_t * __restrict__ grad_inputs,
+ uint32_t B, uint32_t L
+) {
+ const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
+ if (t >= B * D) return;
+
+ const uint32_t b = t / D;
+ const uint32_t d = t - b * D;
+
+ dy_dx += b * L * D * C;
+
+ scalar_t result = 0;
+
+ # pragma unroll
+ for (int l = 0; l < L; l++) {
+ # pragma unroll
+ for (int ch = 0; ch < C; ch++) {
+ result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];
+ }
+ }
+
+ grad_inputs[t] = result;
+}
+
+
+template
+void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
+ static constexpr uint32_t N_THREAD = 512;
+ const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), max_level, 1 };
+ switch (C) {
+ case 1: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
+ case 2: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
+ case 4: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
+ case 8: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
+ case 16: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
+ case 32: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, 8, 16 or 32."};
+ }
+}
+
+// inputs: [B, D], float, in [0, 1]
+// embeddings: [sO, C], float
+// offsets: [L + 1], uint32_t
+// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.)
+// H: base resolution
+// dy_dx: [B, L * D * C]
+template
+void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
+ switch (D) {
+ case 2: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break;
+ case 3: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break;
+ case 4: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break;
+ case 5: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break;
+ default: throw std::runtime_error{"GridEncoding: D must be 2, 3, 4 or 5."};
+ }
+}
+
+template
+void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
+ static constexpr uint32_t N_THREAD = 256;
+ const uint32_t N_C = std::min(2u, C); // n_features_per_thread
+ const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), max_level, 1 };
+ switch (C) {
+ case 1:
+ kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
+ if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L);
+ break;
+ case 2:
+ kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
+ if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L);
+ break;
+ case 4:
+ kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
+ if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L);
+ break;
+ case 8:
+ kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
+ if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L);
+ break;
+ case 16:
+ kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
+ if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L);
+ break;
+ case 32:
+ kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
+ if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L);
+ break;
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, 8, 16 or 32."};
+ }
+}
+
+
+// grad: [L, B, C], float
+// inputs: [B, D], float, in [0, 1]
+// embeddings: [sO, C], float
+// offsets: [L + 1], uint32_t
+// grad_embeddings: [sO, C]
+// H: base resolution
+template
+void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
+ switch (D) {
+ case 2: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
+ case 3: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
+ case 4: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
+ case 5: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
+ default: throw std::runtime_error{"GridEncoding: D must be 2, 3, 4 or 5."};
+ }
+}
+
+
+
+void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
+ CHECK_CUDA(inputs);
+ CHECK_CUDA(embeddings);
+ CHECK_CUDA(offsets);
+ CHECK_CUDA(outputs);
+ // CHECK_CUDA(dy_dx);
+
+ CHECK_CONTIGUOUS(inputs);
+ CHECK_CONTIGUOUS(embeddings);
+ CHECK_CONTIGUOUS(offsets);
+ CHECK_CONTIGUOUS(outputs);
+ // CHECK_CONTIGUOUS(dy_dx);
+
+ CHECK_IS_FLOATING(inputs);
+ CHECK_IS_FLOATING(embeddings);
+ CHECK_IS_INT(offsets);
+ CHECK_IS_FLOATING(outputs);
+ // CHECK_IS_FLOATING(dy_dx);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ embeddings.scalar_type(), "grid_encode_forward", ([&] {
+ grid_encode_forward_cuda(inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), outputs.data_ptr(), B, D, C, L, max_level, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, gridtype, align_corners, interp);
+ }));
+}
+
+void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
+ CHECK_CUDA(grad);
+ CHECK_CUDA(inputs);
+ CHECK_CUDA(embeddings);
+ CHECK_CUDA(offsets);
+ CHECK_CUDA(grad_embeddings);
+ // CHECK_CUDA(dy_dx);
+ // CHECK_CUDA(grad_inputs);
+
+ CHECK_CONTIGUOUS(grad);
+ CHECK_CONTIGUOUS(inputs);
+ CHECK_CONTIGUOUS(embeddings);
+ CHECK_CONTIGUOUS(offsets);
+ CHECK_CONTIGUOUS(grad_embeddings);
+ // CHECK_CONTIGUOUS(dy_dx);
+ // CHECK_CONTIGUOUS(grad_inputs);
+
+ CHECK_IS_FLOATING(grad);
+ CHECK_IS_FLOATING(inputs);
+ CHECK_IS_FLOATING(embeddings);
+ CHECK_IS_INT(offsets);
+ CHECK_IS_FLOATING(grad_embeddings);
+ // CHECK_IS_FLOATING(dy_dx);
+ // CHECK_IS_FLOATING(grad_inputs);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ grad.scalar_type(), "grid_encode_backward", ([&] {
+ grid_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), grad_embeddings.data_ptr(), B, D, C, L, max_level, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr() : nullptr, gridtype, align_corners, interp);
+ }));
+
+}
+
+
+template
+__global__ void kernel_grad_tv(
+ const scalar_t * __restrict__ inputs,
+ const scalar_t * __restrict__ grid,
+ scalar_t * __restrict__ grad,
+ const int * __restrict__ offsets,
+ const float weight,
+ const uint32_t B, const uint32_t L, const float S, const uint32_t H,
+ const uint32_t gridtype,
+ const bool align_corners
+) {
+ const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (b >= B) return;
+
+ const uint32_t level = blockIdx.y;
+
+ // locate
+ inputs += b * D;
+ grid += (uint32_t)offsets[level] * C;
+ grad += (uint32_t)offsets[level] * C;
+
+ // check input range (should be in [0, 1])
+ bool flag_oob = false;
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ if (inputs[d] < 0 || inputs[d] > 1) {
+ flag_oob = true;
+ }
+ }
+
+ // if input out of bound, do nothing
+ if (flag_oob) return;
+
+ const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
+ const uint32_t resolution = (uint32_t)ceil(exp2f(level * S) * H);
+
+ // calculate coordinate
+ float pos[D];
+ uint32_t pos_grid[D]; // [0, resolution]
+
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ // align_corners
+ if (align_corners) {
+ pos[d] = inputs[d] * (float)(resolution - 1); // [0, resolution - 1]
+ pos_grid[d] = min((uint32_t)floorf(pos[d]), resolution - 2); // left-top corner, [0, resolution - 2]
+ } else {
+ pos[d] = fminf(fmaxf(inputs[d] * (float)resolution - 0.5f, 0.0f), (float)(resolution - 1)); // [-0.5, resolution-0.5] --> [0, resolution - 1]
+ pos_grid[d] = (uint32_t)floorf(pos[d]); // left-top corner, [0, resolution - 1]
+ }
+ }
+
+ //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
+
+ // total variation on pos_grid
+ scalar_t results[C] = {0}; // temp results in register
+ scalar_t idelta[C] = {0};
+
+ uint32_t index = get_grid_index(gridtype, 0, hashmap_size, resolution, pos_grid);
+
+ scalar_t w = weight / (2 * D);
+
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+
+ uint32_t cur_d = pos_grid[d];
+ scalar_t grad_val;
+
+ // right side
+ if (cur_d < resolution) {
+ pos_grid[d] = cur_d + 1;
+ uint32_t index_right = get_grid_index(gridtype, 0, hashmap_size, resolution, pos_grid);
+
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ grad_val = (grid[index + ch] - grid[index_right + ch]);
+ results[ch] += grad_val;
+ idelta[ch] += grad_val * grad_val;
+ }
+ }
+
+ // left side
+ if (cur_d > 0) {
+ pos_grid[d] = cur_d - 1;
+ uint32_t index_left = get_grid_index(gridtype, 0, hashmap_size, resolution, pos_grid);
+
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ grad_val = (grid[index + ch] - grid[index_left + ch]);
+ results[ch] += grad_val;
+ idelta[ch] += grad_val * grad_val;
+ }
+ }
+
+ // reset
+ pos_grid[d] = cur_d;
+ }
+
+ // writing to global memory (slow)
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ // index may collide, so use atomic!
+ atomicAdd(&grad[index + ch], w * results[ch] * rsqrtf(idelta[ch] + 1e-9f));
+ }
+
+}
+
+
+template
+void kernel_grad_tv_wrapper(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {
+ static constexpr uint32_t N_THREAD = 512;
+ const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };
+ switch (C) {
+ case 1: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
+ case 2: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
+ case 4: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
+ case 8: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
+ case 16: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
+ case 32: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, 8, 16 or 32."};
+ }
+}
+
+
+template
+void grad_total_variation_cuda(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {
+ switch (D) {
+ case 2: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
+ case 3: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
+ case 4: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
+ case 5: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
+ default: throw std::runtime_error{"GridEncoding: D must be 2, 3, 4, or 5."};
+ }
+}
+
+
+void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ embeddings.scalar_type(), "grad_total_variation", ([&] {
+ grad_total_variation_cuda(inputs.data_ptr(), embeddings.data_ptr(), grad.data_ptr(), offsets.data_ptr(), weight, B, D, C, L, S, H, gridtype, align_corners);
+ }));
+}
+
+template
+__global__ void kernel_grad_wd(
+ const scalar_t * __restrict__ grid,
+ scalar_t * __restrict__ grad,
+ const int * __restrict__ offsets,
+ const float weight,
+ const uint32_t B, const uint32_t L, const uint32_t C
+) {
+ const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (b >= B * C) return;
+
+ // locate
+ grid += b;
+ grad += b;
+
+ // decide in which level is this thread...
+ uint32_t level = 0;
+ const uint32_t n = b / C;
+ // binary search b in offsets
+ uint32_t l = 0, r = L;
+ while (l < r) {
+ uint32_t m = (l + r) / 2;
+ if (offsets[m] <= n) {
+ level = m;
+ l = m + 1;
+ } else {
+ r = m;
+ }
+ }
+
+ const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
+ grad[0] += 2 * weight * grid[0] / hashmap_size;
+}
+
+void grad_weight_decay(const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L) {
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ embeddings.scalar_type(), "grad_weight_decay", ([&] {
+ static constexpr uint32_t N_THREAD = 1024;
+ const dim3 blocks_hashgrid = { div_round_up(B * C, N_THREAD), 1, 1 };
+ kernel_grad_wd<<>>(embeddings.data_ptr(), grad.data_ptr(), offsets.data_ptr(), weight, B, L, C);
+ }));
+}
\ No newline at end of file
diff --git a/gridencoder/src/gridencoder.h b/gridencoder/src/gridencoder.h
new file mode 100644
index 0000000..3df2e08
--- /dev/null
+++ b/gridencoder/src/gridencoder.h
@@ -0,0 +1,18 @@
+#ifndef _HASH_ENCODE_H
+#define _HASH_ENCODE_H
+
+#include
+#include
+
+// inputs: [B, D], float, in [0, 1]
+// embeddings: [sO, C], float
+// offsets: [L + 1], uint32_t
+// outputs: [B, L * C], float
+// H: base resolution
+void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp);
+void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp);
+
+void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners);
+void grad_weight_decay(const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L);
+
+#endif
\ No newline at end of file
diff --git a/guidance/clip_utils.py b/guidance/clip_utils.py
new file mode 100644
index 0000000..ddc0500
--- /dev/null
+++ b/guidance/clip_utils.py
@@ -0,0 +1,52 @@
+import torch
+import torch.nn as nn
+
+import torchvision.transforms as T
+import torchvision.transforms.functional as TF
+
+import clip
+
+class CLIP(nn.Module):
+ def __init__(self, device, **kwargs):
+ super().__init__()
+
+ self.device = device
+ self.clip_model, self.clip_preprocess = clip.load("ViT-B/16", device=self.device, jit=False)
+
+ self.aug = T.Compose([
+ T.Resize((224, 224)),
+ T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+ ])
+
+ def get_text_embeds(self, prompt, **kwargs):
+
+ text = clip.tokenize(prompt).to(self.device)
+ text_z = self.clip_model.encode_text(text)
+ text_z = text_z / text_z.norm(dim=-1, keepdim=True)
+
+ return text_z
+
+ def get_img_embeds(self, image, **kwargs):
+
+ image_z = self.clip_model.encode_image(self.aug(image))
+ image_z = image_z / image_z.norm(dim=-1, keepdim=True)
+
+ return image_z
+
+
+ def train_step(self, clip_z, pred_rgb, grad_scale=10, **kwargs):
+
+ image_z = self.clip_model.encode_image(self.aug(pred_rgb))
+ image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
+
+ loss = 0
+ if 'image' in clip_z:
+ loss = loss - (image_z * clip_z['image']).sum(-1).mean()
+
+ if 'text' in clip_z:
+ loss = loss - (image_z * clip_z['text']).sum(-1).mean()
+
+ loss = loss * grad_scale
+
+ return loss
+
diff --git a/guidance/if_utils.py b/guidance/if_utils.py
new file mode 100644
index 0000000..0dcce22
--- /dev/null
+++ b/guidance/if_utils.py
@@ -0,0 +1,207 @@
+from transformers import logging
+from diffusers import IFPipeline, DDPMScheduler
+
+# suppress partial model loading warning
+logging.set_verbosity_error()
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+
+class SpecifyGradient(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd
+ def forward(ctx, input_tensor, gt_grad):
+ ctx.save_for_backward(gt_grad)
+ # we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
+ return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_scale):
+ gt_grad, = ctx.saved_tensors
+ gt_grad = gt_grad * grad_scale
+ return gt_grad, None
+
+def seed_everything(seed):
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ #torch.backends.cudnn.deterministic = True
+ #torch.backends.cudnn.benchmark = True
+
+
+class IF(nn.Module):
+ def __init__(self, device, vram_O, t_range=[0.02, 0.98]):
+ super().__init__()
+
+ self.device = device
+
+ print(f'[INFO] loading DeepFloyd IF-I-XL...')
+
+ model_key = "DeepFloyd/IF-I-XL-v1.0"
+
+ is_torch2 = torch.__version__[0] == '2'
+
+ # Create model
+ pipe = IFPipeline.from_pretrained(model_key, variant="fp16", torch_dtype=torch.float16)
+ if not is_torch2:
+ pipe.enable_xformers_memory_efficient_attention()
+
+ if vram_O:
+ pipe.unet.to(memory_format=torch.channels_last)
+ pipe.enable_attention_slicing(1)
+ pipe.enable_model_cpu_offload()
+ else:
+ pipe.to(device)
+
+ self.unet = pipe.unet
+ self.tokenizer = pipe.tokenizer
+ self.text_encoder = pipe.text_encoder
+ self.unet = pipe.unet
+ self.scheduler = pipe.scheduler
+
+ self.pipe = pipe
+
+ self.num_train_timesteps = self.scheduler.config.num_train_timesteps
+ self.min_step = int(self.num_train_timesteps * t_range[0])
+ self.max_step = int(self.num_train_timesteps * t_range[1])
+ self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
+
+ print(f'[INFO] loaded DeepFloyd IF-I-XL!')
+
+ @torch.no_grad()
+ def get_text_embeds(self, prompt):
+ # prompt: [str]
+
+ # TODO: should I add the preprocessing at https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py#LL486C10-L486C28
+ prompt = self.pipe._text_preprocessing(prompt, clean_caption=False)
+ inputs = self.tokenizer(prompt, padding='max_length', max_length=77, truncation=True, add_special_tokens=True, return_tensors='pt')
+ embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
+
+ return embeddings
+
+
+ def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, grad_scale=1):
+
+ # [0, 1] to [-1, 1] and make sure shape is [64, 64]
+ images = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
+
+ # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
+ t = torch.randint(self.min_step, self.max_step + 1, (images.shape[0],), dtype=torch.long, device=self.device)
+
+ # predict the noise residual with unet, NO grad!
+ with torch.no_grad():
+ # add noise
+ noise = torch.randn_like(images)
+ images_noisy = self.scheduler.add_noise(images, noise, t)
+
+ # pred noise
+ model_input = torch.cat([images_noisy] * 2)
+ model_input = self.scheduler.scale_model_input(model_input, t)
+ tt = torch.cat([t] * 2)
+ noise_pred = self.unet(model_input, tt, encoder_hidden_states=text_embeddings).sample
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
+ noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # TODO: how to use the variance here?
+ # noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
+
+ # w(t), sigma_t^2
+ w = (1 - self.alphas[t])
+ grad = grad_scale * w[:, None, None, None] * (noise_pred - noise)
+ grad = torch.nan_to_num(grad)
+
+ # since we omitted an item in grad, we need to use the custom function to specify the gradient
+ loss = SpecifyGradient.apply(images, grad)
+
+ return loss
+
+ @torch.no_grad()
+ def produce_imgs(self, text_embeddings, height=64, width=64, num_inference_steps=50, guidance_scale=7.5):
+
+ images = torch.randn((1, 3, height, width), device=text_embeddings.device, dtype=text_embeddings.dtype)
+ images = images * self.scheduler.init_noise_sigma
+
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ for i, t in enumerate(self.scheduler.timesteps):
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
+ model_input = torch.cat([images] * 2)
+ model_input = self.scheduler.scale_model_input(model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
+ noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ images = self.scheduler.step(noise_pred, t, images).prev_sample
+
+ images = (images + 1) / 2
+
+ return images
+
+
+ def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
+
+ if isinstance(prompts, str):
+ prompts = [prompts]
+
+ if isinstance(negative_prompts, str):
+ negative_prompts = [negative_prompts]
+
+ # Prompts -> text embeds
+ pos_embeds = self.get_text_embeds(prompts) # [1, 77, 768]
+ neg_embeds = self.get_text_embeds(negative_prompts)
+ text_embeds = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768]
+
+ # Text embeds -> img
+ imgs = self.produce_imgs(text_embeds, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]
+
+ # Img to Numpy
+ imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
+ imgs = (imgs * 255).round().astype('uint8')
+
+ return imgs
+
+
+if __name__ == '__main__':
+
+ import argparse
+ import matplotlib.pyplot as plt
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('prompt', type=str)
+ parser.add_argument('--negative', default='', type=str)
+ parser.add_argument('--vram_O', action='store_true', help="optimization for low VRAM usage")
+ parser.add_argument('-H', type=int, default=64)
+ parser.add_argument('-W', type=int, default=64)
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--steps', type=int, default=50)
+ opt = parser.parse_args()
+
+ seed_everything(opt.seed)
+
+ device = torch.device('cuda')
+
+ sd = IF(device, opt.vram_O)
+
+ imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
+
+ # visualize image
+ plt.imshow(imgs[0])
+ plt.show()
+
+
+
+
diff --git a/guidance/sd_utils.py b/guidance/sd_utils.py
new file mode 100644
index 0000000..4bef064
--- /dev/null
+++ b/guidance/sd_utils.py
@@ -0,0 +1,707 @@
+from typing import List, Optional, Sequence, Tuple, Union, Mapping
+import os
+
+from dataclasses import dataclass
+from torch.cuda.amp import custom_bwd, custom_fwd
+import torch
+from torch import Tensor
+import torch.nn.functional as F
+import torch.nn as nn
+import torch.nn.functional as F
+from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
+from diffusers.utils.import_utils import is_xformers_available
+from os.path import isfile
+from pathlib import Path
+import numpy as np
+from PIL import Image
+from torchvision.io import read_image
+from torchvision import transforms
+from torchvision.transforms import functional as TVF
+from torchvision.utils import make_grid, save_image
+from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer, CLIPProcessor
+import logging
+logger = logging.getLogger(__name__)
+
+
+def spherical_dist_loss(x, y):
+ x = F.normalize(x, dim=-1)
+ y = F.normalize(y, dim=-1)
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
+
+
+def seed_everything(seed=None):
+ if seed:
+ seed = int(seed)
+ random.seed(seed)
+ os.environ['PYTHONHASHSEED'] = str(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ # torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = True
+
+def save_tensor2image(x: torch.Tensor, path, channel_last=True, quality=75, **kwargs):
+ # assume the input x is channel last
+ if x.ndim == 4 and channel_last:
+ x = x.permute(0, 3, 1, 2)
+ TVF.to_pil_image(make_grid(x, value_range=(0, 1), **kwargs)).save(path, quality=quality)
+
+
+def to_pil(x: torch.Tensor, **kwargs) -> Image.Image:
+ return TVF.to_pil_image(make_grid(x, value_range=(0, 1), **kwargs))
+
+
+def to_np_img(x: torch.Tensor) -> np.ndarray:
+ return (x.detach().permute(0, 2, 3, 1).cpu().numpy() * 255).round().astype(np.uint8)
+
+
+class SpecifyGradient(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd
+ def forward(ctx, input_tensor, gt_grad):
+ ctx.save_for_backward(gt_grad)
+ # we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
+ return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_scale):
+ gt_grad, = ctx.saved_tensors
+ gt_grad = gt_grad * grad_scale
+ return gt_grad, None
+
+
+def token_replace(prompt, negative, learned_embeds_path):
+ # Set up automatic token replacement for prompt
+ if '' in prompt or '' in negative:
+ if learned_embeds_path is None:
+ raise ValueError(
+ '--learned_embeds_path must be specified when using ')
+ import torch
+ tmp = list(torch.load(learned_embeds_path, map_location='cpu').keys())
+ if len(tmp) != 1:
+ raise ValueError(
+ 'Something is wrong with the dict passed in for --learned_embeds_path')
+ token = tmp[0]
+ prompt = prompt.replace('', token)
+ negative = negative.replace('', token)
+ logger.info(f'Prompt after replacing : {prompt}')
+ logger.info(f'Negative prompt after replacing : {negative}')
+ return prompt, negative
+
+
+@dataclass
+class UNet2DConditionOutput:
+ # Not sure how to check what unet_traced.pt contains, and user wants. HalfTensor or FloatTensor
+ sample: torch.HalfTensor
+
+
+def enable_vram(pipe):
+ pipe.enable_sequential_cpu_offload()
+ pipe.enable_vae_slicing()
+ pipe.unet.to(memory_format=torch.channels_last)
+ pipe.enable_attention_slicing(1)
+ # pipe.enable_model_cpu_offload()
+
+
+def get_model_path(sd_version='2.1', clip_version='large', hf_key=None):
+ if hf_key is not None:
+ logger.info(f'[INFO] using hugging face custom model key: {hf_key}')
+ sd_path = hf_key
+ elif sd_version == '2.1':
+ sd_path = "stabilityai/stable-diffusion-2-1-base"
+ elif sd_version == '2.0':
+ sd_path = "stabilityai/stable-diffusion-2-base"
+ elif sd_version == '1.5':
+ sd_path = "runwayml/stable-diffusion-v1-5"
+ else:
+ raise ValueError(
+ f'Stable-diffusion version {sd_version} not supported.')
+ if clip_version == 'base':
+ clip_path = "openai/clip-vit-base-patch32"
+ else:
+ clip_path = "openai/clip-vit-large-patch14"
+ return sd_path, clip_path
+
+
+class StableDiffusion(nn.Module):
+ def __init__(self, device, fp16, vram_O,
+ sd_version='2.1', hf_key=None,
+ t_range=[0.02, 0.98],
+ use_clip=False,
+ clip_version='base',
+ clip_iterative=True,
+ clip_t=0.4,
+ **kwargs
+ ):
+ super().__init__()
+
+ self.device = device
+ self.sd_version = sd_version
+ self.vram_O = vram_O
+ self.fp16 = fp16
+
+ logger.info(f'[INFO] loading stable diffusion...')
+
+ sd_path, clip_path = get_model_path(sd_version, clip_version, hf_key)
+ self.precision_t = torch.float16 if fp16 else torch.float32
+
+ # Create model
+ pipe = StableDiffusionPipeline.from_pretrained(
+ sd_path, torch_dtype=self.precision_t, local_files_only=False)
+
+ if isfile('./unet_traced.pt'):
+ # use jitted unet
+ unet_traced = torch.jit.load('./unet_traced.pt')
+
+ class TracedUNet(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.in_channels = pipe.unet.in_channels
+ self.device = pipe.unet.device
+
+ def forward(self, latent_model_input, t, encoder_hidden_states):
+ sample = unet_traced(
+ latent_model_input, t, encoder_hidden_states)[0]
+ return UNet2DConditionOutput(sample=sample)
+ pipe.unet = TracedUNet()
+
+ self.vae = pipe.vae
+ self.tokenizer = pipe.tokenizer
+ self.text_encoder = pipe.text_encoder
+ self.unet = pipe.unet
+
+ if kwargs.get('learned_embeds_path', None) is not None:
+ learned_embeds_path = kwargs['learned_embeds_path']
+ if os.path.exists(learned_embeds_path):
+ logger.info(
+ f'[INFO] loading learned embeddings from {kwargs["learned_embeds_path"]}')
+ self.add_tokens_to_model_from_path(learned_embeds_path, kwargs.get('overrride_token', None))
+ else:
+ logger.warning(f'learned_embeds_path {learned_embeds_path} does not exist!')
+
+ if vram_O:
+ # this will change device from gpu to other types (meta)
+ enable_vram(pipe)
+ else:
+ if is_xformers_available():
+ pipe.enable_xformers_memory_efficient_attention()
+ pipe.to(device)
+
+ self.scheduler = DDIMScheduler.from_pretrained(
+ sd_path, subfolder="scheduler", torch_dtype=self.precision_t, local_files_only=False)
+
+ self.num_train_timesteps = self.scheduler.config.num_train_timesteps
+ self.min_step = int(self.num_train_timesteps * t_range[0])
+ self.max_step = int(self.num_train_timesteps * t_range[1])
+ self.alphas = self.scheduler.alphas_cumprod.to(
+ self.device) # for convenience
+
+ logger.info(f'[INFO] loaded stable diffusion!')
+
+ # for CLIP
+ self.use_clip = use_clip
+ if self.use_clip:
+ #breakpoint()
+ self.clip_model = CLIPModel.from_pretrained(clip_path).to(device)
+ image_processor = CLIPProcessor.from_pretrained(clip_path).image_processor
+ self.image_processor = transforms.Compose([
+ transforms.Resize((image_processor.crop_size['height'], image_processor.crop_size['width'])),
+ transforms.Normalize(image_processor.image_mean, image_processor.image_std),
+ ])
+ for p in self.clip_model.parameters():
+ p.requires_grad = False
+
+ self.clip_iterative = clip_iterative
+ self.clip_t = int(self.num_train_timesteps * clip_t)
+
+ @torch.no_grad()
+ def get_text_embeds(self, prompt):
+ # Tokenize text and get embeddings
+ text_input = self.tokenizer(
+ prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
+
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
+ return text_embeddings
+
+ @torch.no_grad()
+ def get_all_text_embeds(self, prompt):
+ # Tokenize text and get embeddings
+ text_input = self.tokenizer(
+ prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
+
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))
+ # text_z = text_z / text_z.norm(dim=-1, keepdim=True)
+
+ # return all text embeddings and class embeddings
+ return torch.cat([text_embeddings[0], text_embeddings[1].unsqueeze(1)], dim=1)
+
+ # @torch.no_grad()
+ def get_clip_img_embeds(self, img):
+ img = self.image_processor(img)
+ image_z = self.clip_model.get_image_features(img)
+ image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
+ return image_z
+
+ def clip_loss(self, ref_z, pred_rgb):
+ image_z = self.get_clip_img_embeds(pred_rgb)
+ loss = spherical_dist_loss(image_z, ref_z)
+ return loss
+
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+ def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, as_latent=False, grad_clip=None, grad_scale=1.0,
+ image_ref_clip=None, text_ref_clip=None, clip_guidance=100, clip_image_loss=False,
+ density=None,
+ save_guidance_path=None
+ ):
+ enable_clip = self.use_clip and clip_guidance > 0 and not as_latent
+ enable_sds = True
+ #breakpoint()
+ if as_latent:
+ latents = F.interpolate(
+ pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
+ else:
+ # interp to 512x512 to be fed into vae.
+ pred_rgb_512 = F.interpolate(
+ pred_rgb, (512, 512), mode='bilinear', align_corners=False)
+ # encode image into latents with vae, requires grad!
+ latents = self.encode_imgs(pred_rgb_512)
+
+ # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
+ # Since during the optimzation, the 3D is getting better.
+ # mn = max(self.min_step, int(self.max_step - (self.max_step - self.min_step) / (self.opt.max_epoch // 3) * self.epoch + 0.5))
+ t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device)
+ if enable_clip and self.clip_iterative:
+ if t > self.clip_t:
+ enable_clip = False
+ else:
+ enable_sds = False
+
+ # predict the noise residual with unet, NO grad!
+ with torch.no_grad():
+ # add noise
+ noise = torch.randn_like(latents)
+ latents_noisy = self.scheduler.add_noise(latents, noise, t)
+
+ # pred noise
+ latent_model_input = torch.cat([latents_noisy] * 2)
+ # Save input tensors for UNet
+ # torch.save(latent_model_input, "train_latent_model_input.pt")
+ # torch.save(t, "train_t.pt")
+ # torch.save(text_embeddings, "train_text_embeddings.pt")
+ tt = torch.cat([t]*2)
+ noise_pred = self.unet(latent_model_input, tt,
+ encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance (high scale from paper!)
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_text + guidance_scale * \
+ (noise_pred_text - noise_pred_uncond)
+
+ if enable_clip:
+ pred_original_sample = (latents_noisy - (1 - self.alphas[t]) ** (0.5) * noise_pred) / self.alphas[t] ** (0.5)
+ sample = pred_original_sample
+ sample = sample.detach().requires_grad_()
+
+ sample = 1 / self.vae.config.scaling_factor * sample
+ out_image = self.vae.decode(sample).sample
+ out_image = (out_image / 2 + 0.5)#.clamp(0, 1)
+ image_embeddings_clip = self.get_clip_img_embeds(out_image)
+ ref_clip = image_ref_clip if clip_image_loss else text_ref_clip
+ loss_clip = spherical_dist_loss(image_embeddings_clip, ref_clip).mean() * clip_guidance * 50 # 100
+ grad_clipd = - torch.autograd.grad(loss_clip, sample, retain_graph=True)[0]
+ else:
+ grad_clipd = 0
+
+ # import kiui
+ # latents_tmp = torch.randn((1, 4, 64, 64), device=self.device)
+ # latents_tmp = latents_tmp.detach()
+ # kiui.lo(latents_tmp)
+ # self.scheduler.set_timesteps(30)
+ # for i, t in enumerate(self.scheduler.timesteps):
+ # latent_model_input = torch.cat([latents_tmp] * 3)
+ # noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
+ # noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
+ # noise_pred = noise_pred_uncond + 10 * (noise_pred_pos - noise_pred_uncond)
+ # latents_tmp = self.scheduler.step(noise_pred, t, latents_tmp)['prev_sample']
+ # imgs = self.decode_latents(latents_tmp)
+ # kiui.vis.plot_image(imgs)
+
+ if density is not None:
+ with torch.no_grad():
+ density = F.interpolate(density.detach(), (64, 64), mode='bilinear', align_corners=False)
+ ids = torch.nonzero(density.squeeze())
+ spatial_weight = torch.ones_like(density, device=density.device)
+ try:
+ up = ids[:, 0].min()
+ down = ids[:, 0].max() + 1
+ ll = ids[:, 1].min()
+ rr = ids[:, 1].max() + 1
+ spatial_weight[:, :, up:down, ll:rr] += 1
+ except:
+ pass
+ # breakpoint()
+ # w(t), sigma_t^2
+ w = (1 - self.alphas[t])[:, None, None, None]
+ # w = self.alphas[t] ** 0.5 * (1 - self.alphas[t])
+
+ if enable_sds:
+ grad_sds = grad_scale * w * (noise_pred - noise)
+ loss_sds = grad_sds.abs().mean().detach()
+ else:
+ grad_sds = 0.
+ loss_sds = 0.
+
+ if enable_clip:
+ grad_clipd = w * grad_clipd.detach()
+ loss_clipd = grad_clipd.abs().mean().detach()
+ else:
+ grad_clipd = 0.
+ loss_clipd = 0.
+
+ grad = grad_clipd + grad_sds
+
+ if grad_clip is not None:
+ grad = grad.clamp(-grad_clip, grad_clip)
+
+ if density is not None:
+ grad = grad * spatial_weight / 2
+
+ grad = torch.nan_to_num(grad)
+
+ # since we omitted an item in grad, we need to use the custom function to specify the gradient
+ # loss = SpecifyGradient.apply(latents, grad)
+ # loss = loss.abs().mean().detach()
+ latents.backward(gradient=grad, retain_graph=True)
+ loss = grad.abs().mean().detach()
+
+ if not enable_clip:
+ loss_sds = loss
+
+ if save_guidance_path:
+ with torch.no_grad():
+ # save original input
+ images = []
+ os.makedirs(os.path.dirname(save_guidance_path), exist_ok=True)
+ timesteps = torch.arange(-1, 1000, 100, dtype=torch.long, device=self.device)
+ timesteps[0] *= 0
+ for t in timesteps:
+ if as_latent:
+ pred_rgb_512 = self.decode_latents(latents)
+ latents_noisy = self.scheduler.add_noise(latents, noise, t)
+
+ # pred noise
+ latent_model_input = torch.cat([latents_noisy] * 2)
+
+ noise_pred = self.unet(latent_model_input, t,
+ encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance (high scale from paper!)
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_text + guidance_scale * \
+ (noise_pred_text - noise_pred_uncond)
+
+ pred_original_sample = self.decode_latents((latents_noisy - (1 - self.alphas[t]) ** (0.5) * noise_pred) / self.alphas[t] ** (0.5))
+
+ # visualize predicted denoised image
+ # claforte: discuss this with Vikram!!
+ result_hopefully_less_noisy_image = self.decode_latents(latents - w*(noise_pred - noise))
+
+ # visualize noisier image
+ result_noisier_image = self.decode_latents(latents_noisy)
+
+ # add in the last col, w/o rendered view contraint, using random noise as latent.
+ latent_model_input = torch.cat([noise] * 2)
+ noise_pred = self.unet(latent_model_input, t,
+ encoder_hidden_states=text_embeddings).sample
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_text + guidance_scale * \
+ (noise_pred_text - noise_pred_uncond)
+ noise_diffusion_out = self.decode_latents((noise - (1 - self.alphas[t]) ** (0.5) * noise_pred) / self.alphas[t] ** (0.5))
+ # all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]
+ image = torch.cat([pred_rgb_512, pred_original_sample, result_noisier_image, result_hopefully_less_noisy_image, noise_diffusion_out],dim=0)
+ images.append(image)
+ viz_images = torch.cat(images, dim=0)
+ save_image(viz_images, save_guidance_path, nrow=5)
+
+ return loss, {'loss_sds': loss_sds, 'loss_clipd': loss_clipd}
+
+ @torch.no_grad()
+ def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
+
+ if latents is None:
+ latents = torch.randn(
+ (text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ with torch.autocast('cuda'):
+ for i, t in enumerate(self.scheduler.timesteps):
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
+ latent_model_input = torch.cat([latents] * 2)
+ # latent_model_input = scheduler.scale_model_input(latent_model_input, t)
+
+ # Save input tensors for UNet
+ # torch.save(latent_model_input, "produce_latents_latent_model_input.pt")
+ # torch.save(t, "produce_latents_t.pt")
+ # torch.save(text_embeddings, "produce_latents_text_embeddings.pt")
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
+
+ # perform guidance
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_text + guidance_scale * \
+ (noise_pred_text - noise_pred_uncond)
+
+ latents = self.scheduler.step(noise_pred, t, latents)[
+ 'prev_sample']
+
+ return latents
+
+ def decode_latents(self, latents):
+
+ latents = 1 / self.vae.config.scaling_factor * latents
+
+ # with torch.no_grad():
+ imgs = self.vae.decode(latents).sample
+
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
+
+ return imgs
+
+ def encode_imgs(self, imgs):
+ # imgs: [B, 3, H, W]
+
+ imgs = 2 * imgs - 1
+
+ posterior = self.vae.encode(imgs).latent_dist
+ latents = posterior.sample() * self.vae.config.scaling_factor
+
+ return latents
+
+ def encode_imgs_mean(self, imgs):
+ # imgs: [B, 3, H, W]
+
+ imgs = 2 * imgs - 1
+
+ latents = self.vae.encode(imgs).latent_dist.mean
+ latents = latents * self.vae.config.scaling_factor
+
+ return latents
+
+ @torch.no_grad()
+ def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, to_numpy=True):
+
+ if isinstance(prompts, str):
+ prompts = [prompts]
+
+ if isinstance(negative_prompts, str):
+ negative_prompts = [negative_prompts] * len(prompts)
+
+ prompts = tuple(prompts)
+ negative_prompts = tuple(negative_prompts)
+ # Prompts -> text embeds
+ pos_embeds = self.get_text_embeds(prompts) # [1, 77, 768]
+ neg_embeds = self.get_text_embeds(negative_prompts)
+ text_embeds = torch.cat(
+ [neg_embeds, pos_embeds], dim=0) # [2, 77, 768]
+
+ # Text embeds -> img latents
+ latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents,
+ num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]
+
+ # Img latents -> imgs
+ imgs = self.decode_latents(latents.to(
+ text_embeds.dtype)) # [1, 3, 512, 512]
+
+ # Img to Numpy
+ if to_numpy:
+ imgs = to_np_img(imgs)
+ return imgs
+
+ @torch.no_grad()
+ def img_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, img=None, to_numpy=True, t=50):
+ """
+ Known issues:
+ 1. Not able to reconstruct images even with no noise.
+ """
+ if isinstance(prompts, str):
+ prompts = [prompts]
+
+ if isinstance(negative_prompts, str):
+ negative_prompts = [negative_prompts]
+
+ # Prompts -> text embeds
+ pos_embeds = self.get_text_embeds(prompts) # [1, 77, 768]
+ neg_embeds = self.get_text_embeds(negative_prompts)
+ text_embeds = torch.cat(
+ [neg_embeds, pos_embeds], dim=0) # [2, 77, 768]
+
+ # image to latent
+ # interp to 512x512 to be fed into vae.
+ if isinstance(img, str):
+ img = TVF.to_tensor(Image.open(img))[None, :3].cuda()
+
+ img_512 = F.interpolate(
+ img.to(text_embeds.dtype), (512, 512), mode='bilinear', align_corners=False)
+ # logger.info(img_512.shape, img_512, '\n', img_512.min(), img_512.max(), img_512.mean())
+
+ # encode image into latents with vae, requires grad!
+ latents = self.encode_imgs(img_512).repeat(
+ text_embeds.shape[0] // 2, 1, 1, 1)
+ # logger.info(latents.shape, latents, '\n', latents.min(), latents.max(), latents.mean())
+
+ noise = torch.randn_like(latents)
+ if t > 0:
+ latents_noise = self.scheduler.add_noise(
+ latents, noise, torch.tensor(t).to(torch.int32))
+ else:
+ latents_noise = latents
+
+ # Text embeds -> img latents
+ latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents_noise,
+ num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]
+
+ # Img latents -> imgs
+ imgs = self.decode_latents(latents.to(
+ text_embeds.dtype)) # [1, 3, 512, 512]
+
+ # Img to Numpy
+ if to_numpy:
+ imgs = to_np_img(imgs)
+ return imgs
+
+ def add_tokens_to_model(self, learned_embeds: Mapping[str, Tensor], override_token: Optional[Union[str, dict]] = None) -> None:
+ r"""Adds tokens to the tokenizer and text encoder of a model."""
+
+ # Loop over learned embeddings
+ new_tokens = []
+ for token, embedding in learned_embeds.items():
+ embedding = embedding.to(
+ self.text_encoder.get_input_embeddings().weight.dtype)
+ if override_token is not None:
+ token = override_token if isinstance(
+ override_token, str) else override_token[token]
+
+ # Add the token to the tokenizer
+ num_added_tokens = self.tokenizer.add_tokens(token)
+ if num_added_tokens == 0:
+ raise ValueError((f"The tokenizer already contains the token {token}. Please pass a "
+ "different `token` that is not already in the tokenizer."))
+
+ # Resize the token embeddings
+ self.text_encoder._resize_token_embeddings(len(self.tokenizer))
+
+ # Get the id for the token and assign the embeds
+ token_id = self.tokenizer.convert_tokens_to_ids(token)
+ self.text_encoder.get_input_embeddings(
+ ).weight.data[token_id] = embedding
+ new_tokens.append(token)
+
+ logger.info(
+ f'Added {len(new_tokens)} tokens to tokenizer and text embedding: {new_tokens}')
+
+ def add_tokens_to_model_from_path(self, learned_embeds_path: str, override_token: Optional[Union[str, dict]] = None) -> None:
+ r"""Loads tokens from a file and adds them to the tokenizer and text encoder of a model."""
+ learned_embeds: Mapping[str, Tensor] = torch.load(
+ learned_embeds_path, map_location='cpu')
+ self.add_tokens_to_model(learned_embeds, override_token)
+
+ def check_prompt(self, opt):
+ texts = ['', ', front view', ', side view', ', back view']
+ for view_text in texts:
+ text = opt.text + view_text
+ logger.info(f'Checking stable diffusion model with prompt: {text}')
+ # Generate
+ image_check = self.prompt_to_img(
+ prompts=[text] * opt.get('prompt_check_nums', 5), guidance_scale=7.5, to_numpy=False,
+ num_inference_steps=opt.get('num_inference_steps', 50))
+ # Save
+ output_dir_check = Path(opt.workspace) / 'prompt_check'
+ output_dir_check.mkdir(exist_ok=True, parents=True)
+ to_pil(image_check).save(output_dir_check / f'generations_{view_text}.png')
+ (output_dir_check / 'prompt.txt').write_text(text)
+
+
+
+if __name__ == '__main__':
+
+ import argparse
+ import matplotlib.pyplot as plt
+ from easydict import EasyDict as edict
+ import glob
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--text', type=str)
+ parser.add_argument('--negative', default='', type=str)
+ parser.add_argument('--workspace', default='out/sd', type=str)
+ parser.add_argument('--image_path', default=None, type=str)
+ parser.add_argument('--learned_embeds_path', type=str,
+ default=None, help="path to learned embeds"
+ )
+ parser.add_argument('--sd_version', type=str, default='1.5',
+ choices=['1.5', '2.0', '2.1'], help="stable diffusion version")
+ parser.add_argument('--hf_key', type=str, default=None,
+ help="hugging face Stable diffusion model key")
+ parser.add_argument('--fp16', action='store_true',
+ help="use float16 for training")
+ parser.add_argument('--vram_O', action='store_true',
+ help="optimization for low VRAM usage")
+ parser.add_argument('--gudiance_scale', type=float, default=100)
+ parser.add_argument('-H', type=int, default=512)
+ parser.add_argument('-W', type=int, default=512)
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--num_inference_steps', type=int, default=50)
+ parser.add_argument('--noise_t', type=int, default=50)
+ parser.add_argument('--prompt_check_nums', type=int, default=5)
+ opt, unknown = parser.parse_known_args()
+
+ # seed_everything(opt.seed)
+ device = torch.device('cuda')
+ opt = edict(vars(opt))
+ workspace = opt.workspace
+
+ opt.original_text = opt.text
+ opt.original_negative = opt.negative
+ if opt.learned_embeds_path is not None:
+ # cml:
+ # python guidance/sd_utils.py --text "A high-resolution DSLR image of " --learned_embeds_path out/learned_embeds/ --workspace out/teddy_bear
+ # check prompt
+ if os.path.isdir(opt.learned_embeds_path):
+ learned_embeds_paths = glob.glob(os.path.join(opt.learned_embeds_path, 'learned_embeds*bin'))
+ else:
+ learned_embeds_paths = [opt.learned_embeds_path]
+
+ for learned_embeds_path in learned_embeds_paths:
+ embed_name = os.path.basename(learned_embeds_path).split('.')[0]
+ opt.workspace = os.path.join(workspace, embed_name)
+ sd = StableDiffusion(device, opt.fp16, opt.vram_O,
+ opt.sd_version, opt.hf_key,
+ learned_embeds_path=learned_embeds_path
+ )
+ # Add tokenizer
+ if learned_embeds_path is not None: # add textual inversion tokens to model
+ opt.text, opt.negative = token_replace(
+ opt.original_text, opt.original_negative, learned_embeds_path)
+ logger.info(opt.text, opt.negative)
+ sd.check_prompt(opt)
+ else:
+ #breakpoint()
+ if opt.image_path is not None:
+ save_promt = '_'.join(opt.text.split(' ')) + '_' + opt.image_path.split(
+ '/')[-1].split('.')[0] + '_' + str(opt.noise_t) + '_' + str(opt.num_inference_steps)
+ imgs = sd.img_to_img([opt.text]*opt.prompt_check_nums, [opt.negative]*opt.prompt_check_nums, opt.H, opt.W, opt.num_inference_steps,
+ to_numpy=False, img=opt.image_path, t=opt.noise_t, guidance_scale=opt.gudiance_scale)
+ else:
+ save_promt = '_'.join(opt.text.split(' '))
+ imgs = sd.prompt_to_img([opt.text]*opt.prompt_check_nums, [opt.negative]
+ * opt.prompt_check_nums, opt.H, opt.W, opt.num_inference_steps, to_numpy=False)
+ # visualize image
+ output_dir_check = Path(opt.workspace)
+ output_dir_check.mkdir(exist_ok=True, parents=True)
+
+ to_pil(imgs).save(output_dir_check / f'{save_promt}.png')
diff --git a/guidance/shape_utils.py b/guidance/shape_utils.py
new file mode 100644
index 0000000..33e5b1c
--- /dev/null
+++ b/guidance/shape_utils.py
@@ -0,0 +1,81 @@
+from typing import Any, Dict, Optional, Sequence, Tuple, Union
+from shap_e.models.transmitter.base import Transmitter
+from shap_e.models.query import Query
+from shap_e.models.nerstf.renderer import NeRSTFRenderer
+from shap_e.util.collections import AttrDict
+from shap_e.diffusion.sample import sample_latents
+from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
+from shap_e.models.download import load_model, load_config
+from shap_e.util.image_util import load_image
+from shap_e.models.nn.meta import subdict
+import torch
+import gc
+
+
+camera_to_shapes = [
+ torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32),
+ torch.tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=torch.float32), # to bird view
+ torch.tensor([[0, 1, 0], [0, 0, 1], [-1, 0, 0]], dtype=torch.float32), # to rotaed bird view
+ torch.tensor([[0, -1, 0], [0, 0, 1], [-1, 0, 0]], dtype=torch.float32), # to rotaed bird view
+ torch.tensor([[-1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=torch.float32), # to bird view
+ ]
+
+
+def get_density(
+ render,
+ query: Query,
+ params: Dict[str, torch.Tensor],
+ options: AttrDict[str, Any],
+) -> torch.Tensor:
+ assert render.nerstf is not None
+ return render.nerstf(query, params=subdict(params, "nerstf"), options=options).density
+
+
+@torch.no_grad()
+def get_shape_from_image(image_path, pos,
+ rpst_type='sdf', # or 'density'
+ get_color=True,
+ shape_guidance=3, device='cuda'):
+ xm = load_model('transmitter', device=device)
+ model = load_model('image300M', device=device)
+ diffusion = diffusion_from_config(load_config('diffusion'))
+ latent = sample_latents(
+ batch_size=1,
+ model=model,
+ diffusion=diffusion,
+ guidance_scale=shape_guidance,
+ model_kwargs=dict(images=[load_image(image_path)]),
+ progress=True,
+ clip_denoised=True,
+ use_fp16=True,
+ use_karras=True,
+ karras_steps=64,
+ sigma_min=1e-3,
+ sigma_max=160,
+ s_churn=0,
+ )[0]
+
+ params = (xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params(
+ latent[None]
+ )
+
+ rpsts, colors = [], []
+ for camera_to_shape in camera_to_shapes:
+ query = Query(
+ position=pos @ camera_to_shape.to(pos.device),
+ direction=None,
+ )
+
+ if rpst_type == 'sdf':
+ rpst = xm.renderer.get_signed_distance(query, params, AttrDict())
+ else:
+ rpst = get_density(xm.renderer, query, params, AttrDict())
+ rpsts.append(rpst.squeeze())
+
+ if get_color:
+ color = xm.renderer.get_texture(query, params, AttrDict())
+ else:
+ color = None
+ colors.append(color)
+
+ return rpsts, colors
\ No newline at end of file
diff --git a/guidance/zero123_utils.py b/guidance/zero123_utils.py
new file mode 100644
index 0000000..28c7513
--- /dev/null
+++ b/guidance/zero123_utils.py
@@ -0,0 +1,332 @@
+import math
+import numpy as np
+from omegaconf import OmegaConf
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.cuda.amp import custom_bwd, custom_fwd
+from torchvision.utils import save_image
+
+from diffusers import DDIMScheduler
+
+import sys
+from os import path
+sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
+
+from ldm.util import instantiate_from_config
+
+class SpecifyGradient(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd
+ def forward(ctx, input_tensor, gt_grad):
+ ctx.save_for_backward(gt_grad)
+ # we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
+ return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_scale):
+ gt_grad, = ctx.saved_tensors
+ gt_grad = gt_grad * grad_scale
+ return gt_grad, None
+
+# load model
+def load_model_from_config(config, ckpt, device, vram_O=False, verbose=False):
+
+ pl_sd = torch.load(ckpt, map_location='cpu')
+
+ if 'global_step' in pl_sd and verbose:
+ print(f'[INFO] Global Step: {pl_sd["global_step"]}')
+
+ sd = pl_sd['state_dict']
+
+ model = instantiate_from_config(config.model)
+ m, u = model.load_state_dict(sd, strict=False)
+
+ if len(m) > 0 and verbose:
+ print('[INFO] missing keys: \n', m)
+ if len(u) > 0 and verbose:
+ print('[INFO] unexpected keys: \n', u)
+
+ # manually load ema and delete it to save GPU memory
+ if model.use_ema:
+ if verbose:
+ print('[INFO] loading EMA...')
+ model.model_ema.copy_to(model.model)
+ del model.model_ema
+
+ if vram_O:
+ # we don't need decoder
+ del model.first_stage_model.decoder
+
+ torch.cuda.empty_cache()
+
+ model.eval().to(device)
+
+ return model
+
+class Zero123(nn.Module):
+ def __init__(self, device, fp16,
+ config='./pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml',
+ ckpt='./pretrained/zero123/105000.ckpt', vram_O=False, t_range=[0.02, 0.98], opt=None):
+ super().__init__()
+
+ self.device = device
+ self.fp16 = fp16
+ self.vram_O = vram_O
+ self.t_range = t_range
+ self.opt = opt
+
+ self.config = OmegaConf.load(config)
+ self.model = load_model_from_config(self.config, ckpt, device=self.device, vram_O=vram_O)
+
+ # timesteps: use diffuser for convenience... hope it's alright.
+ self.num_train_timesteps = self.config.model.params.timesteps
+
+ self.scheduler = DDIMScheduler(
+ self.num_train_timesteps,
+ self.config.model.params.linear_start,
+ self.config.model.params.linear_end,
+ beta_schedule='scaled_linear',
+ clip_sample=False,
+ set_alpha_to_one=False,
+ steps_offset=1,
+ )
+
+ self.min_step = int(self.num_train_timesteps * t_range[0])
+ self.max_step = int(self.num_train_timesteps * t_range[1])
+ self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
+
+ @torch.no_grad()
+ def get_img_embeds(self, x):
+ # x: image tensor [B, 3, 256, 256] in [0, 1]
+ x = x * 2 - 1
+ c = [self.model.get_learned_conditioning(xx.unsqueeze(0)) for xx in x] #.tile(n_samples, 1, 1)
+ v = [self.model.encode_first_stage(xx.unsqueeze(0)).mode() for xx in x]
+ return c, v
+
+ def angle_between(self, sph_v1, sph_v2):
+ def sph2cart(sv):
+ r, theta, phi = sv[0], sv[1], sv[2]
+ return torch.tensor([r * torch.sin(theta) * torch.cos(phi), r * torch.sin(theta) * torch.sin(phi), r * torch.cos(theta)])
+ def unit_vector(v):
+ return v / torch.linalg.norm(v)
+ def angle_between_2_sph(sv1, sv2):
+ v1, v2 = sph2cart(sv1), sph2cart(sv2)
+ v1_u, v2_u = unit_vector(v1), unit_vector(v2)
+ return torch.arccos(torch.clip(torch.dot(v1_u, v2_u), -1.0, 1.0))
+ angles = torch.empty(len(sph_v1), len(sph_v2))
+ for i, sv1 in enumerate(sph_v1):
+ for j, sv2 in enumerate(sph_v2):
+ angles[i][j] = angle_between_2_sph(sv1, sv2)
+ return angles
+
+ def train_step(self, embeddings, pred_rgb, polar, azimuth, radius, guidance_scale=3, as_latent=False, grad_scale=1, save_guidance_path:Path=None):
+ # pred_rgb: tensor [1, 3, H, W] in [0, 1]
+
+ # adjust SDS scale based on how far the novel view is from the known view
+ ref_radii = embeddings['ref_radii']
+ ref_polars = embeddings['ref_polars']
+ ref_azimuths = embeddings['ref_azimuths']
+ v1 = torch.stack([radius + ref_radii[0], torch.deg2rad(polar + ref_polars[0]), torch.deg2rad(azimuth + ref_azimuths[0])], dim=-1) # polar,azimuth,radius are all actually delta wrt default
+ v2 = torch.stack([torch.tensor(ref_radii), torch.deg2rad(torch.tensor(ref_polars)), torch.deg2rad(torch.tensor(ref_azimuths))], dim=-1)
+ angles = torch.rad2deg(self.angle_between(v1, v2)).to(self.device)
+ if self.opt.zero123_grad_scale == 'angle':
+ grad_scale = (angles.min(dim=1)[0] / (180/len(ref_azimuths))) * grad_scale # rethink 180/len(ref_azimuths) # claforte: try inverting grad_scale or just fixing it to 1.0
+ elif self.opt.zero123_grad_scale == 'None':
+ grad_scale = 1.0 # claforte: I think this might converge faster...?
+ else:
+ assert False, f'Unrecognized `zero123_grad_scale`: {self.opt.zero123_grad_scale}'
+
+ if as_latent:
+ latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1
+ else:
+ pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False)
+ latents = self.encode_imgs(pred_rgb_256)
+
+ t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device)
+
+ # Set weights acc to closeness in angle
+ if len(ref_azimuths) > 1:
+ inv_angles = 1/angles
+ inv_angles[inv_angles > 100] = 100
+ inv_angles /= inv_angles.max(dim=-1, keepdim=True)[0]
+ inv_angles[inv_angles < 0.1] = 0
+ else:
+ inv_angles = torch.tensor([1.]).to(self.device)
+
+ # Multiply closeness-weight by user-given weights
+ zero123_ws = torch.tensor(embeddings['zero123_ws'])[None, :].to(self.device) * inv_angles
+ zero123_ws /= zero123_ws.max(dim=-1, keepdim=True)[0]
+ zero123_ws[zero123_ws < 0.1] = 0
+
+ with torch.no_grad():
+ noise = torch.randn_like(latents)
+ latents_noisy = self.scheduler.add_noise(latents, noise, t)
+
+ x_in = torch.cat([latents_noisy] * 2)
+ t_in = torch.cat([t] * 2)
+
+ noise_preds = []
+ # Loop through each ref image
+ for (zero123_w, c_crossattn, c_concat, ref_polar, ref_azimuth, ref_radius) in zip(zero123_ws.T,
+ embeddings['c_crossattn'], embeddings['c_concat'],
+ ref_polars, ref_azimuths, ref_radii):
+ # polar,azimuth,radius are all actually delta wrt default
+ p = polar + ref_polars[0] - ref_polar
+ a = azimuth + ref_azimuths[0] - ref_azimuth
+ a[a > 180] -= 360 # range in [-180, 180]
+ r = radius + ref_radii[0] - ref_radius
+ # T = torch.tensor([math.radians(p), math.sin(math.radians(-a)), math.cos(math.radians(a)), r])
+ # T = T[None, None, :].to(self.device)
+ T = torch.stack([torch.deg2rad(p), torch.sin(torch.deg2rad(-a)), torch.cos(torch.deg2rad(a)), r], dim=-1)[:, None, :]
+ cond = {}
+ clip_emb = self.model.cc_projection(torch.cat([c_crossattn.repeat(len(T), 1, 1), T], dim=-1))
+ cond['c_crossattn'] = [torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)]
+ cond['c_concat'] = [torch.cat([torch.zeros_like(c_concat).repeat(len(T), 1, 1, 1).to(self.device), c_concat.repeat(len(T), 1, 1, 1)], dim=0)]
+ noise_pred = self.model.apply_model(x_in, t_in, cond)
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
+ noise_preds.append(zero123_w[:, None, None, None] * noise_pred)
+
+ noise_pred = torch.stack(noise_preds).sum(dim=0) / zero123_ws.sum(dim=-1)[:, None, None, None]
+
+ w = (1 - self.alphas[t])
+ grad = (grad_scale * w)[:, None, None, None] * (noise_pred - noise)
+ grad = torch.nan_to_num(grad)
+
+ # import kiui
+ # if not as_latent:
+ # kiui.vis.plot_image(pred_rgb_256)
+ # kiui.vis.plot_matrix(latents)
+ # kiui.vis.plot_matrix(grad)
+
+ # import kiui
+ # latents = torch.randn((1, 4, 32, 32), device=self.device)
+ # kiui.lo(latents)
+ # self.scheduler.set_timesteps(30)
+ # with torch.no_grad():
+ # for i, t in enumerate(self.scheduler.timesteps):
+ # x_in = torch.cat([latents] * 2)
+ # t_in = torch.cat([t.view(1)] * 2).to(self.device)
+
+ # noise_pred = self.model.apply_model(x_in, t_in, cond)
+ # noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
+ # noise_pred = noise_pred_uncond + 3 * (noise_pred_cond - noise_pred_uncond)
+
+ # latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']
+ # imgs = self.decode_latents(latents)
+ # print(polar, azimuth, radius)
+ # kiui.vis.plot_image(pred_rgb_256, imgs)
+
+ if save_guidance_path:
+ with torch.no_grad():
+ if as_latent:
+ pred_rgb_256 = self.decode_latents(latents) # claforte: test!
+
+ # visualize predicted denoised image
+ result_hopefully_less_noisy_image = self.decode_latents(self.model.predict_start_from_noise(latents_noisy, t, noise_pred))
+
+ # visualize noisier image
+ result_noisier_image = self.decode_latents(latents_noisy)
+
+ # all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]
+ viz_images = torch.cat([pred_rgb_256, result_noisier_image, result_hopefully_less_noisy_image],dim=-1)
+ save_image(viz_images, save_guidance_path)
+
+ # since we omitted an item in grad, we need to use the custom function to specify the gradient
+ # loss = SpecifyGradient.apply(latents, grad)
+ latents.backward(gradient=grad, retain_graph=True)
+ loss = grad.abs().mean().detach()
+ return loss
+
+ # verification
+ @torch.no_grad()
+ def __call__(self,
+ image, # image tensor [1, 3, H, W] in [0, 1]
+ polar=0, azimuth=0, radius=0, # new view params
+ scale=3, ddim_steps=50, ddim_eta=1, h=256, w=256, # diffusion params
+ c_crossattn=None, c_concat=None, post_process=True,
+ ):
+
+ if c_crossattn is None:
+ embeddings = self.get_img_embeds(image)
+
+ T = torch.tensor([math.radians(polar), math.sin(math.radians(azimuth)), math.cos(math.radians(azimuth)), radius])
+ T = T[None, None, :].to(self.device)
+
+ cond = {}
+ clip_emb = self.model.cc_projection(torch.cat([embeddings['c_crossattn'] if c_crossattn is None else c_crossattn, T], dim=-1))
+ cond['c_crossattn'] = [torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)]
+ cond['c_concat'] = [torch.cat([torch.zeros_like(embeddings['c_concat']).to(self.device), embeddings['c_concat']], dim=0)] if c_concat is None else [torch.cat([torch.zeros_like(c_concat).to(self.device), c_concat], dim=0)]
+
+ # produce latents loop
+ latents = torch.randn((1, 4, h // 8, w // 8), device=self.device)
+ self.scheduler.set_timesteps(ddim_steps)
+
+ for i, t in enumerate(self.scheduler.timesteps):
+ x_in = torch.cat([latents] * 2)
+ t_in = torch.cat([t.view(1)] * 2).to(self.device)
+
+ noise_pred = self.model.apply_model(x_in, t_in, cond)
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + scale * (noise_pred_cond - noise_pred_uncond)
+
+ latents = self.scheduler.step(noise_pred, t, latents, eta=ddim_eta)['prev_sample']
+
+ imgs = self.decode_latents(latents)
+ imgs = imgs.cpu().numpy().transpose(0, 2, 3, 1) if post_process else imgs
+
+ return imgs
+
+ def decode_latents(self, latents):
+ # zs: [B, 4, 32, 32] Latent space image
+ # with self.model.ema_scope():
+ imgs = self.model.decode_first_stage(latents)
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
+
+ return imgs # [B, 3, 256, 256] RGB space image
+
+ def encode_imgs(self, imgs):
+ # imgs: [B, 3, 256, 256] RGB space image
+ # with self.model.ema_scope():
+ imgs = imgs * 2 - 1
+ latents = torch.cat([self.model.get_first_stage_encoding(self.model.encode_first_stage(img.unsqueeze(0))) for img in imgs], dim=0)
+ return latents # [B, 4, 32, 32] Latent space image
+
+
+if __name__ == '__main__':
+ import cv2
+ import argparse
+ import numpy as np
+ import matplotlib.pyplot as plt
+
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('input', type=str)
+ parser.add_argument('--fp16', action='store_true', help="use float16 for training") # no use now, can only run in fp32
+
+ parser.add_argument('--polar', type=float, default=0, help='delta polar angle in [-90, 90]')
+ parser.add_argument('--azimuth', type=float, default=0, help='delta azimuth angle in [-180, 180]')
+ parser.add_argument('--radius', type=float, default=0, help='delta camera radius multiplier in [-0.5, 0.5]')
+
+ opt = parser.parse_args()
+
+ device = torch.device('cuda')
+
+ print(f'[INFO] loading image from {opt.input} ...')
+ image = cv2.imread(opt.input, cv2.IMREAD_UNCHANGED)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA)
+ image = image.astype(np.float32) / 255.0
+ image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).contiguous().to(device)
+
+ print(f'[INFO] loading model ...')
+ zero123 = Zero123(device, opt.fp16, opt=opt)
+
+ print(f'[INFO] running model ...')
+ outputs = zero123(image, polar=opt.polar, azimuth=opt.azimuth, radius=opt.radius)
+ plt.imshow(outputs[0])
+ plt.show()
diff --git a/install.sh b/install.sh
new file mode 100644
index 0000000..d3caa1f
--- /dev/null
+++ b/install.sh
@@ -0,0 +1,24 @@
+
+# for KAUST cluster
+module load cuda/11.7.0
+module load gcc/7.5.0
+module load eigen
+
+# for aws ubuntu. install eigen
+#sudo apt update && sudo apt upgrade
+#sudo apt install libeigen3-dev
+
+ # a100: 8.0; v100: 7.0; 2080ti: 7.5; titan xp: 6.1
+export TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0"
+
+# use python venv
+python -m venv venv_magic123
+source venv_magic123/bin/activate
+
+# use conda
+# conda create -n magic123 python=3.10 ipython -y
+# conda activate magic123
+
+pip3 install torch torchvision
+pip3 install -r requirements.txt
+bash scripts/install_ext.sh
\ No newline at end of file
diff --git a/ldm/extras.py b/ldm/extras.py
new file mode 100755
index 0000000..62e654b
--- /dev/null
+++ b/ldm/extras.py
@@ -0,0 +1,77 @@
+from pathlib import Path
+from omegaconf import OmegaConf
+import torch
+from ldm.util import instantiate_from_config
+import logging
+from contextlib import contextmanager
+
+from contextlib import contextmanager
+import logging
+
+@contextmanager
+def all_logging_disabled(highest_level=logging.CRITICAL):
+ """
+ A context manager that will prevent any logging messages
+ triggered during the body from being processed.
+
+ :param highest_level: the maximum logging level in use.
+ This would only need to be changed if a custom level greater than CRITICAL
+ is defined.
+
+ https://gist.github.com/simon-weber/7853144
+ """
+ # two kind-of hacks here:
+ # * can't get the highest logging level in effect => delegate to the user
+ # * can't get the current module-level override => use an undocumented
+ # (but non-private!) interface
+
+ previous_level = logging.root.manager.disable
+
+ logging.disable(highest_level)
+
+ try:
+ yield
+ finally:
+ logging.disable(previous_level)
+
+def load_training_dir(train_dir, device, epoch="last"):
+ """Load a checkpoint and config from training directory"""
+ train_dir = Path(train_dir)
+ ckpt = list(train_dir.rglob(f"*{epoch}.ckpt"))
+ assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files"
+ config = list(train_dir.rglob(f"*-project.yaml"))
+ assert len(ckpt) > 0, f"didn't find any config in {train_dir}"
+ if len(config) > 1:
+ print(f"found {len(config)} matching config files")
+ config = sorted(config)[-1]
+ print(f"selecting {config}")
+ else:
+ config = config[0]
+
+
+ config = OmegaConf.load(config)
+ return load_model_from_config(config, ckpt[0], device)
+
+def load_model_from_config(config, ckpt, device="cpu", verbose=False):
+ """Loads a model from config and a ckpt
+ if config is a path will use omegaconf to load
+ """
+ if isinstance(config, (str, Path)):
+ config = OmegaConf.load(config)
+
+ with all_logging_disabled():
+ print(f"Loading model from {ckpt}")
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ global_step = pl_sd["global_step"]
+ sd = pl_sd["state_dict"]
+ model = instantiate_from_config(config.model)
+ m, u = model.load_state_dict(sd, strict=False)
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ model.to(device)
+ model.eval()
+ model.cond_stage_model.device = device
+ return model
\ No newline at end of file
diff --git a/ldm/guidance.py b/ldm/guidance.py
new file mode 100755
index 0000000..53d1a2a
--- /dev/null
+++ b/ldm/guidance.py
@@ -0,0 +1,96 @@
+from typing import List, Tuple
+from scipy import interpolate
+import numpy as np
+import torch
+import matplotlib.pyplot as plt
+from IPython.display import clear_output
+import abc
+
+
+class GuideModel(torch.nn.Module, abc.ABC):
+ def __init__(self) -> None:
+ super().__init__()
+
+ @abc.abstractmethod
+ def preprocess(self, x_img):
+ pass
+
+ @abc.abstractmethod
+ def compute_loss(self, inp):
+ pass
+
+
+class Guider(torch.nn.Module):
+ def __init__(self, sampler, guide_model, scale=1.0, verbose=False):
+ """Apply classifier guidance
+
+ Specify a guidance scale as either a scalar
+ Or a schedule as a list of tuples t = 0->1 and scale, e.g.
+ [(0, 10), (0.5, 20), (1, 50)]
+ """
+ super().__init__()
+ self.sampler = sampler
+ self.index = 0
+ self.show = verbose
+ self.guide_model = guide_model
+ self.history = []
+
+ if isinstance(scale, (Tuple, List)):
+ times = np.array([x[0] for x in scale])
+ values = np.array([x[1] for x in scale])
+ self.scale_schedule = {"times": times, "values": values}
+ else:
+ self.scale_schedule = float(scale)
+
+ self.ddim_timesteps = sampler.ddim_timesteps
+ self.ddpm_num_timesteps = sampler.ddpm_num_timesteps
+
+
+ def get_scales(self):
+ if isinstance(self.scale_schedule, float):
+ return len(self.ddim_timesteps)*[self.scale_schedule]
+
+ interpolater = interpolate.interp1d(self.scale_schedule["times"], self.scale_schedule["values"])
+ fractional_steps = np.array(self.ddim_timesteps)/self.ddpm_num_timesteps
+ return interpolater(fractional_steps)
+
+ def modify_score(self, model, e_t, x, t, c):
+
+ # TODO look up index by t
+ scale = self.get_scales()[self.index]
+
+ if (scale == 0):
+ return e_t
+
+ sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device)
+ with torch.enable_grad():
+ x_in = x.detach().requires_grad_(True)
+ pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t)
+ x_img = model.first_stage_model.decode((1/0.18215)*pred_x0)
+
+ inp = self.guide_model.preprocess(x_img)
+ loss = self.guide_model.compute_loss(inp)
+ grads = torch.autograd.grad(loss.sum(), x_in)[0]
+ correction = grads * scale
+
+ if self.show:
+ clear_output(wait=True)
+ print(loss.item(), scale, correction.abs().max().item(), e_t.abs().max().item())
+ self.history.append([loss.item(), scale, correction.min().item(), correction.max().item()])
+ plt.imshow((inp[0].detach().permute(1,2,0).clamp(-1,1).cpu()+1)/2)
+ plt.axis('off')
+ plt.show()
+ plt.imshow(correction[0][0].detach().cpu())
+ plt.axis('off')
+ plt.show()
+
+
+ e_t_mod = e_t - sqrt_1ma*correction
+ if self.show:
+ fig, axs = plt.subplots(1, 3)
+ axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2)
+ axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2)
+ axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2)
+ plt.show()
+ self.index += 1
+ return e_t_mod
\ No newline at end of file
diff --git a/ldm/lr_scheduler.py b/ldm/lr_scheduler.py
new file mode 100755
index 0000000..be39da9
--- /dev/null
+++ b/ldm/lr_scheduler.py
@@ -0,0 +1,98 @@
+import numpy as np
+
+
+class LambdaWarmUpCosineScheduler:
+ """
+ note: use with a base_lr of 1.0
+ """
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
+ self.lr_warm_up_steps = warm_up_steps
+ self.lr_start = lr_start
+ self.lr_min = lr_min
+ self.lr_max = lr_max
+ self.lr_max_decay_steps = max_decay_steps
+ self.last_lr = 0.
+ self.verbosity_interval = verbosity_interval
+
+ def schedule(self, n, **kwargs):
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
+ if n < self.lr_warm_up_steps:
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
+ self.last_lr = lr
+ return lr
+ else:
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
+ t = min(t, 1.0)
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
+ 1 + np.cos(t * np.pi))
+ self.last_lr = lr
+ return lr
+
+ def __call__(self, n, **kwargs):
+ return self.schedule(n,**kwargs)
+
+
+class LambdaWarmUpCosineScheduler2:
+ """
+ supports repeated iterations, configurable via lists
+ note: use with a base_lr of 1.0.
+ """
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
+ self.lr_warm_up_steps = warm_up_steps
+ self.f_start = f_start
+ self.f_min = f_min
+ self.f_max = f_max
+ self.cycle_lengths = cycle_lengths
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
+ self.last_f = 0.
+ self.verbosity_interval = verbosity_interval
+
+ def find_in_interval(self, n):
+ interval = 0
+ for cl in self.cum_cycles[1:]:
+ if n <= cl:
+ return interval
+ interval += 1
+
+ def schedule(self, n, **kwargs):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
+ f"current cycle {cycle}")
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
+ t = min(t, 1.0)
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
+ 1 + np.cos(t * np.pi))
+ self.last_f = f
+ return f
+
+ def __call__(self, n, **kwargs):
+ return self.schedule(n, **kwargs)
+
+
+class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
+
+ def schedule(self, n, **kwargs):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
+ f"current cycle {cycle}")
+
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
+ self.last_f = f
+ return f
+
diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py
new file mode 100755
index 0000000..6a9c4f4
--- /dev/null
+++ b/ldm/models/autoencoder.py
@@ -0,0 +1,443 @@
+import torch
+import pytorch_lightning as pl
+import torch.nn.functional as F
+from contextlib import contextmanager
+
+from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
+
+from ldm.modules.diffusionmodules.model import Encoder, Decoder
+from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+
+from ldm.util import instantiate_from_config
+
+
+class VQModel(pl.LightningModule):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ batch_resize_range=None,
+ scheduler_config=None,
+ lr_g_factor=1.0,
+ remap=None,
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
+ use_ema=False
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.n_embed = n_embed
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ self.loss = instantiate_from_config(lossconfig)
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
+ remap=remap,
+ sane_index_shape=sane_index_shape)
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels)==int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+ self.batch_resize_range = batch_resize_range
+ if self.batch_resize_range is not None:
+ print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
+
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+ self.scheduler_config = scheduler_config
+ self.lr_g_factor = lr_g_factor
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.parameters())
+ self.model_ema.copy_to(self)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ missing, unexpected = self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ print(f"Unexpected Keys: {unexpected}")
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self)
+
+ def encode(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ quant, emb_loss, info = self.quantize(h)
+ return quant, emb_loss, info
+
+ def encode_to_prequant(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ return h
+
+ def decode(self, quant):
+ quant = self.post_quant_conv(quant)
+ dec = self.decoder(quant)
+ return dec
+
+ def decode_code(self, code_b):
+ quant_b = self.quantize.embed_code(code_b)
+ dec = self.decode(quant_b)
+ return dec
+
+ def forward(self, input, return_pred_indices=False):
+ quant, diff, (_,_,ind) = self.encode(input)
+ dec = self.decode(quant)
+ if return_pred_indices:
+ return dec, diff, ind
+ return dec, diff
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+ if self.batch_resize_range is not None:
+ lower_size = self.batch_resize_range[0]
+ upper_size = self.batch_resize_range[1]
+ if self.global_step <= 4:
+ # do the first few batches with max size to avoid later oom
+ new_resize = upper_size
+ else:
+ new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
+ if new_resize != x.shape[2]:
+ x = F.interpolate(x, size=new_resize, mode="bicubic")
+ x = x.detach()
+ return x
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ # https://github.com/pytorch/pytorch/issues/37142
+ # try not to fool the heuristics
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss, ind = self(x, return_pred_indices=True)
+
+ if optimizer_idx == 0:
+ # autoencode
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train",
+ predicted_indices=ind)
+
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return aeloss
+
+ if optimizer_idx == 1:
+ # discriminator
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ log_dict = self._validation_step(batch, batch_idx)
+ with self.ema_scope():
+ log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
+ return log_dict
+
+ def _validation_step(self, batch, batch_idx, suffix=""):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss, ind = self(x, return_pred_indices=True)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="val"+suffix,
+ predicted_indices=ind
+ )
+
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="val"+suffix,
+ predicted_indices=ind
+ )
+ rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
+ self.log(f"val{suffix}/rec_loss", rec_loss,
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+ self.log(f"val{suffix}/aeloss", aeloss,
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
+ del log_dict_ae[f"val{suffix}/rec_loss"]
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def configure_optimizers(self):
+ lr_d = self.learning_rate
+ lr_g = self.lr_g_factor*self.learning_rate
+ print("lr_d", lr_d)
+ print("lr_g", lr_g)
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quantize.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=lr_g, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr_d, betas=(0.5, 0.9))
+
+ if self.scheduler_config is not None:
+ scheduler = instantiate_from_config(self.scheduler_config)
+
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ },
+ {
+ 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ },
+ ]
+ return [opt_ae, opt_disc], scheduler
+ return [opt_ae, opt_disc], []
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ if only_inputs:
+ log["inputs"] = x
+ return log
+ xrec, _ = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["inputs"] = x
+ log["reconstructions"] = xrec
+ if plot_ema:
+ with self.ema_scope():
+ xrec_ema, _ = self(x)
+ if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
+ log["reconstructions_ema"] = xrec_ema
+ return log
+
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+ return x
+
+
+class VQModelInterface(VQModel):
+ def __init__(self, embed_dim, *args, **kwargs):
+ super().__init__(embed_dim=embed_dim, *args, **kwargs)
+ self.embed_dim = embed_dim
+
+ def encode(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ return h
+
+ def decode(self, h, force_not_quantize=False):
+ # also go through quantization layer
+ if not force_not_quantize:
+ quant, emb_loss, info = self.quantize(h)
+ else:
+ quant = h
+ quant = self.post_quant_conv(quant)
+ dec = self.decoder(quant)
+ return dec
+
+
+class AutoencoderKL(pl.LightningModule):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ ):
+ super().__init__()
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ self.loss = instantiate_from_config(lossconfig)
+ assert ddconfig["double_z"]
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ self.embed_dim = embed_dim
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels)==int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path}")
+
+ def encode(self, x):
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+
+ if optimizer_idx == 0:
+ # train encoder+decoder+logvar
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return aeloss
+
+ if optimizer_idx == 1:
+ # train the discriminator
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=lr, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr, betas=(0.5, 0.9))
+ return [opt_ae, opt_disc], []
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ @torch.no_grad()
+ def log_images(self, batch, only_inputs=False, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ if not only_inputs:
+ xrec, posterior = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
+ log["reconstructions"] = xrec
+ log["inputs"] = x
+ return log
+
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+ return x
+
+
+class IdentityFirstStage(torch.nn.Module):
+ def __init__(self, *args, vq_interface=False, **kwargs):
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
+ super().__init__()
+
+ def encode(self, x, *args, **kwargs):
+ return x
+
+ def decode(self, x, *args, **kwargs):
+ return x
+
+ def quantize(self, x, *args, **kwargs):
+ if self.vq_interface:
+ return x, None, [None, None, None]
+ return x
+
+ def forward(self, x, *args, **kwargs):
+ return x
diff --git a/ldm/models/diffusion/__init__.py b/ldm/models/diffusion/__init__.py
new file mode 100755
index 0000000..e69de29
diff --git a/ldm/models/diffusion/classifier.py b/ldm/models/diffusion/classifier.py
new file mode 100755
index 0000000..67e98b9
--- /dev/null
+++ b/ldm/models/diffusion/classifier.py
@@ -0,0 +1,267 @@
+import os
+import torch
+import pytorch_lightning as pl
+from omegaconf import OmegaConf
+from torch.nn import functional as F
+from torch.optim import AdamW
+from torch.optim.lr_scheduler import LambdaLR
+from copy import deepcopy
+from einops import rearrange
+from glob import glob
+from natsort import natsorted
+
+from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
+from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
+
+__models__ = {
+ 'class_label': EncoderUNetModel,
+ 'segmentation': UNetModel
+}
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class NoisyLatentImageClassifier(pl.LightningModule):
+
+ def __init__(self,
+ diffusion_path,
+ num_classes,
+ ckpt_path=None,
+ pool='attention',
+ label_key=None,
+ diffusion_ckpt_path=None,
+ scheduler_config=None,
+ weight_decay=1.e-2,
+ log_steps=10,
+ monitor='val/loss',
+ *args,
+ **kwargs):
+ super().__init__(*args, **kwargs)
+ self.num_classes = num_classes
+ # get latest config of diffusion model
+ diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
+ self.diffusion_config = OmegaConf.load(diffusion_config).model
+ self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
+ self.load_diffusion()
+
+ self.monitor = monitor
+ self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
+ self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
+ self.log_steps = log_steps
+
+ self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
+ else self.diffusion_model.cond_stage_key
+
+ assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
+
+ if self.label_key not in __models__:
+ raise NotImplementedError()
+
+ self.load_classifier(ckpt_path, pool)
+
+ self.scheduler_config = scheduler_config
+ self.use_scheduler = self.scheduler_config is not None
+ self.weight_decay = weight_decay
+
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+
+ def load_diffusion(self):
+ model = instantiate_from_config(self.diffusion_config)
+ self.diffusion_model = model.eval()
+ self.diffusion_model.train = disabled_train
+ for param in self.diffusion_model.parameters():
+ param.requires_grad = False
+
+ def load_classifier(self, ckpt_path, pool):
+ model_config = deepcopy(self.diffusion_config.params.unet_config.params)
+ model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
+ model_config.out_channels = self.num_classes
+ if self.label_key == 'class_label':
+ model_config.pool = pool
+
+ self.model = __models__[self.label_key](**model_config)
+ if ckpt_path is not None:
+ print('#####################################################################')
+ print(f'load from ckpt "{ckpt_path}"')
+ print('#####################################################################')
+ self.init_from_ckpt(ckpt_path)
+
+ @torch.no_grad()
+ def get_x_noisy(self, x, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x))
+ continuous_sqrt_alpha_cumprod = None
+ if self.diffusion_model.use_continuous_noise:
+ continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
+ # todo: make sure t+1 is correct here
+
+ return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
+ continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
+
+ def forward(self, x_noisy, t, *args, **kwargs):
+ return self.model(x_noisy, t)
+
+ @torch.no_grad()
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = rearrange(x, 'b h w c -> b c h w')
+ x = x.to(memory_format=torch.contiguous_format).float()
+ return x
+
+ @torch.no_grad()
+ def get_conditioning(self, batch, k=None):
+ if k is None:
+ k = self.label_key
+ assert k is not None, 'Needs to provide label key'
+
+ targets = batch[k].to(self.device)
+
+ if self.label_key == 'segmentation':
+ targets = rearrange(targets, 'b h w c -> b c h w')
+ for down in range(self.numd):
+ h, w = targets.shape[-2:]
+ targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
+
+ # targets = rearrange(targets,'b c h w -> b h w c')
+
+ return targets
+
+ def compute_top_k(self, logits, labels, k, reduction="mean"):
+ _, top_ks = torch.topk(logits, k, dim=1)
+ if reduction == "mean":
+ return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
+ elif reduction == "none":
+ return (top_ks == labels[:, None]).float().sum(dim=-1)
+
+ def on_train_epoch_start(self):
+ # save some memory
+ self.diffusion_model.model.to('cpu')
+
+ @torch.no_grad()
+ def write_logs(self, loss, logits, targets):
+ log_prefix = 'train' if self.training else 'val'
+ log = {}
+ log[f"{log_prefix}/loss"] = loss.mean()
+ log[f"{log_prefix}/acc@1"] = self.compute_top_k(
+ logits, targets, k=1, reduction="mean"
+ )
+ log[f"{log_prefix}/acc@5"] = self.compute_top_k(
+ logits, targets, k=5, reduction="mean"
+ )
+
+ self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
+ self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
+ self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
+ lr = self.optimizers().param_groups[0]['lr']
+ self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
+
+ def shared_step(self, batch, t=None):
+ x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
+ targets = self.get_conditioning(batch)
+ if targets.dim() == 4:
+ targets = targets.argmax(dim=1)
+ if t is None:
+ t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
+ else:
+ t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
+ x_noisy = self.get_x_noisy(x, t)
+ logits = self(x_noisy, t)
+
+ loss = F.cross_entropy(logits, targets, reduction='none')
+
+ self.write_logs(loss.detach(), logits.detach(), targets.detach())
+
+ loss = loss.mean()
+ return loss, logits, x_noisy, targets
+
+ def training_step(self, batch, batch_idx):
+ loss, *_ = self.shared_step(batch)
+ return loss
+
+ def reset_noise_accs(self):
+ self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
+ range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
+
+ def on_validation_start(self):
+ self.reset_noise_accs()
+
+ @torch.no_grad()
+ def validation_step(self, batch, batch_idx):
+ loss, *_ = self.shared_step(batch)
+
+ for t in self.noisy_acc:
+ _, logits, _, targets = self.shared_step(batch, t)
+ self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
+ self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
+
+ return loss
+
+ def configure_optimizers(self):
+ optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
+
+ if self.use_scheduler:
+ scheduler = instantiate_from_config(self.scheduler_config)
+
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ }]
+ return [optimizer], scheduler
+
+ return optimizer
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, *args, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.diffusion_model.first_stage_key)
+ log['inputs'] = x
+
+ y = self.get_conditioning(batch)
+
+ if self.label_key == 'class_label':
+ y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
+ log['labels'] = y
+
+ if ismap(y):
+ log['labels'] = self.diffusion_model.to_rgb(y)
+
+ for step in range(self.log_steps):
+ current_time = step * self.log_time_interval
+
+ _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
+
+ log[f'inputs@t{current_time}'] = x_noisy
+
+ pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
+ pred = rearrange(pred, 'b h w c -> b c h w')
+
+ log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
+
+ for key in log:
+ log[key] = log[key][:N]
+
+ return log
diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py
new file mode 100755
index 0000000..0683d16
--- /dev/null
+++ b/ldm/models/diffusion/ddim.py
@@ -0,0 +1,328 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+from functools import partial
+from einops import rearrange
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
+from ldm.models.diffusion.sampling_util import renorm_thresholding, norm_thresholding, spatial_norm_thresholding
+
+
+class DDIMSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+
+ def to(self, device):
+ """Same as to in torch module
+ Don't really underestand why this isn't a module in the first place"""
+ for k, v in self.__dict__.items():
+ if isinstance(v, torch.Tensor):
+ new_v = getattr(self, k).to(device)
+ setattr(self, k, new_v)
+
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ dynamic_threshold=None,
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list): ctmp = ctmp[0]
+ cbs = ctmp.shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ # print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+ samples, intermediates = self.ddim_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold,
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def ddim_sampling(self, cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
+ t_start=-1):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ timesteps = timesteps[:t_start]
+
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ # print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold)
+ img, pred_x0 = outs
+ if callback:
+ img = callback(i, img, pred_x0)
+ if img_callback:
+ img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
+ dynamic_threshold=None):
+ b, *_, device = *x.shape, x.device
+
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ if isinstance(c, dict):
+ assert isinstance(unconditional_conditioning, dict)
+ c_in = dict()
+ for k in c:
+ if isinstance(c[k], list):
+ c_in[k] = [torch.cat([
+ unconditional_conditioning[k][i],
+ c[k][i]]) for i in range(len(c[k]))]
+ else:
+ c_in[k] = torch.cat([
+ unconditional_conditioning[k],
+ c[k]])
+ else:
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+
+ print(t, sqrt_one_minus_at, a_t)
+
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+
+ if dynamic_threshold is not None:
+ pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
+
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ @torch.no_grad()
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None):
+ num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
+
+ assert t_enc <= num_reference_steps
+ num_steps = t_enc
+
+ if use_original_steps:
+ alphas_next = self.alphas_cumprod[:num_steps]
+ alphas = self.alphas_cumprod_prev[:num_steps]
+ else:
+ alphas_next = self.ddim_alphas[:num_steps]
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
+
+ x_next = x0
+ intermediates = []
+ inter_steps = []
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
+ t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
+ if unconditional_guidance_scale == 1.:
+ noise_pred = self.model.apply_model(x_next, t, c)
+ else:
+ assert unconditional_conditioning is not None
+ e_t_uncond, noise_pred = torch.chunk(
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
+ torch.cat((unconditional_conditioning, c))), 2)
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
+
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
+ weighted_noise_pred = alphas_next[i].sqrt() * (
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
+ x_next = xt_weighted + weighted_noise_pred
+ if return_intermediates and i % (
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
+ intermediates.append(x_next)
+ inter_steps.append(i)
+ elif return_intermediates and i >= num_steps - 2:
+ intermediates.append(x_next)
+ inter_steps.append(i)
+
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
+ if return_intermediates:
+ out.update({'intermediates': intermediates})
+ return x_next, out
+
+ @torch.no_grad()
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+ # fast, but does not allow for exact reconstruction
+ # t serves as an index to gather the correct alphas
+ if use_original_steps:
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+ else:
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+ if noise is None:
+ noise = torch.randn_like(x0)
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
+
+ @torch.no_grad()
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
+ use_original_steps=False):
+
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
+ timesteps = timesteps[:t_start]
+
+ time_range = np.flip(timesteps)
+ total_steps = timesteps.shape[0]
+ # print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
+ x_dec = x_latent
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning)
+ return x_dec
\ No newline at end of file
diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py
new file mode 100755
index 0000000..3fcb7ad
--- /dev/null
+++ b/ldm/models/diffusion/ddpm.py
@@ -0,0 +1,1994 @@
+"""
+wild mixture of
+https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
+https://github.com/CompVis/taming-transformers
+-- merci
+"""
+
+import torch
+import torch.nn as nn
+import numpy as np
+import pytorch_lightning as pl
+from torch.optim.lr_scheduler import LambdaLR
+from einops import rearrange, repeat
+from contextlib import contextmanager, nullcontext
+from functools import partial
+import itertools
+from tqdm import tqdm
+from torchvision.utils import make_grid
+from pytorch_lightning.utilities.rank_zero import rank_zero_only
+from omegaconf import ListConfig
+
+from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
+from ldm.modules.ema import LitEma
+from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
+from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
+from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.modules.attention import CrossAttention
+
+
+__conditioning_keys__ = {'concat': 'c_concat',
+ 'crossattn': 'c_crossattn',
+ 'adm': 'y'}
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def uniform_on_device(r1, r2, shape, device):
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
+
+
+class DDPM(pl.LightningModule):
+ # classic DDPM with Gaussian diffusion, in image space
+ def __init__(self,
+ unet_config,
+ timesteps=1000,
+ beta_schedule="linear",
+ loss_type="l2",
+ ckpt_path=None,
+ ignore_keys=[],
+ load_only_unet=False,
+ monitor="val/loss",
+ use_ema=True,
+ first_stage_key="image",
+ image_size=256,
+ channels=3,
+ log_every_t=100,
+ clip_denoised=True,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ given_betas=None,
+ original_elbo_weight=0.,
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+ l_simple_weight=1.,
+ conditioning_key=None,
+ parameterization="eps", # all assuming fixed variance schedules
+ scheduler_config=None,
+ use_positional_encodings=False,
+ learn_logvar=False,
+ logvar_init=0.,
+ make_it_fit=False,
+ ucg_training=None,
+ ):
+ super().__init__()
+ assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
+ self.parameterization = parameterization
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
+ self.cond_stage_model = None
+ self.clip_denoised = clip_denoised
+ self.log_every_t = log_every_t
+ self.first_stage_key = first_stage_key
+ self.image_size = image_size # try conv?
+ self.channels = channels
+ self.use_positional_encodings = use_positional_encodings
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
+ count_params(self.model, verbose=True)
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self.model)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ self.use_scheduler = scheduler_config is not None
+ if self.use_scheduler:
+ self.scheduler_config = scheduler_config
+
+ self.v_posterior = v_posterior
+ self.original_elbo_weight = original_elbo_weight
+ self.l_simple_weight = l_simple_weight
+
+ if monitor is not None:
+ self.monitor = monitor
+ self.make_it_fit = make_it_fit
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
+
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
+
+ self.loss_type = loss_type
+
+ self.learn_logvar = learn_logvar
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
+ if self.learn_logvar:
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
+
+ self.ucg_training = ucg_training or dict()
+ if self.ucg_training:
+ self.ucg_prng = np.random.RandomState()
+
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if exists(given_betas):
+ betas = given_betas
+ else:
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+ cosine_s=cosine_s)
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
+ 1. - alphas_cumprod) + self.v_posterior * betas
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+ self.register_buffer('posterior_mean_coef1', to_torch(
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+ self.register_buffer('posterior_mean_coef2', to_torch(
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+
+ if self.parameterization == "eps":
+ lvlb_weights = self.betas ** 2 / (
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
+ elif self.parameterization == "x0":
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
+ else:
+ raise NotImplementedError("mu not supported")
+ # TODO how to choose this term
+ lvlb_weights[0] = lvlb_weights[1]
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
+ assert not torch.isnan(self.lvlb_weights).all()
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ @torch.no_grad()
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+
+ if self.make_it_fit:
+ n_params = len([name for name, _ in
+ itertools.chain(self.named_parameters(),
+ self.named_buffers())])
+ for name, param in tqdm(
+ itertools.chain(self.named_parameters(),
+ self.named_buffers()),
+ desc="Fitting old weights to new weights",
+ total=n_params
+ ):
+ if not name in sd:
+ continue
+ old_shape = sd[name].shape
+ new_shape = param.shape
+ assert len(old_shape)==len(new_shape)
+ if len(new_shape) > 2:
+ # we only modify first two axes
+ assert new_shape[2:] == old_shape[2:]
+ # assumes first axis corresponds to output dim
+ if not new_shape == old_shape:
+ new_param = param.clone()
+ old_param = sd[name]
+ if len(new_shape) == 1:
+ for i in range(new_param.shape[0]):
+ new_param[i] = old_param[i % old_shape[0]]
+ elif len(new_shape) >= 2:
+ for i in range(new_param.shape[0]):
+ for j in range(new_param.shape[1]):
+ new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
+
+ n_used_old = torch.ones(old_shape[1])
+ for j in range(new_param.shape[1]):
+ n_used_old[j % old_shape[1]] += 1
+ n_used_new = torch.zeros(new_shape[1])
+ for j in range(new_param.shape[1]):
+ n_used_new[j] = n_used_old[j % old_shape[1]]
+
+ n_used_new = n_used_new[None, :]
+ while len(n_used_new.shape) < len(new_shape):
+ n_used_new = n_used_new.unsqueeze(-1)
+ new_param /= n_used_new
+
+ sd[name] = new_param
+
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+ return mean, variance, log_variance
+
+ def predict_start_from_noise(self, x_t, t, noise):
+ return (
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+ )
+
+ def q_posterior(self, x_start, x_t, t):
+ posterior_mean = (
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, x, t, clip_denoised: bool):
+ model_out = self.model(x, t)
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
+ b, *_, device = *x.shape, x.device
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
+ noise = noise_like(x.shape, device, repeat_noise)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def p_sample_loop(self, shape, return_intermediates=False):
+ device = self.betas.device
+ b = shape[0]
+ img = torch.randn(shape, device=device)
+ intermediates = [img]
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
+ clip_denoised=self.clip_denoised)
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
+ intermediates.append(img)
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, batch_size=16, return_intermediates=False):
+ image_size = self.image_size
+ channels = self.channels
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
+ return_intermediates=return_intermediates)
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+ def get_loss(self, pred, target, mean=True):
+ if self.loss_type == 'l1':
+ loss = (target - pred).abs()
+ if mean:
+ loss = loss.mean()
+ elif self.loss_type == 'l2':
+ if mean:
+ loss = torch.nn.functional.mse_loss(target, pred)
+ else:
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
+ else:
+ raise NotImplementedError("unknown loss type '{loss_type}'")
+
+ return loss
+
+ def p_losses(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_out = self.model(x_noisy, t)
+
+ loss_dict = {}
+ if self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "x0":
+ target = x_start
+ else:
+ raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
+
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
+
+ log_prefix = 'train' if self.training else 'val'
+
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
+ loss_simple = loss.mean() * self.l_simple_weight
+
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
+
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
+
+ loss_dict.update({f'{log_prefix}/loss': loss})
+
+ return loss, loss_dict
+
+ def forward(self, x, *args, **kwargs):
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+ return self.p_losses(x, t, *args, **kwargs)
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = rearrange(x, 'b h w c -> b c h w')
+ x = x.to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def shared_step(self, batch):
+ x = self.get_input(batch, self.first_stage_key)
+ loss, loss_dict = self(x)
+ return loss, loss_dict
+
+ def training_step(self, batch, batch_idx):
+ for k in self.ucg_training:
+ p = self.ucg_training[k]["p"]
+ val = self.ucg_training[k]["val"]
+ if val is None:
+ val = ""
+ for i in range(len(batch[k])):
+ if self.ucg_prng.choice(2, p=[1-p, p]):
+ batch[k][i] = val
+
+ loss, loss_dict = self.shared_step(batch)
+
+ self.log_dict(loss_dict, prog_bar=True,
+ logger=True, on_step=True, on_epoch=True)
+
+ self.log("global_step", self.global_step,
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+ if self.use_scheduler:
+ lr = self.optimizers().param_groups[0]['lr']
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+ return loss
+
+ @torch.no_grad()
+ def validation_step(self, batch, batch_idx):
+ _, loss_dict_no_ema = self.shared_step(batch)
+ with self.ema_scope():
+ _, loss_dict_ema = self.shared_step(batch)
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self.model)
+
+ def _get_rows_from_list(self, samples):
+ n_imgs_per_row = len(samples)
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.first_stage_key)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ x = x.to(self.device)[:N]
+ log["inputs"] = x
+
+ # get diffusion row
+ diffusion_row = list()
+ x_start = x[:n_row]
+
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(x_start)
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ diffusion_row.append(x_noisy)
+
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
+
+ if sample:
+ # get denoise row
+ with self.ema_scope("Plotting"):
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
+
+ log["samples"] = samples
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.learn_logvar:
+ params = params + [self.logvar]
+ opt = torch.optim.AdamW(params, lr=lr)
+ return opt
+
+
+class LatentDiffusion(DDPM):
+ """main class"""
+ def __init__(self,
+ first_stage_config,
+ cond_stage_config,
+ num_timesteps_cond=None,
+ cond_stage_key="image",
+ cond_stage_trainable=False,
+ concat_mode=True,
+ cond_stage_forward=None,
+ conditioning_key=None,
+ scale_factor=1.0,
+ scale_by_std=False,
+ unet_trainable=True,
+ *args, **kwargs):
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
+ self.scale_by_std = scale_by_std
+ assert self.num_timesteps_cond <= kwargs['timesteps']
+ # for backwards compatibility after implementation of DiffusionWrapper
+ if conditioning_key is None:
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
+ if cond_stage_config == '__is_unconditional__':
+ conditioning_key = None
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ignore_keys = kwargs.pop("ignore_keys", [])
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+ self.concat_mode = concat_mode
+ self.cond_stage_trainable = cond_stage_trainable
+ self.unet_trainable = unet_trainable
+ self.cond_stage_key = cond_stage_key
+ try:
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
+ except:
+ self.num_downs = 0
+ if not scale_by_std:
+ self.scale_factor = scale_factor
+ else:
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
+ self.instantiate_first_stage(first_stage_config)
+ self.instantiate_cond_stage(cond_stage_config)
+ self.cond_stage_forward = cond_stage_forward
+
+ # construct linear projection layer for concatenating image CLIP embedding and RT
+ self.cc_projection = nn.Linear(772, 768)
+ nn.init.eye_(list(self.cc_projection.parameters())[0][:768, :768])
+ nn.init.zeros_(list(self.cc_projection.parameters())[1])
+ self.cc_projection.requires_grad_(True)
+
+ self.clip_denoised = False
+ self.bbox_tokenizer = None
+
+ self.restarted_from_ckpt = False
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+ self.restarted_from_ckpt = True
+
+ def make_cond_schedule(self, ):
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
+ self.cond_ids[:self.num_timesteps_cond] = ids
+
+ @rank_zero_only
+ @torch.no_grad()
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
+ # only for very first batch
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
+ # set rescale weight to 1./std of encodings
+ print("### USING STD-RESCALING ###")
+ x = super().get_input(batch, self.first_stage_key)
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+ del self.scale_factor
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
+ print(f"setting self.scale_factor to {self.scale_factor}")
+ print("### USING STD-RESCALING ###")
+
+ def register_schedule(self,
+ given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
+
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
+ if self.shorten_cond_schedule:
+ self.make_cond_schedule()
+
+ def instantiate_first_stage(self, config):
+ model = instantiate_from_config(config)
+ self.first_stage_model = model.eval()
+ self.first_stage_model.train = disabled_train
+ for param in self.first_stage_model.parameters():
+ param.requires_grad = False
+
+ def instantiate_cond_stage(self, config):
+ if not self.cond_stage_trainable:
+ if config == "__is_first_stage__":
+ print("Using first stage also as cond stage.")
+ self.cond_stage_model = self.first_stage_model
+ elif config == "__is_unconditional__":
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
+ self.cond_stage_model = None
+ # self.be_unconditional = True
+ else:
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model.eval()
+ self.cond_stage_model.train = disabled_train
+ for param in self.cond_stage_model.parameters():
+ param.requires_grad = False
+ else:
+ assert config != '__is_first_stage__'
+ assert config != '__is_unconditional__'
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model
+
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
+ denoise_row = []
+ for zd in tqdm(samples, desc=desc):
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
+ force_not_quantize=force_no_decoder_quantization))
+ n_imgs_per_row = len(denoise_row)
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ def get_first_stage_encoding(self, encoder_posterior):
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+ z = encoder_posterior.sample()
+ elif isinstance(encoder_posterior, torch.Tensor):
+ z = encoder_posterior
+ else:
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
+ return self.scale_factor * z
+
+ def get_learned_conditioning(self, c):
+ if self.cond_stage_forward is None:
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
+ c = self.cond_stage_model.encode(c)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ else:
+ c = self.cond_stage_model(c)
+ else:
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+ return c
+
+ def meshgrid(self, h, w):
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
+
+ arr = torch.cat([y, x], dim=-1)
+ return arr
+
+ def delta_border(self, h, w):
+ """
+ :param h: height
+ :param w: width
+ :return: normalized distance to image border,
+ wtith min distance = 0 at border and max dist = 0.5 at image center
+ """
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
+ arr = self.meshgrid(h, w) / lower_right_corner
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
+ return edge_dist
+
+ def get_weighting(self, h, w, Ly, Lx, device):
+ weighting = self.delta_border(h, w)
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
+ self.split_input_params["clip_max_weight"], )
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
+
+ if self.split_input_params["tie_braker"]:
+ L_weighting = self.delta_border(Ly, Lx)
+ L_weighting = torch.clip(L_weighting,
+ self.split_input_params["clip_min_tie_weight"],
+ self.split_input_params["clip_max_tie_weight"])
+
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
+ weighting = weighting * L_weighting
+ return weighting
+
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
+ """
+ :param x: img of size (bs, c, h, w)
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
+ """
+ bs, nc, h, w = x.shape
+
+ # number of crops in image
+ Ly = (h - kernel_size[0]) // stride[0] + 1
+ Lx = (w - kernel_size[1]) // stride[1] + 1
+
+ if uf == 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
+
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
+
+ elif uf > 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
+ dilation=1, padding=0,
+ stride=(stride[0] * uf, stride[1] * uf))
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
+
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
+
+ elif df > 1 and uf == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
+ dilation=1, padding=0,
+ stride=(stride[0] // df, stride[1] // df))
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
+
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
+
+ else:
+ raise NotImplementedError
+
+ return fold, unfold, normalization, weighting
+
+
+ @torch.no_grad()
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
+ cond_key=None, return_original_cond=False, bs=None, uncond=0.05):
+ x = super().get_input(batch, k)
+ T = batch['T'].to(memory_format=torch.contiguous_format).float()
+
+ if bs is not None:
+ x = x[:bs]
+ T = T[:bs].to(self.device)
+
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+ cond_key = cond_key or self.cond_stage_key
+ xc = super().get_input(batch, cond_key).to(self.device)
+ if bs is not None:
+ xc = xc[:bs]
+ cond = {}
+
+ # To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%.
+ random = torch.rand(x.size(0), device=x.device)
+ prompt_mask = rearrange(random < 2 * uncond, "n -> n 1 1")
+ input_mask = 1 - rearrange((random >= uncond).float() * (random < 3 * uncond).float(), "n -> n 1 1 1")
+ null_prompt = self.get_learned_conditioning([""])
+
+ # z.shape: [8, 4, 64, 64]; c.shape: [8, 1, 768]
+ # print('=========== xc shape ===========', xc.shape)
+ with torch.enable_grad():
+ clip_emb = self.get_learned_conditioning(xc).detach()
+ null_prompt = self.get_learned_conditioning([""]).detach()
+ cond["c_crossattn"] = [self.cc_projection(torch.cat([torch.where(prompt_mask, null_prompt, clip_emb), T[:, None, :]], dim=-1))]
+ cond["c_concat"] = [input_mask * self.encode_first_stage((xc.to(self.device))).mode().detach()]
+ out = [z, cond]
+ if return_first_stage_outputs:
+ xrec = self.decode_first_stage(z)
+ out.extend([x, xrec])
+ if return_original_cond:
+ out.append(xc)
+ return out
+
+ # @torch.no_grad()
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+ if predict_cids:
+ if z.dim() == 4:
+ z = torch.argmax(z.exp(), dim=1).long()
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+ z = 1. / self.scale_factor * z
+
+ if hasattr(self, "split_input_params"):
+ if self.split_input_params["patch_distributed_vq"]:
+ ks = self.split_input_params["ks"] # eg. (128, 128)
+ stride = self.split_input_params["stride"] # eg. (64, 64)
+ uf = self.split_input_params["vqf"]
+ bs, nc, h, w = z.shape
+ if ks[0] > h or ks[1] > w:
+ ks = (min(ks[0], h), min(ks[1], w))
+ print("reducing Kernel")
+
+ if stride[0] > h or stride[1] > w:
+ stride = (min(stride[0], h), min(stride[1], w))
+ print("reducing stride")
+
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
+
+ z = unfold(z) # (bn, nc * prod(**ks), L)
+ # 1. Reshape to img shape
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ # 2. apply model loop over last dim
+ if isinstance(self.first_stage_model, VQModelInterface):
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
+ force_not_quantize=predict_cids or force_not_quantize)
+ for i in range(z.shape[-1])]
+ else:
+
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
+ for i in range(z.shape[-1])]
+
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
+ o = o * weighting
+ # Reverse 1. reshape to img shape
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # stitch crops together
+ decoded = fold(o)
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
+ return decoded
+ else:
+ if isinstance(self.first_stage_model, VQModelInterface):
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ else:
+ return self.first_stage_model.decode(z)
+
+ else:
+ if isinstance(self.first_stage_model, VQModelInterface):
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ else:
+ return self.first_stage_model.decode(z)
+
+ # @torch.no_grad() # wasted two hours to find this bug... why no grad here!
+ def encode_first_stage(self, x):
+ if hasattr(self, "split_input_params"):
+ if self.split_input_params["patch_distributed_vq"]:
+ ks = self.split_input_params["ks"] # eg. (128, 128)
+ stride = self.split_input_params["stride"] # eg. (64, 64)
+ df = self.split_input_params["vqf"]
+ self.split_input_params['original_image_size'] = x.shape[-2:]
+ bs, nc, h, w = x.shape
+ if ks[0] > h or ks[1] > w:
+ ks = (min(ks[0], h), min(ks[1], w))
+ print("reducing Kernel")
+
+ if stride[0] > h or stride[1] > w:
+ stride = (min(stride[0], h), min(stride[1], w))
+ print("reducing stride")
+
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
+ z = unfold(x) # (bn, nc * prod(**ks), L)
+ # Reshape to img shape
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
+ for i in range(z.shape[-1])]
+
+ o = torch.stack(output_list, axis=-1)
+ o = o * weighting
+
+ # Reverse reshape to img shape
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # stitch crops together
+ decoded = fold(o)
+ decoded = decoded / normalization
+ return decoded
+
+ else:
+ return self.first_stage_model.encode(x)
+ else:
+ return self.first_stage_model.encode(x)
+
+ def shared_step(self, batch, **kwargs):
+ x, c = self.get_input(batch, self.first_stage_key)
+ loss = self(x, c)
+ return loss
+
+ def forward(self, x, c, *args, **kwargs):
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+ if self.model.conditioning_key is not None:
+ assert c is not None
+ # if self.cond_stage_trainable:
+ # c = self.get_learned_conditioning(c)
+ if self.shorten_cond_schedule: # TODO: drop this option
+ tc = self.cond_ids[t].to(self.device)
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
+ return self.p_losses(x, c, t, *args, **kwargs)
+
+ def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
+ def rescale_bbox(bbox):
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
+ return x0, y0, w, h
+
+ return [rescale_bbox(b) for b in bboxes]
+
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
+
+ if isinstance(cond, dict):
+ # hybrid case, cond is exptected to be a dict
+ pass
+ else:
+ if not isinstance(cond, list):
+ cond = [cond]
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
+ cond = {key: cond}
+
+ if hasattr(self, "split_input_params"):
+ assert len(cond) == 1 # todo can only deal with one conditioning atm
+ assert not return_ids
+ ks = self.split_input_params["ks"] # eg. (128, 128)
+ stride = self.split_input_params["stride"] # eg. (64, 64)
+
+ h, w = x_noisy.shape[-2:]
+
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
+
+ z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
+ # Reshape to img shape
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+ z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
+
+ if self.cond_stage_key in ["image", "LR_image", "segmentation",
+ 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
+ c_key = next(iter(cond.keys())) # get key
+ c = next(iter(cond.values())) # get value
+ assert (len(c) == 1) # todo extend to list with more than one elem
+ c = c[0] # get element
+
+ c = unfold(c)
+ c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
+
+ elif self.cond_stage_key == 'coordinates_bbox':
+ assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
+
+ # assuming padding of unfold is always 0 and its dilation is always 1
+ n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
+ full_img_h, full_img_w = self.split_input_params['original_image_size']
+ # as we are operating on latents, we need the factor from the original image size to the
+ # spatial latent size to properly rescale the crops for regenerating the bbox annotations
+ num_downs = self.first_stage_model.encoder.num_resolutions - 1
+ rescale_latent = 2 ** (num_downs)
+
+ # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
+ # need to rescale the tl patch coordinates to be in between (0,1)
+ tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
+ rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
+ for patch_nr in range(z.shape[-1])]
+
+ # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
+ patch_limits = [(x_tl, y_tl,
+ rescale_latent * ks[0] / full_img_w,
+ rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
+ # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
+
+ # tokenize crop coordinates for the bounding boxes of the respective patches
+ patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
+ for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
+ # cut tknzd crop position from conditioning
+ assert isinstance(cond, dict), 'cond must be dict to be fed into model'
+ cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
+
+ adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
+ adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
+ adapted_cond = self.get_learned_conditioning(adapted_cond)
+ adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
+
+ cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
+
+ else:
+ cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
+
+ # apply model by loop over crops
+ output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
+ assert not isinstance(output_list[0],
+ tuple) # todo cant deal with multiple model outputs check this never happens
+
+ o = torch.stack(output_list, axis=-1)
+ o = o * weighting
+ # Reverse reshape to img shape
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # stitch crops together
+ x_recon = fold(o) / normalization
+
+ else:
+ x_recon = self.model(x_noisy, t, **cond)
+
+ if isinstance(x_recon, tuple) and not return_ids:
+ return x_recon[0]
+ else:
+ return x_recon
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+ This term can't be optimized, as it only depends on the encoder.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def p_losses(self, x_start, cond, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_output = self.apply_model(x_noisy, t, cond)
+
+ loss_dict = {}
+ prefix = 'train' if self.training else 'val'
+
+ if self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "eps":
+ target = noise
+ else:
+ raise NotImplementedError()
+
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
+
+ logvar_t = self.logvar[t].to(self.device)
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
+ if self.learn_logvar:
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
+ loss_dict.update({'logvar': self.logvar.data.mean()})
+
+ loss = self.l_simple_weight * loss.mean()
+
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
+ loss += (self.original_elbo_weight * loss_vlb)
+ loss_dict.update({f'{prefix}/loss': loss})
+
+ return loss, loss_dict
+
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
+ t_in = t
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
+
+ if score_corrector is not None:
+ assert self.parameterization == "eps"
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
+
+ if return_codebook_ids:
+ model_out, logits = model_out
+
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ else:
+ raise NotImplementedError()
+
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+ if quantize_denoised:
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ if return_codebook_ids:
+ return model_mean, posterior_variance, posterior_log_variance, logits
+ elif return_x0:
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
+ else:
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
+ b, *_, device = *x.shape, x.device
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
+ return_codebook_ids=return_codebook_ids,
+ quantize_denoised=quantize_denoised,
+ return_x0=return_x0,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if return_codebook_ids:
+ raise DeprecationWarning("Support dropped.")
+ model_mean, _, model_log_variance, logits = outputs
+ elif return_x0:
+ model_mean, _, model_log_variance, x0 = outputs
+ else:
+ model_mean, _, model_log_variance = outputs
+
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+
+ if return_codebook_ids:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
+ if return_x0:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
+ else:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
+ log_every_t=None):
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ timesteps = self.num_timesteps
+ if batch_size is not None:
+ b = batch_size if batch_size is not None else shape[0]
+ shape = [batch_size] + list(shape)
+ else:
+ b = batch_size = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=self.device)
+ else:
+ img = x_T
+ intermediates = []
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
+ total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+ if type(temperature) == float:
+ temperature = [temperature] * timesteps
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img, x0_partial = self.p_sample(img, cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised, return_x0=True,
+ temperature=temperature[i], noise_dropout=noise_dropout,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(x0_partial)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, start_T=None,
+ log_every_t=None):
+
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ device = self.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ intermediates = [img]
+ if timesteps is None:
+ timesteps = self.num_timesteps
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+
+ if mask is not None:
+ assert x0 is not None
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img = self.p_sample(img, cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised)
+ if mask is not None:
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(img)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
+ verbose=True, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, shape=None,**kwargs):
+ if shape is None:
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+ return self.p_sample_loop(cond,
+ shape,
+ return_intermediates=return_intermediates, x_T=x_T,
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
+ mask=mask, x0=x0)
+
+ @torch.no_grad()
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
+ if ddim:
+ ddim_sampler = DDIMSampler(self)
+ shape = (self.channels, self.image_size, self.image_size)
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
+ shape, cond, verbose=False, **kwargs)
+
+ else:
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
+ return_intermediates=True, **kwargs)
+
+ return samples, intermediates
+
+ @torch.no_grad()
+ def get_unconditional_conditioning(self, batch_size, null_label=None, image_size=512):
+ if null_label is not None:
+ xc = null_label
+ if isinstance(xc, ListConfig):
+ xc = list(xc)
+ if isinstance(xc, dict) or isinstance(xc, list):
+ c = self.get_learned_conditioning(xc)
+ else:
+ if hasattr(xc, "to"):
+ xc = xc.to(self.device)
+ c = self.get_learned_conditioning(xc)
+ else:
+ # todo: get null label from cond_stage_model
+ raise NotImplementedError()
+ c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
+ cond = {}
+ cond["c_crossattn"] = [c]
+ cond["c_concat"] = [torch.zeros([batch_size, 4, image_size // 8, image_size // 8]).to(self.device)]
+ return cond
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+ plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
+ use_ema_scope=True,
+ **kwargs):
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
+ return_first_stage_outputs=True,
+ force_c_encode=True,
+ return_original_cond=True,
+ bs=N)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption", "txt"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25)
+ log["conditioning"] = xc
+ elif self.cond_stage_key == 'class_label':
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25)
+ log['conditioning'] = xc
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ with ema_scope("Sampling"):
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
+ ddim_steps=ddim_steps,eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
+ self.first_stage_model, IdentityFirstStage):
+ # also display when quantizing x0 while sampling
+ with ema_scope("Plotting Quantized Denoised"):
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
+ ddim_steps=ddim_steps,eta=ddim_eta,
+ quantize_denoised=True)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
+ # quantize_denoised=True)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_x0_quantized"] = x_samples
+
+ if unconditional_guidance_scale > 1.0:
+ uc = self.get_unconditional_conditioning(N, unconditional_guidance_label, image_size=x.shape[-1])
+ # uc = torch.zeros_like(c)
+ with ema_scope("Sampling with classifier-free guidance"):
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+ if inpaint:
+ # make a simple center square
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
+ mask = torch.ones(N, h, w).to(self.device)
+ # zeros will be filled in
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
+ mask = mask[:, None, ...]
+ with ema_scope("Plotting Inpaint"):
+
+ samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_inpainting"] = x_samples
+ log["mask"] = mask
+
+ # outpaint
+ mask = 1. - mask
+ with ema_scope("Plotting Outpaint"):
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_outpainting"] = x_samples
+
+ if plot_progressive_rows:
+ with ema_scope("Plotting Progressives"):
+ img, progressives = self.progressive_denoising(c,
+ shape=(self.channels, self.image_size, self.image_size),
+ batch_size=N)
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+ log["progressive_row"] = prog_row
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = []
+ if self.unet_trainable == "attn":
+ print("Training only unet attention layers")
+ for n, m in self.model.named_modules():
+ if isinstance(m, CrossAttention) and n.endswith('attn2'):
+ params.extend(m.parameters())
+ if self.unet_trainable == "conv_in":
+ print("Training only unet input conv layers")
+ params = list(self.model.diffusion_model.input_blocks[0][0].parameters())
+ elif self.unet_trainable is True or self.unet_trainable == "all":
+ print("Training the full unet")
+ params = list(self.model.parameters())
+ else:
+ raise ValueError(f"Unrecognised setting for unet_trainable: {self.unet_trainable}")
+
+ if self.cond_stage_trainable:
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
+ params = params + list(self.cond_stage_model.parameters())
+ if self.learn_logvar:
+ print('Diffusion model optimizing logvar')
+ params.append(self.logvar)
+
+ if self.cc_projection is not None:
+ params = params + list(self.cc_projection.parameters())
+ print('========== optimizing for cc projection weight ==========')
+
+ opt = torch.optim.AdamW([{"params": self.model.parameters(), "lr": lr},
+ {"params": self.cc_projection.parameters(), "lr": 10. * lr}], lr=lr)
+ if self.use_scheduler:
+ assert 'target' in self.scheduler_config
+ scheduler = instantiate_from_config(self.scheduler_config)
+
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ }]
+ return [opt], scheduler
+ return opt
+
+ @torch.no_grad()
+ def to_rgb(self, x):
+ x = x.float()
+ if not hasattr(self, "colorize"):
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = nn.functional.conv2d(x, weight=self.colorize)
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
+ return x
+
+
+class DiffusionWrapper(pl.LightningModule):
+ def __init__(self, diff_model_config, conditioning_key):
+ super().__init__()
+ self.diffusion_model = instantiate_from_config(diff_model_config)
+ self.conditioning_key = conditioning_key
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm']
+
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
+ if self.conditioning_key is None:
+ out = self.diffusion_model(x, t)
+ elif self.conditioning_key == 'concat':
+ xc = torch.cat([x] + c_concat, dim=1)
+ out = self.diffusion_model(xc, t)
+ elif self.conditioning_key == 'crossattn':
+ # c_crossattn dimension: torch.Size([8, 1, 768]) 1
+ # cc dimension: torch.Size([8, 1, 768]
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(x, t, context=cc)
+ elif self.conditioning_key == 'hybrid':
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc)
+ elif self.conditioning_key == 'hybrid-adm':
+ assert c_adm is not None
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc, y=c_adm)
+ elif self.conditioning_key == 'adm':
+ cc = c_crossattn[0]
+ out = self.diffusion_model(x, t, y=cc)
+ else:
+ raise NotImplementedError()
+
+ return out
+
+
+class LatentUpscaleDiffusion(LatentDiffusion):
+ def __init__(self, *args, low_scale_config, low_scale_key="LR", **kwargs):
+ super().__init__(*args, **kwargs)
+ # assumes that neither the cond_stage nor the low_scale_model contain trainable params
+ assert not self.cond_stage_trainable
+ self.instantiate_low_stage(low_scale_config)
+ self.low_scale_key = low_scale_key
+
+ def instantiate_low_stage(self, config):
+ model = instantiate_from_config(config)
+ self.low_scale_model = model.eval()
+ self.low_scale_model.train = disabled_train
+ for param in self.low_scale_model.parameters():
+ param.requires_grad = False
+
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
+ if not log_mode:
+ z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
+ else:
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+ force_c_encode=True, return_original_cond=True, bs=bs)
+ x_low = batch[self.low_scale_key][:bs]
+ x_low = rearrange(x_low, 'b h w c -> b c h w')
+ x_low = x_low.to(memory_format=torch.contiguous_format).float()
+ zx, noise_level = self.low_scale_model(x_low)
+ all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
+ #import pudb; pu.db
+ if log_mode:
+ # TODO: maybe disable if too expensive
+ interpretability = False
+ if interpretability:
+ zx = zx[:, :, ::2, ::2]
+ x_low_rec = self.low_scale_model.decode(zx)
+ return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
+ return z, all_conds
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+ plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
+ unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
+ **kwargs):
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
+ log_mode=True)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ log["x_lr"] = x_low
+ log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption", "txt"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25)
+ log["conditioning"] = xc
+ elif self.cond_stage_key == 'class_label':
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25)
+ log['conditioning'] = xc
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ with ema_scope("Sampling"):
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if unconditional_guidance_scale > 1.0:
+ uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+ # TODO explore better "unconditional" choices for the other keys
+ # maybe guide away from empty text label and highest noise level and maximally degraded zx?
+ uc = dict()
+ for k in c:
+ if k == "c_crossattn":
+ assert isinstance(c[k], list) and len(c[k]) == 1
+ uc[k] = [uc_tmp]
+ elif k == "c_adm": # todo: only run with text-based guidance?
+ assert isinstance(c[k], torch.Tensor)
+ uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
+ elif isinstance(c[k], list):
+ uc[k] = [c[k][i] for i in range(len(c[k]))]
+ else:
+ uc[k] = c[k]
+
+ with ema_scope("Sampling with classifier-free guidance"):
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+ if plot_progressive_rows:
+ with ema_scope("Plotting Progressives"):
+ img, progressives = self.progressive_denoising(c,
+ shape=(self.channels, self.image_size, self.image_size),
+ batch_size=N)
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+ log["progressive_row"] = prog_row
+
+ return log
+
+
+class LatentInpaintDiffusion(LatentDiffusion):
+ """
+ can either run as pure inpainting model (only concat mode) or with mixed conditionings,
+ e.g. mask as concat and text via cross-attn.
+ To disable finetuning mode, set finetune_keys to None
+ """
+ def __init__(self,
+ finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
+ "model_ema.diffusion_modelinput_blocks00weight"
+ ),
+ concat_keys=("mask", "masked_image"),
+ masked_image_key="masked_image",
+ keep_finetune_dims=4, # if model was trained without concat mode before and we would like to keep these channels
+ c_concat_log_start=None, # to log reconstruction of c_concat codes
+ c_concat_log_end=None,
+ *args, **kwargs
+ ):
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ignore_keys = kwargs.pop("ignore_keys", list())
+ super().__init__(*args, **kwargs)
+ self.masked_image_key = masked_image_key
+ assert self.masked_image_key in concat_keys
+ self.finetune_keys = finetune_keys
+ self.concat_keys = concat_keys
+ self.keep_dims = keep_finetune_dims
+ self.c_concat_log_start = c_concat_log_start
+ self.c_concat_log_end = c_concat_log_end
+ if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
+ if exists(ckpt_path):
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+
+ # make it explicit, finetune by including extra input channels
+ if exists(self.finetune_keys) and k in self.finetune_keys:
+ new_entry = None
+ for name, param in self.named_parameters():
+ if name in self.finetune_keys:
+ print(f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
+ new_entry = torch.zeros_like(param) # zero init
+ assert exists(new_entry), 'did not find matching parameter to modify'
+ new_entry[:, :self.keep_dims, ...] = sd[k]
+ sd[k] = new_entry
+
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+ # note: restricted to non-trainable encoders currently
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+ force_c_encode=True, return_original_cond=True, bs=bs)
+
+ assert exists(self.concat_keys)
+ c_cat = list()
+ for ck in self.concat_keys:
+ cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
+ if bs is not None:
+ cc = cc[:bs]
+ cc = cc.to(self.device)
+ bchw = z.shape
+ if ck != self.masked_image_key:
+ cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
+ else:
+ cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+ if return_first_stage_outputs:
+ return z, all_conds, x, xrec, xc
+ return z, all_conds
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+ plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
+ use_ema_scope=True,
+ **kwargs):
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
+ c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption", "txt"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+ log["conditioning"] = xc
+ elif self.cond_stage_key == 'class_label':
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+ log['conditioning'] = xc
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
+ log["c_concat_decoded"] = self.decode_first_stage(c_cat[:,self.c_concat_log_start:self.c_concat_log_end])
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ with ema_scope("Sampling"):
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+ batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if unconditional_guidance_scale > 1.0:
+ uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+ uc_cat = c_cat
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
+ with ema_scope("Sampling with classifier-free guidance"):
+ samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+ batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc_full,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+ log["masked_image"] = rearrange(batch["masked_image"],
+ 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
+ return log
+
+
+class Layout2ImgDiffusion(LatentDiffusion):
+ # TODO: move all layout-specific hacks to this class
+ def __init__(self, cond_stage_key, *args, **kwargs):
+ assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
+ super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
+
+ def log_images(self, batch, N=8, *args, **kwargs):
+ logs = super().log_images(batch=batch, N=N, *args, **kwargs)
+
+ key = 'train' if self.training else 'validation'
+ dset = self.trainer.datamodule.datasets[key]
+ mapper = dset.conditional_builders[self.cond_stage_key]
+
+ bbox_imgs = []
+ map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
+ for tknzd_bbox in batch[self.cond_stage_key][:N]:
+ bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
+ bbox_imgs.append(bboximg)
+
+ cond_img = torch.stack(bbox_imgs, dim=0)
+ logs['bbox_image'] = cond_img
+ return logs
+
+
+class SimpleUpscaleDiffusion(LatentDiffusion):
+ def __init__(self, *args, low_scale_key="LR", **kwargs):
+ super().__init__(*args, **kwargs)
+ # assumes that neither the cond_stage nor the low_scale_model contain trainable params
+ assert not self.cond_stage_trainable
+ self.low_scale_key = low_scale_key
+
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
+ if not log_mode:
+ z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
+ else:
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+ force_c_encode=True, return_original_cond=True, bs=bs)
+ x_low = batch[self.low_scale_key][:bs]
+ x_low = rearrange(x_low, 'b h w c -> b c h w')
+ x_low = x_low.to(memory_format=torch.contiguous_format).float()
+
+ encoder_posterior = self.encode_first_stage(x_low)
+ zx = self.get_first_stage_encoding(encoder_posterior).detach()
+ all_conds = {"c_concat": [zx], "c_crossattn": [c]}
+
+ if log_mode:
+ # TODO: maybe disable if too expensive
+ interpretability = False
+ if interpretability:
+ zx = zx[:, :, ::2, ::2]
+ return z, all_conds, x, xrec, xc, x_low
+ return z, all_conds
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+ plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
+ unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
+ **kwargs):
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc, x_low = self.get_input(batch, self.first_stage_key, bs=N, log_mode=True)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ log["x_lr"] = x_low
+
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption", "txt"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25)
+ log["conditioning"] = xc
+ elif self.cond_stage_key == 'class_label':
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25)
+ log['conditioning'] = xc
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if sample:
+ # get denoise row
+ with ema_scope("Sampling"):
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+
+ if unconditional_guidance_scale > 1.0:
+ uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+ uc = dict()
+ for k in c:
+ if k == "c_crossattn":
+ assert isinstance(c[k], list) and len(c[k]) == 1
+ uc[k] = [uc_tmp]
+ elif isinstance(c[k], list):
+ uc[k] = [c[k][i] for i in range(len(c[k]))]
+ else:
+ uc[k] = c[k]
+
+ with ema_scope("Sampling with classifier-free guidance"):
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+ return log
+
+class MultiCatFrameDiffusion(LatentDiffusion):
+ def __init__(self, *args, low_scale_key="LR", **kwargs):
+ super().__init__(*args, **kwargs)
+ # assumes that neither the cond_stage nor the low_scale_model contain trainable params
+ assert not self.cond_stage_trainable
+ self.low_scale_key = low_scale_key
+
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
+ n = 2
+ if not log_mode:
+ z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
+ else:
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+ force_c_encode=True, return_original_cond=True, bs=bs)
+ cat_conds = batch[self.low_scale_key][:bs]
+ cats = []
+ for i in range(n):
+ x_low = cat_conds[:,:,:,3*i:3*(i+1)]
+ x_low = rearrange(x_low, 'b h w c -> b c h w')
+ x_low = x_low.to(memory_format=torch.contiguous_format).float()
+ encoder_posterior = self.encode_first_stage(x_low)
+ zx = self.get_first_stage_encoding(encoder_posterior).detach()
+ cats.append(zx)
+
+ all_conds = {"c_concat": [torch.cat(cats, dim=1)], "c_crossattn": [c]}
+
+ if log_mode:
+ # TODO: maybe disable if too expensive
+ interpretability = False
+ if interpretability:
+ zx = zx[:, :, ::2, ::2]
+ return z, all_conds, x, xrec, xc, x_low
+ return z, all_conds
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+ plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
+ unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
+ **kwargs):
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc, x_low = self.get_input(batch, self.first_stage_key, bs=N, log_mode=True)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ log["x_lr"] = x_low
+
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption", "txt"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25)
+ log["conditioning"] = xc
+ elif self.cond_stage_key == 'class_label':
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25)
+ log['conditioning'] = xc
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if sample:
+ # get denoise row
+ with ema_scope("Sampling"):
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+
+ if unconditional_guidance_scale > 1.0:
+ uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+ uc = dict()
+ for k in c:
+ if k == "c_crossattn":
+ assert isinstance(c[k], list) and len(c[k]) == 1
+ uc[k] = [uc_tmp]
+ elif isinstance(c[k], list):
+ uc[k] = [c[k][i] for i in range(len(c[k]))]
+ else:
+ uc[k] = c[k]
+
+ with ema_scope("Sampling with classifier-free guidance"):
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+ return log
diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py
new file mode 100755
index 0000000..080edee
--- /dev/null
+++ b/ldm/models/diffusion/plms.py
@@ -0,0 +1,259 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+from functools import partial
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
+from ldm.models.diffusion.sampling_util import norm_thresholding
+
+
+class PLMSSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ if ddim_eta != 0:
+ raise ValueError('ddim_eta must be 0 for PLMS')
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ dynamic_threshold=None,
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list): ctmp = ctmp[0]
+ cbs = ctmp.shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for PLMS sampling is {size}')
+
+ samples, intermediates = self.plms_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold,
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def plms_sampling(self, cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
+ dynamic_threshold=None):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
+ old_eps = []
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ old_eps=old_eps, t_next=ts_next,
+ dynamic_threshold=dynamic_threshold)
+ img, pred_x0, e_t = outs
+ old_eps.append(e_t)
+ if len(old_eps) >= 4:
+ old_eps.pop(0)
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
+ dynamic_threshold=None):
+ b, *_, device = *x.shape, x.device
+
+ def get_model_output(x, t):
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ if isinstance(c, dict):
+ assert isinstance(unconditional_conditioning, dict)
+ c_in = dict()
+ for k in c:
+ if isinstance(c[k], list):
+ c_in[k] = [torch.cat([
+ unconditional_conditioning[k][i],
+ c[k][i]]) for i in range(len(c[k]))]
+ else:
+ c_in[k] = torch.cat([
+ unconditional_conditioning[k],
+ c[k]])
+ else:
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ return e_t
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+
+ def get_x_prev_and_pred_x0(e_t, index):
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ if dynamic_threshold is not None:
+ pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ e_t = get_model_output(x, t)
+ if len(old_eps) == 0:
+ # Pseudo Improved Euler (2nd order)
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+ e_t_next = get_model_output(x_prev, t_next)
+ e_t_prime = (e_t + e_t_next) / 2
+ elif len(old_eps) == 1:
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
+ elif len(old_eps) == 2:
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+ elif len(old_eps) >= 3:
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+ return x_prev, pred_x0, e_t
diff --git a/ldm/models/diffusion/sampling_util.py b/ldm/models/diffusion/sampling_util.py
new file mode 100755
index 0000000..a0ae00f
--- /dev/null
+++ b/ldm/models/diffusion/sampling_util.py
@@ -0,0 +1,50 @@
+import torch
+import numpy as np
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions.
+ From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
+ return x[(...,) + (None,) * dims_to_append]
+
+
+def renorm_thresholding(x0, value):
+ # renorm
+ pred_max = x0.max()
+ pred_min = x0.min()
+ pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1
+ pred_x0 = 2 * pred_x0 - 1. # -1 ... 1
+
+ s = torch.quantile(
+ rearrange(pred_x0, 'b ... -> b (...)').abs(),
+ value,
+ dim=-1
+ )
+ s.clamp_(min=1.0)
+ s = s.view(-1, *((1,) * (pred_x0.ndim - 1)))
+
+ # clip by threshold
+ # pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max
+
+ # temporary hack: numpy on cpu
+ pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy()
+ pred_x0 = torch.tensor(pred_x0).to(self.model.device)
+
+ # re.renorm
+ pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1
+ pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range
+ return pred_x0
+
+
+def norm_thresholding(x0, value):
+ s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
+ return x0 * (value / s)
+
+
+def spatial_norm_thresholding(x0, value):
+ # b c h w
+ s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
+ return x0 * (value / s)
\ No newline at end of file
diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py
new file mode 100755
index 0000000..124effb
--- /dev/null
+++ b/ldm/modules/attention.py
@@ -0,0 +1,266 @@
+from inspect import isfunction
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+
+from ldm.modules.diffusionmodules.util import checkpoint
+
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return{el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(
+ nn.Linear(dim, inner_dim),
+ nn.GELU()
+ ) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(
+ project_in,
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
+ k = k.softmax(dim=-1)
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
+ return self.to_out(out)
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = rearrange(q, 'b c h w -> b (h w) c')
+ k = rearrange(k, 'b c h w -> b c (h w)')
+ w_ = torch.einsum('bij,bjk->bik', q, k)
+
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, 'b c h w -> b c (h w)')
+ w_ = rearrange(w_, 'b i j -> b j i')
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
+ disable_self_attn=False):
+ super().__init__()
+ self.disable_self_attn = disable_self_attn
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
+
+ def _forward(self, x, context=None):
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ """
+ def __init__(self, in_channels, n_heads, d_head,
+ depth=1, dropout=0., context_dim=None,
+ disable_self_attn=False):
+ super().__init__()
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+
+ self.proj_in = nn.Conv2d(in_channels,
+ inner_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ self.transformer_blocks = nn.ModuleList(
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,
+ disable_self_attn=disable_self_attn)
+ for d in range(depth)]
+ )
+
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0))
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ x = self.proj_in(x)
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
+ for block in self.transformer_blocks:
+ x = block(x, context=context)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+ x = self.proj_out(x)
+ return x + x_in
diff --git a/ldm/modules/diffusionmodules/__init__.py b/ldm/modules/diffusionmodules/__init__.py
new file mode 100755
index 0000000..e69de29
diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py
new file mode 100755
index 0000000..533e589
--- /dev/null
+++ b/ldm/modules/diffusionmodules/model.py
@@ -0,0 +1,835 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+
+from ldm.util import instantiate_from_config
+from ldm.modules.attention import LinearAttention
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0,1,0,1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x+h
+
+
+class LinAttnBlock(LinearAttention):
+ """to match AttnBlock usage"""
+ def __init__(self, in_channels):
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w)
+ q = q.permute(0,2,1) # b,hw,c
+ k = k.reshape(b,c,h*w) # b,c,hw
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+def make_attn(in_channels, attn_type="vanilla"):
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ return AttnBlock(in_channels)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ return LinAttnBlock(in_channels)
+
+
+class Model(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x, t=None, context=None):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+ if context is not None:
+ # assume aligned context, cat along channel axis
+ x = torch.cat((x, context), dim=1)
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
+ **ignore_kwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ 2*z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+ attn_type="vanilla", **ignorekwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,)+tuple(ch_mult)
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ curr_res = resolution // 2**(self.num_resolutions-1)
+ self.z_shape = (1,z_channels,curr_res,curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
+
+
+class SimpleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ nn.Conv2d(2*in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True)])
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1,2,3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+
+
+class UpsampleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+ ch_mult=(2,2), dropout=0.0):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class LatentRescaler(nn.Module):
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
+ super().__init__()
+ # residual block, interpolate, residual block
+ self.factor = factor
+ self.conv_in = nn.Conv2d(in_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+ self.attn = AttnBlock(mid_channels)
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+
+ self.conv_out = nn.Conv2d(mid_channels,
+ out_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, x):
+ x = self.conv_in(x)
+ for block in self.res_block1:
+ x = block(x, None)
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
+ x = self.attn(x)
+ for block in self.res_block2:
+ x = block(x, None)
+ x = self.conv_out(x)
+ return x
+
+
+class MergedRescaleEncoder(nn.Module):
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ intermediate_chn = ch * ch_mult[-1]
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
+ out_ch=None)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.rescaler(x)
+ return x
+
+
+class MergedRescaleDecoder(nn.Module):
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ tmp_chn = z_channels*ch_mult[-1]
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
+ out_channels=tmp_chn, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Upsampler(nn.Module):
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
+ super().__init__()
+ assert out_size >= in_size
+ num_blocks = int(np.log2(out_size//in_size))+1
+ factor_up = 1.+ (out_size % in_size)
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
+ out_channels=in_channels)
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
+ attn_resolutions=[], in_channels=None, ch=in_channels,
+ ch_mult=[ch_mult for _ in range(num_blocks)])
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Resize(nn.Module):
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+ super().__init__()
+ self.with_conv = learned
+ self.mode = mode
+ if self.with_conv:
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
+ raise NotImplementedError()
+ assert in_channels is not None
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1)
+
+ def forward(self, x, scale_factor=1.0):
+ if scale_factor==1.0:
+ return x
+ else:
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
+ return x
+
+class FirstStagePostProcessor(nn.Module):
+
+ def __init__(self, ch_mult:list, in_channels,
+ pretrained_model:nn.Module=None,
+ reshape=False,
+ n_channels=None,
+ dropout=0.,
+ pretrained_config=None):
+ super().__init__()
+ if pretrained_config is None:
+ assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
+ self.pretrained_model = pretrained_model
+ else:
+ assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
+ self.instantiate_pretrained(pretrained_config)
+
+ self.do_reshape = reshape
+
+ if n_channels is None:
+ n_channels = self.pretrained_model.encoder.ch
+
+ self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
+ self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
+ stride=1,padding=1)
+
+ blocks = []
+ downs = []
+ ch_in = n_channels
+ for m in ch_mult:
+ blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
+ ch_in = m * n_channels
+ downs.append(Downsample(ch_in, with_conv=False))
+
+ self.model = nn.ModuleList(blocks)
+ self.downsampler = nn.ModuleList(downs)
+
+
+ def instantiate_pretrained(self, config):
+ model = instantiate_from_config(config)
+ self.pretrained_model = model.eval()
+ # self.pretrained_model.train = False
+ for param in self.pretrained_model.parameters():
+ param.requires_grad = False
+
+
+ @torch.no_grad()
+ def encode_with_pretrained(self,x):
+ c = self.pretrained_model.encode(x)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ return c
+
+ def forward(self,x):
+ z_fs = self.encode_with_pretrained(x)
+ z = self.proj_norm(z_fs)
+ z = self.proj(z)
+ z = nonlinearity(z)
+
+ for submodel, downmodel in zip(self.model,self.downsampler):
+ z = submodel(z,temb=None)
+ z = downmodel(z)
+
+ if self.do_reshape:
+ z = rearrange(z,'b c h w -> b (h w) c')
+ return z
+
diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py
new file mode 100755
index 0000000..09f0ae1
--- /dev/null
+++ b/ldm/modules/diffusionmodules/openaimodel.py
@@ -0,0 +1,996 @@
+from abc import abstractmethod
+from functools import partial
+import math
+from typing import Iterable
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ldm.modules.diffusionmodules.util import (
+ checkpoint,
+ conv_nd,
+ linear,
+ avg_pool_nd,
+ zero_module,
+ normalization,
+ timestep_embedding,
+)
+from ldm.modules.attention import SpatialTransformer
+from ldm.util import exists
+
+
+# dummy replace
+def convert_module_to_f16(x):
+ pass
+
+def convert_module_to_f32(x):
+ pass
+
+
+## go
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, context=None):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, SpatialTransformer):
+ x = layer(x, context)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+class TransposedUpsample(nn.Module):
+ 'Learned 2x upsampling without padding'
+ def __init__(self, channels, out_channels=None, ks=5):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
+
+ def forward(self,x):
+ return self.up(x)
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
+ )
+
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x):
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ #return pt_checkpoint(self._forward, x) # pytorch
+
+ def _forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ disable_self_attentions=None,
+ num_attention_blocks=None
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+ #self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set.") # todo: convert to warning
+
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(self.num_res_blocks[level] + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa
+ )
+ )
+ if level and i == self.num_res_blocks[level]:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape == (x.shape[0],)
+ emb = emb + self.label_emb(y)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
+
+
+class EncoderUNetModel(nn.Module):
+ """
+ The half UNet model with attention and timestep embedding.
+ For usage, see UNet.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ pool="adaptive",
+ *args,
+ **kwargs
+ ):
+ super().__init__()
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+ self.pool = pool
+ if pool == "adaptive":
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ nn.AdaptiveAvgPool2d((1, 1)),
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
+ nn.Flatten(),
+ )
+ elif pool == "attention":
+ assert num_head_channels != -1
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ AttentionPool2d(
+ (image_size // ds), ch, num_head_channels, out_channels
+ ),
+ )
+ elif pool == "spatial":
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ nn.ReLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ elif pool == "spatial_v2":
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ normalization(2048),
+ nn.SiLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ else:
+ raise NotImplementedError(f"Unexpected {pool} pooling")
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :return: an [N x K] Tensor of outputs.
+ """
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
+
+ results = []
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb)
+ if self.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = self.middle_block(h, emb)
+ if self.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = th.cat(results, axis=-1)
+ return self.out(h)
+ else:
+ h = h.type(x.dtype)
+ return self.out(h)
+
diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py
new file mode 100755
index 0000000..a952e6c
--- /dev/null
+++ b/ldm/modules/diffusionmodules/util.py
@@ -0,0 +1,267 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+
+import os
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import repeat
+
+from ldm.util import instantiate_from_config
+
+
+def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if schedule == "linear":
+ betas = (
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
+ )
+
+ elif schedule == "cosine":
+ timesteps = (
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+ )
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
+ alphas = torch.cos(alphas).pow(2)
+ alphas = alphas / alphas[0]
+ betas = 1 - alphas[1:] / alphas[:-1]
+ betas = np.clip(betas, a_min=0, a_max=0.999)
+
+ elif schedule == "sqrt_linear":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+ elif schedule == "sqrt":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
+ else:
+ raise ValueError(f"schedule '{schedule}' unknown.")
+ return betas.numpy()
+
+
+def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
+ if ddim_discr_method == 'uniform':
+ c = num_ddpm_timesteps // num_ddim_timesteps
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
+ elif ddim_discr_method == 'quad':
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
+ else:
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
+
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
+ steps_out = ddim_timesteps + 1
+ if verbose:
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
+ return steps_out
+
+
+def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
+ # select alphas for computing the variance schedule
+ alphas = alphacums[ddim_timesteps]
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
+
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
+ if verbose:
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
+ print(f'For the chosen value of eta, which is {eta}, '
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
+ return sigmas, alphas, alphas_prev
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ else:
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class HybridConditioner(nn.Module):
+
+ def __init__(self, c_concat_config, c_crossattn_config):
+ super().__init__()
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+ def forward(self, c_concat, c_crossattn):
+ c_concat = self.concat_conditioner(c_concat)
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+ noise = lambda: torch.randn(shape, device=device)
+ return repeat_noise() if repeat else noise()
\ No newline at end of file
diff --git a/ldm/modules/distributions/__init__.py b/ldm/modules/distributions/__init__.py
new file mode 100755
index 0000000..e69de29
diff --git a/ldm/modules/distributions/distributions.py b/ldm/modules/distributions/distributions.py
new file mode 100755
index 0000000..f2b8ef9
--- /dev/null
+++ b/ldm/modules/distributions/distributions.py
@@ -0,0 +1,92 @@
+import torch
+import numpy as np
+
+
+class AbstractDistribution:
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=[1, 2, 3])
+
+ def nll(self, sample, dims=[1,2,3]):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + torch.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
diff --git a/ldm/modules/ema.py b/ldm/modules/ema.py
new file mode 100755
index 0000000..c8c75af
--- /dev/null
+++ b/ldm/modules/ema.py
@@ -0,0 +1,76 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
+ super().__init__()
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError('Decay must be between 0 and 1')
+
+ self.m_name2s_name = {}
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
+ self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
+ else torch.tensor(-1,dtype=torch.int))
+
+ for name, p in model.named_parameters():
+ if p.requires_grad:
+ #remove as '.'-character is not allowed in buffers
+ s_name = name.replace('.','')
+ self.m_name2s_name.update({name:s_name})
+ self.register_buffer(s_name,p.clone().detach().data)
+
+ self.collected_params = []
+
+ def forward(self,model):
+ decay = self.decay
+
+ if self.num_updates >= 0:
+ self.num_updates += 1
+ decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
+
+ one_minus_decay = 1.0 - decay
+
+ with torch.no_grad():
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+
+ for key in m_param:
+ if m_param[key].requires_grad:
+ sname = self.m_name2s_name[key]
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
+ else:
+ assert not key in self.m_name2s_name
+
+ def copy_to(self, model):
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+ for key in m_param:
+ if m_param[key].requires_grad:
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+ else:
+ assert not key in self.m_name2s_name
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
diff --git a/ldm/modules/encoders/__init__.py b/ldm/modules/encoders/__init__.py
new file mode 100755
index 0000000..e69de29
diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py
new file mode 100755
index 0000000..b1afccf
--- /dev/null
+++ b/ldm/modules/encoders/modules.py
@@ -0,0 +1,550 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from functools import partial
+import kornia
+
+from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
+from ldm.util import default
+import clip
+
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+class IdentityEncoder(AbstractEncoder):
+
+ def encode(self, x):
+ return x
+
+class FaceClipEncoder(AbstractEncoder):
+ def __init__(self, augment=True, retreival_key=None):
+ super().__init__()
+ self.encoder = FrozenCLIPImageEmbedder()
+ self.augment = augment
+ self.retreival_key = retreival_key
+
+ def forward(self, img):
+ encodings = []
+ with torch.no_grad():
+ x_offset = 125
+ if self.retreival_key:
+ # Assumes retrieved image are packed into the second half of channels
+ face = img[:,3:,190:440,x_offset:(512-x_offset)]
+ other = img[:,:3,...].clone()
+ else:
+ face = img[:,:,190:440,x_offset:(512-x_offset)]
+ other = img.clone()
+
+ if self.augment:
+ face = K.RandomHorizontalFlip()(face)
+
+ other[:,:,190:440,x_offset:(512-x_offset)] *= 0
+ encodings = [
+ self.encoder.encode(face),
+ self.encoder.encode(other),
+ ]
+
+ return torch.cat(encodings, dim=1)
+
+ def encode(self, img):
+ if isinstance(img, list):
+ # Uncondition
+ return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device)
+
+ return self(img)
+
+class FaceIdClipEncoder(AbstractEncoder):
+ def __init__(self):
+ super().__init__()
+ self.encoder = FrozenCLIPImageEmbedder()
+ for p in self.encoder.parameters():
+ p.requires_grad = False
+ self.id = FrozenFaceEncoder("/home/jpinkney/code/stable-diffusion/model_ir_se50.pth", augment=True)
+
+ def forward(self, img):
+ encodings = []
+ with torch.no_grad():
+ face = kornia.geometry.resize(img, (256, 256),
+ interpolation='bilinear', align_corners=True)
+
+ other = img.clone()
+ other[:,:,184:452,122:396] *= 0
+ encodings = [
+ self.id.encode(face),
+ self.encoder.encode(other),
+ ]
+
+ return torch.cat(encodings, dim=1)
+
+ def encode(self, img):
+ if isinstance(img, list):
+ # Uncondition
+ return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device)
+
+ return self(img)
+
+class ClassEmbedder(nn.Module):
+ def __init__(self, embed_dim, n_classes=1000, key='class'):
+ super().__init__()
+ self.key = key
+ self.embedding = nn.Embedding(n_classes, embed_dim)
+
+ def forward(self, batch, key=None):
+ if key is None:
+ key = self.key
+ # this is for use in crossattn
+ c = batch[key][:, None]
+ c = self.embedding(c)
+ return c
+
+
+class TransformerEmbedder(AbstractEncoder):
+ """Some transformer encoder layers"""
+ def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
+ super().__init__()
+ self.device = device
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
+ attn_layers=Encoder(dim=n_embed, depth=n_layer))
+
+ def forward(self, tokens):
+ tokens = tokens.to(self.device) # meh
+ z = self.transformer(tokens, return_embeddings=True)
+ return z
+
+ def encode(self, x):
+ return self(x)
+
+
+class BERTTokenizer(AbstractEncoder):
+ """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
+ def __init__(self, device="cuda", vq_interface=True, max_length=77):
+ super().__init__()
+ from transformers import BertTokenizerFast # TODO: add to reuquirements
+ self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
+ self.device = device
+ self.vq_interface = vq_interface
+ self.max_length = max_length
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ return tokens
+
+ @torch.no_grad()
+ def encode(self, text):
+ tokens = self(text)
+ if not self.vq_interface:
+ return tokens
+ return None, None, [None, None, tokens]
+
+ def decode(self, text):
+ return text
+
+
+class BERTEmbedder(AbstractEncoder):
+ """Uses the BERT tokenizr model and add some transformer encoder layers"""
+ def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
+ device="cuda",use_tokenizer=True, embedding_dropout=0.0):
+ super().__init__()
+ self.use_tknz_fn = use_tokenizer
+ if self.use_tknz_fn:
+ self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
+ self.device = device
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
+ attn_layers=Encoder(dim=n_embed, depth=n_layer),
+ emb_dropout=embedding_dropout)
+
+ def forward(self, text):
+ if self.use_tknz_fn:
+ tokens = self.tknz_fn(text)#.to(self.device)
+ else:
+ tokens = text
+ z = self.transformer(tokens, return_embeddings=True)
+ return z
+
+ def encode(self, text):
+ # output of length 77
+ return self(text)
+
+
+from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class FrozenT5Embedder(AbstractEncoder):
+ """Uses the T5 transformer encoder for text"""
+ def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
+ super().__init__()
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
+ self.transformer = T5EncoderModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length # TODO: typical value?
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ #self.train = disabled_train
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens)
+
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+from ldm.thirdp.psp.id_loss import IDFeatures
+import kornia.augmentation as K
+
+class FrozenFaceEncoder(AbstractEncoder):
+ def __init__(self, model_path, augment=False):
+ super().__init__()
+ self.loss_fn = IDFeatures(model_path)
+ # face encoder is frozen
+ for p in self.loss_fn.parameters():
+ p.requires_grad = False
+ # Mapper is trainable
+ self.mapper = torch.nn.Linear(512, 768)
+ p = 0.25
+ if augment:
+ self.augment = K.AugmentationSequential(
+ K.RandomHorizontalFlip(p=0.5),
+ K.RandomEqualize(p=p),
+ # K.RandomPlanckianJitter(p=p),
+ # K.RandomPlasmaBrightness(p=p),
+ # K.RandomPlasmaContrast(p=p),
+ # K.ColorJiggle(0.02, 0.2, 0.2, p=p),
+ )
+ else:
+ self.augment = False
+
+ def forward(self, img):
+ if isinstance(img, list):
+ # Uncondition
+ return torch.zeros((1, 1, 768), device=self.mapper.weight.device)
+
+ if self.augment is not None:
+ # Transforms require 0-1
+ img = self.augment((img + 1)/2)
+ img = 2*img - 1
+
+ feat = self.loss_fn(img, crop=True)
+ feat = self.mapper(feat.unsqueeze(1))
+ return feat
+
+ def encode(self, img):
+ return self(img)
+
+class FrozenCLIPEmbedder(AbstractEncoder):
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
+ super().__init__()
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.transformer = CLIPTextModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length # TODO: typical value?
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ #self.train = disabled_train
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens)
+
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+import torch.nn.functional as F
+from transformers import CLIPVisionModel
+class ClipImageProjector(AbstractEncoder):
+ """
+ Uses the CLIP image encoder.
+ """
+ def __init__(self, version="openai/clip-vit-large-patch14", max_length=77): # clip-vit-base-patch32
+ super().__init__()
+ self.model = CLIPVisionModel.from_pretrained(version)
+ self.model.train()
+ self.max_length = max_length # TODO: typical value?
+ self.antialias = True
+ self.mapper = torch.nn.Linear(1024, 768)
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
+ null_cond = self.get_null_cond(version, max_length)
+ self.register_buffer('null_cond', null_cond)
+
+ @torch.no_grad()
+ def get_null_cond(self, version, max_length):
+ device = self.mean.device
+ embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length)
+ null_cond = embedder([""])
+ return null_cond
+
+ def preprocess(self, x):
+ # Expects inputs in the range -1, 1
+ x = kornia.geometry.resize(x, (224, 224),
+ interpolation='bicubic',align_corners=True,
+ antialias=self.antialias)
+ x = (x + 1.) / 2.
+ # renormalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def forward(self, x):
+ if isinstance(x, list):
+ return self.null_cond
+ # x is assumed to be in range [-1,1]
+ x = self.preprocess(x)
+ outputs = self.model(pixel_values=x)
+ last_hidden_state = outputs.last_hidden_state
+ last_hidden_state = self.mapper(last_hidden_state)
+ return F.pad(last_hidden_state, [0,0, 0,self.max_length-last_hidden_state.shape[1], 0,0])
+
+ def encode(self, im):
+ return self(im)
+
+class ProjectedFrozenCLIPEmbedder(AbstractEncoder):
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
+ super().__init__()
+ self.embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length)
+ self.projection = torch.nn.Linear(768, 768)
+
+ def forward(self, text):
+ z = self.embedder(text)
+ return self.projection(z)
+
+ def encode(self, text):
+ return self(text)
+
+class FrozenCLIPImageEmbedder(AbstractEncoder):
+ """
+ Uses the CLIP image encoder.
+ Not actually frozen... If you want that set cond_stage_trainable=False in cfg
+ """
+ def __init__(
+ self,
+ model='ViT-L/14',
+ jit=False,
+ device='cpu',
+ antialias=False,
+ ):
+ super().__init__()
+ self.model, _ = clip.load(name=model, device=device, jit=jit)
+ # We don't use the text part so delete it
+ del self.model.transformer
+ self.antialias = antialias
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
+
+ def preprocess(self, x):
+ # Expects inputs in the range -1, 1
+ x = kornia.geometry.resize(x, (224, 224),
+ interpolation='bicubic',align_corners=True,
+ antialias=self.antialias)
+ x = (x + 1.) / 2.
+ # renormalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def forward(self, x):
+ # x is assumed to be in range [-1,1]
+ if isinstance(x, list):
+ # [""] denotes condition dropout for ucg
+ device = self.model.visual.conv1.weight.device
+ return torch.zeros(1, 768, device=device)
+ return self.model.encode_image(self.preprocess(x)).float()
+
+ def encode(self, im):
+ return self(im).unsqueeze(1)
+
+from torchvision import transforms
+import random
+
+class FrozenCLIPImageMutliEmbedder(AbstractEncoder):
+ """
+ Uses the CLIP image encoder.
+ Not actually frozen... If you want that set cond_stage_trainable=False in cfg
+ """
+ def __init__(
+ self,
+ model='ViT-L/14',
+ jit=False,
+ device='cpu',
+ antialias=True,
+ max_crops=5,
+ ):
+ super().__init__()
+ self.model, _ = clip.load(name=model, device=device, jit=jit)
+ # We don't use the text part so delete it
+ del self.model.transformer
+ self.antialias = antialias
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
+ self.max_crops = max_crops
+
+ def preprocess(self, x):
+
+ # Expects inputs in the range -1, 1
+ randcrop = transforms.RandomResizedCrop(224, scale=(0.085, 1.0), ratio=(1,1))
+ max_crops = self.max_crops
+ patches = []
+ crops = [randcrop(x) for _ in range(max_crops)]
+ patches.extend(crops)
+ x = torch.cat(patches, dim=0)
+ x = (x + 1.) / 2.
+ # renormalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def forward(self, x):
+ # x is assumed to be in range [-1,1]
+ if isinstance(x, list):
+ # [""] denotes condition dropout for ucg
+ device = self.model.visual.conv1.weight.device
+ return torch.zeros(1, self.max_crops, 768, device=device)
+ batch_tokens = []
+ for im in x:
+ patches = self.preprocess(im.unsqueeze(0))
+ tokens = self.model.encode_image(patches).float()
+ for t in tokens:
+ if random.random() < 0.1:
+ t *= 0
+ batch_tokens.append(tokens.unsqueeze(0))
+
+ return torch.cat(batch_tokens, dim=0)
+
+ def encode(self, im):
+ return self(im)
+
+class SpatialRescaler(nn.Module):
+ def __init__(self,
+ n_stages=1,
+ method='bilinear',
+ multiplier=0.5,
+ in_channels=3,
+ out_channels=None,
+ bias=False):
+ super().__init__()
+ self.n_stages = n_stages
+ assert self.n_stages >= 0
+ assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
+ self.multiplier = multiplier
+ self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
+ self.remap_output = out_channels is not None
+ if self.remap_output:
+ print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
+ self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
+
+ def forward(self,x):
+ for stage in range(self.n_stages):
+ x = self.interpolator(x, scale_factor=self.multiplier)
+
+
+ if self.remap_output:
+ x = self.channel_mapper(x)
+ return x
+
+ def encode(self, x):
+ return self(x)
+
+
+from ldm.util import instantiate_from_config
+from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
+
+
+class LowScaleEncoder(nn.Module):
+ def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64,
+ scale_factor=1.0):
+ super().__init__()
+ self.max_noise_level = max_noise_level
+ self.model = instantiate_from_config(model_config)
+ self.augmentation_schedule = self.register_schedule(timesteps=timesteps, linear_start=linear_start,
+ linear_end=linear_end)
+ self.out_size = output_size
+ self.scale_factor = scale_factor
+
+ def register_schedule(self, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+ cosine_s=cosine_s)
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+ def forward(self, x):
+ z = self.model.encode(x).sample()
+ z = z * self.scale_factor
+ noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
+ z = self.q_sample(z, noise_level)
+ if self.out_size is not None:
+ z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode
+ # z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
+ return z, noise_level
+
+ def decode(self, z):
+ z = z / self.scale_factor
+ return self.model.decode(z)
+
+
+if __name__ == "__main__":
+ from ldm.util import count_params
+ sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"]
+ model = FrozenT5Embedder(version="google/t5-v1_1-xl").cuda()
+ count_params(model, True)
+ z = model(sentences)
+ print(z.shape)
+
+ model = FrozenCLIPEmbedder().cuda()
+ count_params(model, True)
+ z = model(sentences)
+ print(z.shape)
+
+ print("done.")
diff --git a/ldm/modules/evaluate/adm_evaluator.py b/ldm/modules/evaluate/adm_evaluator.py
new file mode 100755
index 0000000..508cddf
--- /dev/null
+++ b/ldm/modules/evaluate/adm_evaluator.py
@@ -0,0 +1,676 @@
+import argparse
+import io
+import os
+import random
+import warnings
+import zipfile
+from abc import ABC, abstractmethod
+from contextlib import contextmanager
+from functools import partial
+from multiprocessing import cpu_count
+from multiprocessing.pool import ThreadPool
+from typing import Iterable, Optional, Tuple
+import yaml
+
+import numpy as np
+import requests
+import tensorflow.compat.v1 as tf
+from scipy import linalg
+from tqdm.auto import tqdm
+
+INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
+INCEPTION_V3_PATH = "classify_image_graph_def.pb"
+
+FID_POOL_NAME = "pool_3:0"
+FID_SPATIAL_NAME = "mixed_6/conv:0"
+
+REQUIREMENTS = f"This script has the following requirements: \n" \
+ 'tensorflow-gpu>=2.0' + "\n" + 'scipy' + "\n" + "requests" + "\n" + "tqdm"
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--ref_batch", help="path to reference batch npz file")
+ parser.add_argument("--sample_batch", help="path to sample batch npz file")
+ args = parser.parse_args()
+
+ config = tf.ConfigProto(
+ allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
+ )
+ config.gpu_options.allow_growth = True
+ evaluator = Evaluator(tf.Session(config=config))
+
+ print("warming up TensorFlow...")
+ # This will cause TF to print a bunch of verbose stuff now rather
+ # than after the next print(), to help prevent confusion.
+ evaluator.warmup()
+
+ print("computing reference batch activations...")
+ ref_acts = evaluator.read_activations(args.ref_batch)
+ print("computing/reading reference batch statistics...")
+ ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
+
+ print("computing sample batch activations...")
+ sample_acts = evaluator.read_activations(args.sample_batch)
+ print("computing/reading sample batch statistics...")
+ sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts)
+
+ print("Computing evaluations...")
+ is_ = evaluator.compute_inception_score(sample_acts[0])
+ print("Inception Score:", is_)
+ fid = sample_stats.frechet_distance(ref_stats)
+ print("FID:", fid)
+ sfid = sample_stats_spatial.frechet_distance(ref_stats_spatial)
+ print("sFID:", sfid)
+ prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
+ print("Precision:", prec)
+ print("Recall:", recall)
+
+ savepath = '/'.join(args.sample_batch.split('/')[:-1])
+ results_file = os.path.join(savepath,'evaluation_metrics.yaml')
+ print(f'Saving evaluation results to "{results_file}"')
+
+ results = {
+ 'IS': is_,
+ 'FID': fid,
+ 'sFID': sfid,
+ 'Precision:':prec,
+ 'Recall': recall
+ }
+
+ with open(results_file, 'w') as f:
+ yaml.dump(results, f, default_flow_style=False)
+
+class InvalidFIDException(Exception):
+ pass
+
+
+class FIDStatistics:
+ def __init__(self, mu: np.ndarray, sigma: np.ndarray):
+ self.mu = mu
+ self.sigma = sigma
+
+ def frechet_distance(self, other, eps=1e-6):
+ """
+ Compute the Frechet distance between two sets of statistics.
+ """
+ # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
+ mu1, sigma1 = self.mu, self.sigma
+ mu2, sigma2 = other.mu, other.sigma
+
+ mu1 = np.atleast_1d(mu1)
+ mu2 = np.atleast_1d(mu2)
+
+ sigma1 = np.atleast_2d(sigma1)
+ sigma2 = np.atleast_2d(sigma2)
+
+ assert (
+ mu1.shape == mu2.shape
+ ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
+ assert (
+ sigma1.shape == sigma2.shape
+ ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
+
+ diff = mu1 - mu2
+
+ # product might be almost singular
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
+ if not np.isfinite(covmean).all():
+ msg = (
+ "fid calculation produces singular product; adding %s to diagonal of cov estimates"
+ % eps
+ )
+ warnings.warn(msg)
+ offset = np.eye(sigma1.shape[0]) * eps
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
+
+ # numerical error might give slight imaginary component
+ if np.iscomplexobj(covmean):
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
+ m = np.max(np.abs(covmean.imag))
+ raise ValueError("Imaginary component {}".format(m))
+ covmean = covmean.real
+
+ tr_covmean = np.trace(covmean)
+
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
+
+
+class Evaluator:
+ def __init__(
+ self,
+ session,
+ batch_size=64,
+ softmax_batch_size=512,
+ ):
+ self.sess = session
+ self.batch_size = batch_size
+ self.softmax_batch_size = softmax_batch_size
+ self.manifold_estimator = ManifoldEstimator(session)
+ with self.sess.graph.as_default():
+ self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
+ self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
+ self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)
+ self.softmax = _create_softmax_graph(self.softmax_input)
+
+ def warmup(self):
+ self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
+
+ def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
+ with open_npz_array(npz_path, "arr_0") as reader:
+ return self.compute_activations(reader.read_batches(self.batch_size))
+
+ def compute_activations(self, batches: Iterable[np.ndarray],silent=False) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Compute image features for downstream evals.
+
+ :param batches: a iterator over NHWC numpy arrays in [0, 255].
+ :return: a tuple of numpy arrays of shape [N x X], where X is a feature
+ dimension. The tuple is (pool_3, spatial).
+ """
+ preds = []
+ spatial_preds = []
+ it = batches if silent else tqdm(batches)
+ for batch in it:
+ batch = batch.astype(np.float32)
+ pred, spatial_pred = self.sess.run(
+ [self.pool_features, self.spatial_features], {self.image_input: batch}
+ )
+ preds.append(pred.reshape([pred.shape[0], -1]))
+ spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
+ return (
+ np.concatenate(preds, axis=0),
+ np.concatenate(spatial_preds, axis=0),
+ )
+
+ def read_statistics(
+ self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
+ ) -> Tuple[FIDStatistics, FIDStatistics]:
+ obj = np.load(npz_path)
+ if "mu" in list(obj.keys()):
+ return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
+ obj["mu_s"], obj["sigma_s"]
+ )
+ return tuple(self.compute_statistics(x) for x in activations)
+
+ def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
+ mu = np.mean(activations, axis=0)
+ sigma = np.cov(activations, rowvar=False)
+ return FIDStatistics(mu, sigma)
+
+ def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
+ softmax_out = []
+ for i in range(0, len(activations), self.softmax_batch_size):
+ acts = activations[i : i + self.softmax_batch_size]
+ softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
+ preds = np.concatenate(softmax_out, axis=0)
+ # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
+ scores = []
+ for i in range(0, len(preds), split_size):
+ part = preds[i : i + split_size]
+ kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
+ kl = np.mean(np.sum(kl, 1))
+ scores.append(np.exp(kl))
+ return float(np.mean(scores))
+
+ def compute_prec_recall(
+ self, activations_ref: np.ndarray, activations_sample: np.ndarray
+ ) -> Tuple[float, float]:
+ radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
+ radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
+ pr = self.manifold_estimator.evaluate_pr(
+ activations_ref, radii_1, activations_sample, radii_2
+ )
+ return (float(pr[0][0]), float(pr[1][0]))
+
+
+class ManifoldEstimator:
+ """
+ A helper for comparing manifolds of feature vectors.
+
+ Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
+ """
+
+ def __init__(
+ self,
+ session,
+ row_batch_size=10000,
+ col_batch_size=10000,
+ nhood_sizes=(3,),
+ clamp_to_percentile=None,
+ eps=1e-5,
+ ):
+ """
+ Estimate the manifold of given feature vectors.
+
+ :param session: the TensorFlow session.
+ :param row_batch_size: row batch size to compute pairwise distances
+ (parameter to trade-off between memory usage and performance).
+ :param col_batch_size: column batch size to compute pairwise distances.
+ :param nhood_sizes: number of neighbors used to estimate the manifold.
+ :param clamp_to_percentile: prune hyperspheres that have radius larger than
+ the given percentile.
+ :param eps: small number for numerical stability.
+ """
+ self.distance_block = DistanceBlock(session)
+ self.row_batch_size = row_batch_size
+ self.col_batch_size = col_batch_size
+ self.nhood_sizes = nhood_sizes
+ self.num_nhoods = len(nhood_sizes)
+ self.clamp_to_percentile = clamp_to_percentile
+ self.eps = eps
+
+ def warmup(self):
+ feats, radii = (
+ np.zeros([1, 2048], dtype=np.float32),
+ np.zeros([1, 1], dtype=np.float32),
+ )
+ self.evaluate_pr(feats, radii, feats, radii)
+
+ def manifold_radii(self, features: np.ndarray) -> np.ndarray:
+ num_images = len(features)
+
+ # Estimate manifold of features by calculating distances to k-NN of each sample.
+ radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
+ distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
+ seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
+
+ for begin1 in range(0, num_images, self.row_batch_size):
+ end1 = min(begin1 + self.row_batch_size, num_images)
+ row_batch = features[begin1:end1]
+
+ for begin2 in range(0, num_images, self.col_batch_size):
+ end2 = min(begin2 + self.col_batch_size, num_images)
+ col_batch = features[begin2:end2]
+
+ # Compute distances between batches.
+ distance_batch[
+ 0 : end1 - begin1, begin2:end2
+ ] = self.distance_block.pairwise_distances(row_batch, col_batch)
+
+ # Find the k-nearest neighbor from the current batch.
+ radii[begin1:end1, :] = np.concatenate(
+ [
+ x[:, self.nhood_sizes]
+ for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)
+ ],
+ axis=0,
+ )
+
+ if self.clamp_to_percentile is not None:
+ max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
+ radii[radii > max_distances] = 0
+ return radii
+
+ def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):
+ """
+ Evaluate if new feature vectors are at the manifold.
+ """
+ num_eval_images = eval_features.shape[0]
+ num_ref_images = radii.shape[0]
+ distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
+ batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
+ max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
+ nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
+
+ for begin1 in range(0, num_eval_images, self.row_batch_size):
+ end1 = min(begin1 + self.row_batch_size, num_eval_images)
+ feature_batch = eval_features[begin1:end1]
+
+ for begin2 in range(0, num_ref_images, self.col_batch_size):
+ end2 = min(begin2 + self.col_batch_size, num_ref_images)
+ ref_batch = features[begin2:end2]
+
+ distance_batch[
+ 0 : end1 - begin1, begin2:end2
+ ] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
+
+ # From the minibatch of new feature vectors, determine if they are in the estimated manifold.
+ # If a feature vector is inside a hypersphere of some reference sample, then
+ # the new sample lies at the estimated manifold.
+ # The radii of the hyperspheres are determined from distances of neighborhood size k.
+ samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
+ batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
+
+ max_realism_score[begin1:end1] = np.max(
+ radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
+ )
+ nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)
+
+ return {
+ "fraction": float(np.mean(batch_predictions)),
+ "batch_predictions": batch_predictions,
+ "max_realisim_score": max_realism_score,
+ "nearest_indices": nearest_indices,
+ }
+
+ def evaluate_pr(
+ self,
+ features_1: np.ndarray,
+ radii_1: np.ndarray,
+ features_2: np.ndarray,
+ radii_2: np.ndarray,
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Evaluate precision and recall efficiently.
+
+ :param features_1: [N1 x D] feature vectors for reference batch.
+ :param radii_1: [N1 x K1] radii for reference vectors.
+ :param features_2: [N2 x D] feature vectors for the other batch.
+ :param radii_2: [N x K2] radii for other vectors.
+ :return: a tuple of arrays for (precision, recall):
+ - precision: an np.ndarray of length K1
+ - recall: an np.ndarray of length K2
+ """
+ features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool)
+ features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool)
+ for begin_1 in range(0, len(features_1), self.row_batch_size):
+ end_1 = begin_1 + self.row_batch_size
+ batch_1 = features_1[begin_1:end_1]
+ for begin_2 in range(0, len(features_2), self.col_batch_size):
+ end_2 = begin_2 + self.col_batch_size
+ batch_2 = features_2[begin_2:end_2]
+ batch_1_in, batch_2_in = self.distance_block.less_thans(
+ batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
+ )
+ features_1_status[begin_1:end_1] |= batch_1_in
+ features_2_status[begin_2:end_2] |= batch_2_in
+ return (
+ np.mean(features_2_status.astype(np.float64), axis=0),
+ np.mean(features_1_status.astype(np.float64), axis=0),
+ )
+
+
+class DistanceBlock:
+ """
+ Calculate pairwise distances between vectors.
+
+ Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
+ """
+
+ def __init__(self, session):
+ self.session = session
+
+ # Initialize TF graph to calculate pairwise distances.
+ with session.graph.as_default():
+ self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
+ self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
+ distance_block_16 = _batch_pairwise_distances(
+ tf.cast(self._features_batch1, tf.float16),
+ tf.cast(self._features_batch2, tf.float16),
+ )
+ self.distance_block = tf.cond(
+ tf.reduce_all(tf.math.is_finite(distance_block_16)),
+ lambda: tf.cast(distance_block_16, tf.float32),
+ lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),
+ )
+
+ # Extra logic for less thans.
+ self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
+ self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
+ dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
+ self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
+ self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)
+
+ def pairwise_distances(self, U, V):
+ """
+ Evaluate pairwise distances between two batches of feature vectors.
+ """
+ return self.session.run(
+ self.distance_block,
+ feed_dict={self._features_batch1: U, self._features_batch2: V},
+ )
+
+ def less_thans(self, batch_1, radii_1, batch_2, radii_2):
+ return self.session.run(
+ [self._batch_1_in, self._batch_2_in],
+ feed_dict={
+ self._features_batch1: batch_1,
+ self._features_batch2: batch_2,
+ self._radii1: radii_1,
+ self._radii2: radii_2,
+ },
+ )
+
+
+def _batch_pairwise_distances(U, V):
+ """
+ Compute pairwise distances between two batches of feature vectors.
+ """
+ with tf.variable_scope("pairwise_dist_block"):
+ # Squared norms of each row in U and V.
+ norm_u = tf.reduce_sum(tf.square(U), 1)
+ norm_v = tf.reduce_sum(tf.square(V), 1)
+
+ # norm_u as a column and norm_v as a row vectors.
+ norm_u = tf.reshape(norm_u, [-1, 1])
+ norm_v = tf.reshape(norm_v, [1, -1])
+
+ # Pairwise squared Euclidean distances.
+ D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
+
+ return D
+
+
+class NpzArrayReader(ABC):
+ @abstractmethod
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
+ pass
+
+ @abstractmethod
+ def remaining(self) -> int:
+ pass
+
+ def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
+ def gen_fn():
+ while True:
+ batch = self.read_batch(batch_size)
+ if batch is None:
+ break
+ yield batch
+
+ rem = self.remaining()
+ num_batches = rem // batch_size + int(rem % batch_size != 0)
+ return BatchIterator(gen_fn, num_batches)
+
+
+class BatchIterator:
+ def __init__(self, gen_fn, length):
+ self.gen_fn = gen_fn
+ self.length = length
+
+ def __len__(self):
+ return self.length
+
+ def __iter__(self):
+ return self.gen_fn()
+
+
+class StreamingNpzArrayReader(NpzArrayReader):
+ def __init__(self, arr_f, shape, dtype):
+ self.arr_f = arr_f
+ self.shape = shape
+ self.dtype = dtype
+ self.idx = 0
+
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
+ if self.idx >= self.shape[0]:
+ return None
+
+ bs = min(batch_size, self.shape[0] - self.idx)
+ self.idx += bs
+
+ if self.dtype.itemsize == 0:
+ return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
+
+ read_count = bs * np.prod(self.shape[1:])
+ read_size = int(read_count * self.dtype.itemsize)
+ data = _read_bytes(self.arr_f, read_size, "array data")
+ return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
+
+ def remaining(self) -> int:
+ return max(0, self.shape[0] - self.idx)
+
+
+class MemoryNpzArrayReader(NpzArrayReader):
+ def __init__(self, arr):
+ self.arr = arr
+ self.idx = 0
+
+ @classmethod
+ def load(cls, path: str, arr_name: str):
+ with open(path, "rb") as f:
+ arr = np.load(f)[arr_name]
+ return cls(arr)
+
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
+ if self.idx >= self.arr.shape[0]:
+ return None
+
+ res = self.arr[self.idx : self.idx + batch_size]
+ self.idx += batch_size
+ return res
+
+ def remaining(self) -> int:
+ return max(0, self.arr.shape[0] - self.idx)
+
+
+@contextmanager
+def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
+ with _open_npy_file(path, arr_name) as arr_f:
+ version = np.lib.format.read_magic(arr_f)
+ if version == (1, 0):
+ header = np.lib.format.read_array_header_1_0(arr_f)
+ elif version == (2, 0):
+ header = np.lib.format.read_array_header_2_0(arr_f)
+ else:
+ yield MemoryNpzArrayReader.load(path, arr_name)
+ return
+ shape, fortran, dtype = header
+ if fortran or dtype.hasobject:
+ yield MemoryNpzArrayReader.load(path, arr_name)
+ else:
+ yield StreamingNpzArrayReader(arr_f, shape, dtype)
+
+
+def _read_bytes(fp, size, error_template="ran out of data"):
+ """
+ Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
+
+ Read from file-like object until size bytes are read.
+ Raises ValueError if not EOF is encountered before size bytes are read.
+ Non-blocking objects only supported if they derive from io objects.
+ Required as e.g. ZipExtFile in python 2.6 can return less data than
+ requested.
+ """
+ data = bytes()
+ while True:
+ # io files (default in python3) return None or raise on
+ # would-block, python2 file will truncate, probably nothing can be
+ # done about that. note that regular files can't be non-blocking
+ try:
+ r = fp.read(size - len(data))
+ data += r
+ if len(r) == 0 or len(data) == size:
+ break
+ except io.BlockingIOError:
+ pass
+ if len(data) != size:
+ msg = "EOF: reading %s, expected %d bytes got %d"
+ raise ValueError(msg % (error_template, size, len(data)))
+ else:
+ return data
+
+
+@contextmanager
+def _open_npy_file(path: str, arr_name: str):
+ with open(path, "rb") as f:
+ with zipfile.ZipFile(f, "r") as zip_f:
+ if f"{arr_name}.npy" not in zip_f.namelist():
+ raise ValueError(f"missing {arr_name} in npz file")
+ with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
+ yield arr_f
+
+
+def _download_inception_model():
+ if os.path.exists(INCEPTION_V3_PATH):
+ return
+ print("downloading InceptionV3 model...")
+ with requests.get(INCEPTION_V3_URL, stream=True) as r:
+ r.raise_for_status()
+ tmp_path = INCEPTION_V3_PATH + ".tmp"
+ with open(tmp_path, "wb") as f:
+ for chunk in tqdm(r.iter_content(chunk_size=8192)):
+ f.write(chunk)
+ os.rename(tmp_path, INCEPTION_V3_PATH)
+
+
+def _create_feature_graph(input_batch):
+ _download_inception_model()
+ prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
+ with open(INCEPTION_V3_PATH, "rb") as f:
+ graph_def = tf.GraphDef()
+ graph_def.ParseFromString(f.read())
+ pool3, spatial = tf.import_graph_def(
+ graph_def,
+ input_map={f"ExpandDims:0": input_batch},
+ return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
+ name=prefix,
+ )
+ _update_shapes(pool3)
+ spatial = spatial[..., :7]
+ return pool3, spatial
+
+
+def _create_softmax_graph(input_batch):
+ _download_inception_model()
+ prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
+ with open(INCEPTION_V3_PATH, "rb") as f:
+ graph_def = tf.GraphDef()
+ graph_def.ParseFromString(f.read())
+ (matmul,) = tf.import_graph_def(
+ graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
+ )
+ w = matmul.inputs[1]
+ logits = tf.matmul(input_batch, w)
+ return tf.nn.softmax(logits)
+
+
+def _update_shapes(pool3):
+ # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
+ ops = pool3.graph.get_operations()
+ for op in ops:
+ for o in op.outputs:
+ shape = o.get_shape()
+ if shape._dims is not None: # pylint: disable=protected-access
+ # shape = [s.value for s in shape] TF 1.x
+ shape = [s for s in shape] # TF 2.x
+ new_shape = []
+ for j, s in enumerate(shape):
+ if s == 1 and j == 0:
+ new_shape.append(None)
+ else:
+ new_shape.append(s)
+ o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
+ return pool3
+
+
+def _numpy_partition(arr, kth, **kwargs):
+ num_workers = min(cpu_count(), len(arr))
+ chunk_size = len(arr) // num_workers
+ extra = len(arr) % num_workers
+
+ start_idx = 0
+ batches = []
+ for i in range(num_workers):
+ size = chunk_size + (1 if i < extra else 0)
+ batches.append(arr[start_idx : start_idx + size])
+ start_idx += size
+
+ with ThreadPool(num_workers) as pool:
+ return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
+
+
+if __name__ == "__main__":
+ print(REQUIREMENTS)
+ main()
diff --git a/ldm/modules/evaluate/evaluate_perceptualsim.py b/ldm/modules/evaluate/evaluate_perceptualsim.py
new file mode 100755
index 0000000..c85fef9
--- /dev/null
+++ b/ldm/modules/evaluate/evaluate_perceptualsim.py
@@ -0,0 +1,630 @@
+import argparse
+import glob
+import os
+from tqdm import tqdm
+from collections import namedtuple
+
+import numpy as np
+import torch
+import torchvision.transforms as transforms
+from torchvision import models
+from PIL import Image
+
+from ldm.modules.evaluate.ssim import ssim
+
+
+transform = transforms.Compose([transforms.ToTensor()])
+
+def normalize_tensor(in_feat, eps=1e-10):
+ norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1)).view(
+ in_feat.size()[0], 1, in_feat.size()[2], in_feat.size()[3]
+ )
+ return in_feat / (norm_factor.expand_as(in_feat) + eps)
+
+
+def cos_sim(in0, in1):
+ in0_norm = normalize_tensor(in0)
+ in1_norm = normalize_tensor(in1)
+ N = in0.size()[0]
+ X = in0.size()[2]
+ Y = in0.size()[3]
+
+ return torch.mean(
+ torch.mean(
+ torch.sum(in0_norm * in1_norm, dim=1).view(N, 1, X, Y), dim=2
+ ).view(N, 1, 1, Y),
+ dim=3,
+ ).view(N)
+
+
+class squeezenet(torch.nn.Module):
+ def __init__(self, requires_grad=False, pretrained=True):
+ super(squeezenet, self).__init__()
+ pretrained_features = models.squeezenet1_1(
+ pretrained=pretrained
+ ).features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ self.slice6 = torch.nn.Sequential()
+ self.slice7 = torch.nn.Sequential()
+ self.N_slices = 7
+ for x in range(2):
+ self.slice1.add_module(str(x), pretrained_features[x])
+ for x in range(2, 5):
+ self.slice2.add_module(str(x), pretrained_features[x])
+ for x in range(5, 8):
+ self.slice3.add_module(str(x), pretrained_features[x])
+ for x in range(8, 10):
+ self.slice4.add_module(str(x), pretrained_features[x])
+ for x in range(10, 11):
+ self.slice5.add_module(str(x), pretrained_features[x])
+ for x in range(11, 12):
+ self.slice6.add_module(str(x), pretrained_features[x])
+ for x in range(12, 13):
+ self.slice7.add_module(str(x), pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ h = self.slice1(X)
+ h_relu1 = h
+ h = self.slice2(h)
+ h_relu2 = h
+ h = self.slice3(h)
+ h_relu3 = h
+ h = self.slice4(h)
+ h_relu4 = h
+ h = self.slice5(h)
+ h_relu5 = h
+ h = self.slice6(h)
+ h_relu6 = h
+ h = self.slice7(h)
+ h_relu7 = h
+ vgg_outputs = namedtuple(
+ "SqueezeOutputs",
+ ["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"],
+ )
+ out = vgg_outputs(
+ h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7
+ )
+
+ return out
+
+
+class alexnet(torch.nn.Module):
+ def __init__(self, requires_grad=False, pretrained=True):
+ super(alexnet, self).__init__()
+ alexnet_pretrained_features = models.alexnet(
+ pretrained=pretrained
+ ).features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ self.N_slices = 5
+ for x in range(2):
+ self.slice1.add_module(str(x), alexnet_pretrained_features[x])
+ for x in range(2, 5):
+ self.slice2.add_module(str(x), alexnet_pretrained_features[x])
+ for x in range(5, 8):
+ self.slice3.add_module(str(x), alexnet_pretrained_features[x])
+ for x in range(8, 10):
+ self.slice4.add_module(str(x), alexnet_pretrained_features[x])
+ for x in range(10, 12):
+ self.slice5.add_module(str(x), alexnet_pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ h = self.slice1(X)
+ h_relu1 = h
+ h = self.slice2(h)
+ h_relu2 = h
+ h = self.slice3(h)
+ h_relu3 = h
+ h = self.slice4(h)
+ h_relu4 = h
+ h = self.slice5(h)
+ h_relu5 = h
+ alexnet_outputs = namedtuple(
+ "AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"]
+ )
+ out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
+
+ return out
+
+
+class vgg16(torch.nn.Module):
+ def __init__(self, requires_grad=False, pretrained=True):
+ super(vgg16, self).__init__()
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ self.N_slices = 5
+ for x in range(4):
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(4, 9):
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(9, 16):
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(16, 23):
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(23, 30):
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ h = self.slice1(X)
+ h_relu1_2 = h
+ h = self.slice2(h)
+ h_relu2_2 = h
+ h = self.slice3(h)
+ h_relu3_3 = h
+ h = self.slice4(h)
+ h_relu4_3 = h
+ h = self.slice5(h)
+ h_relu5_3 = h
+ vgg_outputs = namedtuple(
+ "VggOutputs",
+ ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"],
+ )
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
+
+ return out
+
+
+class resnet(torch.nn.Module):
+ def __init__(self, requires_grad=False, pretrained=True, num=18):
+ super(resnet, self).__init__()
+ if num == 18:
+ self.net = models.resnet18(pretrained=pretrained)
+ elif num == 34:
+ self.net = models.resnet34(pretrained=pretrained)
+ elif num == 50:
+ self.net = models.resnet50(pretrained=pretrained)
+ elif num == 101:
+ self.net = models.resnet101(pretrained=pretrained)
+ elif num == 152:
+ self.net = models.resnet152(pretrained=pretrained)
+ self.N_slices = 5
+
+ self.conv1 = self.net.conv1
+ self.bn1 = self.net.bn1
+ self.relu = self.net.relu
+ self.maxpool = self.net.maxpool
+ self.layer1 = self.net.layer1
+ self.layer2 = self.net.layer2
+ self.layer3 = self.net.layer3
+ self.layer4 = self.net.layer4
+
+ def forward(self, X):
+ h = self.conv1(X)
+ h = self.bn1(h)
+ h = self.relu(h)
+ h_relu1 = h
+ h = self.maxpool(h)
+ h = self.layer1(h)
+ h_conv2 = h
+ h = self.layer2(h)
+ h_conv3 = h
+ h = self.layer3(h)
+ h_conv4 = h
+ h = self.layer4(h)
+ h_conv5 = h
+
+ outputs = namedtuple(
+ "Outputs", ["relu1", "conv2", "conv3", "conv4", "conv5"]
+ )
+ out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
+
+ return out
+
+# Off-the-shelf deep network
+class PNet(torch.nn.Module):
+ """Pre-trained network with all channels equally weighted by default"""
+
+ def __init__(self, pnet_type="vgg", pnet_rand=False, use_gpu=True):
+ super(PNet, self).__init__()
+
+ self.use_gpu = use_gpu
+
+ self.pnet_type = pnet_type
+ self.pnet_rand = pnet_rand
+
+ self.shift = torch.Tensor([-0.030, -0.088, -0.188]).view(1, 3, 1, 1)
+ self.scale = torch.Tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1)
+
+ if self.pnet_type in ["vgg", "vgg16"]:
+ self.net = vgg16(pretrained=not self.pnet_rand, requires_grad=False)
+ elif self.pnet_type == "alex":
+ self.net = alexnet(
+ pretrained=not self.pnet_rand, requires_grad=False
+ )
+ elif self.pnet_type[:-2] == "resnet":
+ self.net = resnet(
+ pretrained=not self.pnet_rand,
+ requires_grad=False,
+ num=int(self.pnet_type[-2:]),
+ )
+ elif self.pnet_type == "squeeze":
+ self.net = squeezenet(
+ pretrained=not self.pnet_rand, requires_grad=False
+ )
+
+ self.L = self.net.N_slices
+
+ if use_gpu:
+ self.net.cuda()
+ self.shift = self.shift.cuda()
+ self.scale = self.scale.cuda()
+
+ def forward(self, in0, in1, retPerLayer=False):
+ in0_sc = (in0 - self.shift.expand_as(in0)) / self.scale.expand_as(in0)
+ in1_sc = (in1 - self.shift.expand_as(in0)) / self.scale.expand_as(in0)
+
+ outs0 = self.net.forward(in0_sc)
+ outs1 = self.net.forward(in1_sc)
+
+ if retPerLayer:
+ all_scores = []
+ for (kk, out0) in enumerate(outs0):
+ cur_score = 1.0 - cos_sim(outs0[kk], outs1[kk])
+ if kk == 0:
+ val = 1.0 * cur_score
+ else:
+ val = val + cur_score
+ if retPerLayer:
+ all_scores += [cur_score]
+
+ if retPerLayer:
+ return (val, all_scores)
+ else:
+ return val
+
+
+
+
+# The SSIM metric
+def ssim_metric(img1, img2, mask=None):
+ return ssim(img1, img2, mask=mask, size_average=False)
+
+
+# The PSNR metric
+def psnr(img1, img2, mask=None,reshape=False):
+ b = img1.size(0)
+ if not (mask is None):
+ b = img1.size(0)
+ mse_err = (img1 - img2).pow(2) * mask
+ if reshape:
+ mse_err = mse_err.reshape(b, -1).sum(dim=1) / (
+ 3 * mask.reshape(b, -1).sum(dim=1).clamp(min=1)
+ )
+ else:
+ mse_err = mse_err.view(b, -1).sum(dim=1) / (
+ 3 * mask.view(b, -1).sum(dim=1).clamp(min=1)
+ )
+ else:
+ if reshape:
+ mse_err = (img1 - img2).pow(2).reshape(b, -1).mean(dim=1)
+ else:
+ mse_err = (img1 - img2).pow(2).view(b, -1).mean(dim=1)
+
+ psnr = 10 * (1 / mse_err).log10()
+ return psnr
+
+
+# The perceptual similarity metric
+def perceptual_sim(img1, img2, vgg16):
+ # First extract features
+ dist = vgg16(img1 * 2 - 1, img2 * 2 - 1)
+
+ return dist
+
+def load_img(img_name, size=None):
+ try:
+ img = Image.open(img_name)
+
+ if type(size) == int:
+ img = img.resize((size, size))
+ elif size is not None:
+ img = img.resize((size[1], size[0]))
+
+ img = transform(img).cuda()
+ img = img.unsqueeze(0)
+ except Exception as e:
+ print("Failed at loading %s " % img_name)
+ print(e)
+ img = torch.zeros(1, 3, 256, 256).cuda()
+ raise
+ return img
+
+
+def compute_perceptual_similarity(folder, pred_img, tgt_img, take_every_other):
+
+ # Load VGG16 for feature similarity
+ vgg16 = PNet().to("cuda")
+ vgg16.eval()
+ vgg16.cuda()
+
+ values_percsim = []
+ values_ssim = []
+ values_psnr = []
+ folders = os.listdir(folder)
+ for i, f in tqdm(enumerate(sorted(folders))):
+ pred_imgs = glob.glob(folder + f + "/" + pred_img)
+ tgt_imgs = glob.glob(folder + f + "/" + tgt_img)
+ assert len(tgt_imgs) == 1
+
+ perc_sim = 10000
+ ssim_sim = -10
+ psnr_sim = -10
+ for p_img in pred_imgs:
+ t_img = load_img(tgt_imgs[0])
+ p_img = load_img(p_img, size=t_img.shape[2:])
+ t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()
+ perc_sim = min(perc_sim, t_perc_sim)
+
+ ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item())
+ psnr_sim = max(psnr_sim, psnr(p_img, t_img).item())
+
+ values_percsim += [perc_sim]
+ values_ssim += [ssim_sim]
+ values_psnr += [psnr_sim]
+
+ if take_every_other:
+ n_valuespercsim = []
+ n_valuesssim = []
+ n_valuespsnr = []
+ for i in range(0, len(values_percsim) // 2):
+ n_valuespercsim += [
+ min(values_percsim[2 * i], values_percsim[2 * i + 1])
+ ]
+ n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]
+ n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]
+
+ values_percsim = n_valuespercsim
+ values_ssim = n_valuesssim
+ values_psnr = n_valuespsnr
+
+ avg_percsim = np.mean(np.array(values_percsim))
+ std_percsim = np.std(np.array(values_percsim))
+
+ avg_psnr = np.mean(np.array(values_psnr))
+ std_psnr = np.std(np.array(values_psnr))
+
+ avg_ssim = np.mean(np.array(values_ssim))
+ std_ssim = np.std(np.array(values_ssim))
+
+ return {
+ "Perceptual similarity": (avg_percsim, std_percsim),
+ "PSNR": (avg_psnr, std_psnr),
+ "SSIM": (avg_ssim, std_ssim),
+ }
+
+
+def compute_perceptual_similarity_from_list(pred_imgs_list, tgt_imgs_list,
+ take_every_other,
+ simple_format=True):
+
+ # Load VGG16 for feature similarity
+ vgg16 = PNet().to("cuda")
+ vgg16.eval()
+ vgg16.cuda()
+
+ values_percsim = []
+ values_ssim = []
+ values_psnr = []
+ equal_count = 0
+ ambig_count = 0
+ for i, tgt_img in enumerate(tqdm(tgt_imgs_list)):
+ pred_imgs = pred_imgs_list[i]
+ tgt_imgs = [tgt_img]
+ assert len(tgt_imgs) == 1
+
+ if type(pred_imgs) != list:
+ pred_imgs = [pred_imgs]
+
+ perc_sim = 10000
+ ssim_sim = -10
+ psnr_sim = -10
+ assert len(pred_imgs)>0
+ for p_img in pred_imgs:
+ t_img = load_img(tgt_imgs[0])
+ p_img = load_img(p_img, size=t_img.shape[2:])
+ t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()
+ perc_sim = min(perc_sim, t_perc_sim)
+
+ ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item())
+ psnr_sim = max(psnr_sim, psnr(p_img, t_img).item())
+
+ values_percsim += [perc_sim]
+ values_ssim += [ssim_sim]
+ if psnr_sim != np.float("inf"):
+ values_psnr += [psnr_sim]
+ else:
+ if torch.allclose(p_img, t_img):
+ equal_count += 1
+ print("{} equal src and wrp images.".format(equal_count))
+ else:
+ ambig_count += 1
+ print("{} ambiguous src and wrp images.".format(ambig_count))
+
+ if take_every_other:
+ n_valuespercsim = []
+ n_valuesssim = []
+ n_valuespsnr = []
+ for i in range(0, len(values_percsim) // 2):
+ n_valuespercsim += [
+ min(values_percsim[2 * i], values_percsim[2 * i + 1])
+ ]
+ n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]
+ n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]
+
+ values_percsim = n_valuespercsim
+ values_ssim = n_valuesssim
+ values_psnr = n_valuespsnr
+
+ avg_percsim = np.mean(np.array(values_percsim))
+ std_percsim = np.std(np.array(values_percsim))
+
+ avg_psnr = np.mean(np.array(values_psnr))
+ std_psnr = np.std(np.array(values_psnr))
+
+ avg_ssim = np.mean(np.array(values_ssim))
+ std_ssim = np.std(np.array(values_ssim))
+
+ if simple_format:
+ # just to make yaml formatting readable
+ return {
+ "Perceptual similarity": [float(avg_percsim), float(std_percsim)],
+ "PSNR": [float(avg_psnr), float(std_psnr)],
+ "SSIM": [float(avg_ssim), float(std_ssim)],
+ }
+ else:
+ return {
+ "Perceptual similarity": (avg_percsim, std_percsim),
+ "PSNR": (avg_psnr, std_psnr),
+ "SSIM": (avg_ssim, std_ssim),
+ }
+
+
+def compute_perceptual_similarity_from_list_topk(pred_imgs_list, tgt_imgs_list,
+ take_every_other, resize=False):
+
+ # Load VGG16 for feature similarity
+ vgg16 = PNet().to("cuda")
+ vgg16.eval()
+ vgg16.cuda()
+
+ values_percsim = []
+ values_ssim = []
+ values_psnr = []
+ individual_percsim = []
+ individual_ssim = []
+ individual_psnr = []
+ for i, tgt_img in enumerate(tqdm(tgt_imgs_list)):
+ pred_imgs = pred_imgs_list[i]
+ tgt_imgs = [tgt_img]
+ assert len(tgt_imgs) == 1
+
+ if type(pred_imgs) != list:
+ assert False
+ pred_imgs = [pred_imgs]
+
+ perc_sim = 10000
+ ssim_sim = -10
+ psnr_sim = -10
+ sample_percsim = list()
+ sample_ssim = list()
+ sample_psnr = list()
+ for p_img in pred_imgs:
+ if resize:
+ t_img = load_img(tgt_imgs[0], size=(256,256))
+ else:
+ t_img = load_img(tgt_imgs[0])
+ p_img = load_img(p_img, size=t_img.shape[2:])
+
+ t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()
+ sample_percsim.append(t_perc_sim)
+ perc_sim = min(perc_sim, t_perc_sim)
+
+ t_ssim = ssim_metric(p_img, t_img).item()
+ sample_ssim.append(t_ssim)
+ ssim_sim = max(ssim_sim, t_ssim)
+
+ t_psnr = psnr(p_img, t_img).item()
+ sample_psnr.append(t_psnr)
+ psnr_sim = max(psnr_sim, t_psnr)
+
+ values_percsim += [perc_sim]
+ values_ssim += [ssim_sim]
+ values_psnr += [psnr_sim]
+ individual_percsim.append(sample_percsim)
+ individual_ssim.append(sample_ssim)
+ individual_psnr.append(sample_psnr)
+
+ if take_every_other:
+ assert False, "Do this later, after specifying topk to get proper results"
+ n_valuespercsim = []
+ n_valuesssim = []
+ n_valuespsnr = []
+ for i in range(0, len(values_percsim) // 2):
+ n_valuespercsim += [
+ min(values_percsim[2 * i], values_percsim[2 * i + 1])
+ ]
+ n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]
+ n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]
+
+ values_percsim = n_valuespercsim
+ values_ssim = n_valuesssim
+ values_psnr = n_valuespsnr
+
+ avg_percsim = np.mean(np.array(values_percsim))
+ std_percsim = np.std(np.array(values_percsim))
+
+ avg_psnr = np.mean(np.array(values_psnr))
+ std_psnr = np.std(np.array(values_psnr))
+
+ avg_ssim = np.mean(np.array(values_ssim))
+ std_ssim = np.std(np.array(values_ssim))
+
+ individual_percsim = np.array(individual_percsim)
+ individual_psnr = np.array(individual_psnr)
+ individual_ssim = np.array(individual_ssim)
+
+ return {
+ "avg_of_best": {
+ "Perceptual similarity": [float(avg_percsim), float(std_percsim)],
+ "PSNR": [float(avg_psnr), float(std_psnr)],
+ "SSIM": [float(avg_ssim), float(std_ssim)],
+ },
+ "individual": {
+ "PSIM": individual_percsim,
+ "PSNR": individual_psnr,
+ "SSIM": individual_ssim,
+ }
+ }
+
+
+if __name__ == "__main__":
+ args = argparse.ArgumentParser()
+ args.add_argument("--folder", type=str, default="")
+ args.add_argument("--pred_image", type=str, default="")
+ args.add_argument("--target_image", type=str, default="")
+ args.add_argument("--take_every_other", action="store_true", default=False)
+ args.add_argument("--output_file", type=str, default="")
+
+ opts = args.parse_args()
+
+ folder = opts.folder
+ pred_img = opts.pred_image
+ tgt_img = opts.target_image
+
+ results = compute_perceptual_similarity(
+ folder, pred_img, tgt_img, opts.take_every_other
+ )
+
+ f = open(opts.output_file, 'w')
+ for key in results:
+ print("%s for %s: \n" % (key, opts.folder))
+ print(
+ "\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1])
+ )
+
+ f.write("%s for %s: \n" % (key, opts.folder))
+ f.write(
+ "\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1])
+ )
+
+ f.close()
diff --git a/ldm/modules/evaluate/frechet_video_distance.py b/ldm/modules/evaluate/frechet_video_distance.py
new file mode 100755
index 0000000..d9e13c4
--- /dev/null
+++ b/ldm/modules/evaluate/frechet_video_distance.py
@@ -0,0 +1,147 @@
+# coding=utf-8
+# Copyright 2022 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python2, python3
+"""Minimal Reference implementation for the Frechet Video Distance (FVD).
+
+FVD is a metric for the quality of video generation models. It is inspired by
+the FID (Frechet Inception Distance) used for images, but uses a different
+embedding to be better suitable for videos.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import six
+import tensorflow.compat.v1 as tf
+import tensorflow_gan as tfgan
+import tensorflow_hub as hub
+
+
+def preprocess(videos, target_resolution):
+ """Runs some preprocessing on the videos for I3D model.
+
+ Args:
+ videos: [batch_size, num_frames, height, width, depth] The videos to be
+ preprocessed. We don't care about the specific dtype of the videos, it can
+ be anything that tf.image.resize_bilinear accepts. Values are expected to
+ be in the range 0-255.
+ target_resolution: (width, height): target video resolution
+
+ Returns:
+ videos: [batch_size, num_frames, height, width, depth]
+ """
+ videos_shape = list(videos.shape)
+ all_frames = tf.reshape(videos, [-1] + videos_shape[-3:])
+ resized_videos = tf.image.resize_bilinear(all_frames, size=target_resolution)
+ target_shape = [videos_shape[0], -1] + list(target_resolution) + [3]
+ output_videos = tf.reshape(resized_videos, target_shape)
+ scaled_videos = 2. * tf.cast(output_videos, tf.float32) / 255. - 1
+ return scaled_videos
+
+
+def _is_in_graph(tensor_name):
+ """Checks whether a given tensor does exists in the graph."""
+ try:
+ tf.get_default_graph().get_tensor_by_name(tensor_name)
+ except KeyError:
+ return False
+ return True
+
+
+def create_id3_embedding(videos,warmup=False,batch_size=16):
+ """Embeds the given videos using the Inflated 3D Convolution ne twork.
+
+ Downloads the graph of the I3D from tf.hub and adds it to the graph on the
+ first call.
+
+ Args:
+ videos: [batch_size, num_frames, height=224, width=224, depth=3].
+ Expected range is [-1, 1].
+
+ Returns:
+ embedding: [batch_size, embedding_size]. embedding_size depends
+ on the model used.
+
+ Raises:
+ ValueError: when a provided embedding_layer is not supported.
+ """
+
+ # batch_size = 16
+ module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1"
+
+
+ # Making sure that we import the graph separately for
+ # each different input video tensor.
+ module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str(
+ videos.name).replace(":", "_")
+
+
+
+ assert_ops = [
+ tf.Assert(
+ tf.reduce_max(videos) <= 1.001,
+ ["max value in frame is > 1", videos]),
+ tf.Assert(
+ tf.reduce_min(videos) >= -1.001,
+ ["min value in frame is < -1", videos]),
+ tf.assert_equal(
+ tf.shape(videos)[0],
+ batch_size, ["invalid frame batch size: ",
+ tf.shape(videos)],
+ summarize=6),
+ ]
+ with tf.control_dependencies(assert_ops):
+ videos = tf.identity(videos)
+
+ module_scope = "%s_apply_default/" % module_name
+
+ # To check whether the module has already been loaded into the graph, we look
+ # for a given tensor name. If this tensor name exists, we assume the function
+ # has been called before and the graph was imported. Otherwise we import it.
+ # Note: in theory, the tensor could exist, but have wrong shapes.
+ # This will happen if create_id3_embedding is called with a frames_placehoder
+ # of wrong size/batch size, because even though that will throw a tf.Assert
+ # on graph-execution time, it will insert the tensor (with wrong shape) into
+ # the graph. This is why we need the following assert.
+ if warmup:
+ video_batch_size = int(videos.shape[0])
+ assert video_batch_size in [batch_size, -1, None], f"Invalid batch size {video_batch_size}"
+ tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
+ if not _is_in_graph(tensor_name):
+ i3d_model = hub.Module(module_spec, name=module_name)
+ i3d_model(videos)
+
+ # gets the kinetics-i3d-400-logits layer
+ tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
+ tensor = tf.get_default_graph().get_tensor_by_name(tensor_name)
+ return tensor
+
+
+def calculate_fvd(real_activations,
+ generated_activations):
+ """Returns a list of ops that compute metrics as funcs of activations.
+
+ Args:
+ real_activations: [num_samples, embedding_size]
+ generated_activations: [num_samples, embedding_size]
+
+ Returns:
+ A scalar that contains the requested FVD.
+ """
+ return tfgan.eval.frechet_classifier_distance_from_activations(
+ real_activations, generated_activations)
diff --git a/ldm/modules/evaluate/ssim.py b/ldm/modules/evaluate/ssim.py
new file mode 100755
index 0000000..4e8883c
--- /dev/null
+++ b/ldm/modules/evaluate/ssim.py
@@ -0,0 +1,124 @@
+# MIT Licence
+
+# Methods to predict the SSIM, taken from
+# https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
+
+from math import exp
+
+import torch
+import torch.nn.functional as F
+from torch.autograd import Variable
+
+def gaussian(window_size, sigma):
+ gauss = torch.Tensor(
+ [
+ exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2))
+ for x in range(window_size)
+ ]
+ )
+ return gauss / gauss.sum()
+
+
+def create_window(window_size, channel):
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
+ window = Variable(
+ _2D_window.expand(channel, 1, window_size, window_size).contiguous()
+ )
+ return window
+
+
+def _ssim(
+ img1, img2, window, window_size, channel, mask=None, size_average=True
+):
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
+
+ mu1_sq = mu1.pow(2)
+ mu2_sq = mu2.pow(2)
+ mu1_mu2 = mu1 * mu2
+
+ sigma1_sq = (
+ F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel)
+ - mu1_sq
+ )
+ sigma2_sq = (
+ F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel)
+ - mu2_sq
+ )
+ sigma12 = (
+ F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
+ - mu1_mu2
+ )
+
+ C1 = (0.01) ** 2
+ C2 = (0.03) ** 2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
+ (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
+ )
+
+ if not (mask is None):
+ b = mask.size(0)
+ ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask
+ ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum(
+ dim=1
+ ).clamp(min=1)
+ return ssim_map
+
+ import pdb
+
+ pdb.set_trace
+
+ if size_average:
+ return ssim_map.mean()
+ else:
+ return ssim_map.mean(1).mean(1).mean(1)
+
+
+class SSIM(torch.nn.Module):
+ def __init__(self, window_size=11, size_average=True):
+ super(SSIM, self).__init__()
+ self.window_size = window_size
+ self.size_average = size_average
+ self.channel = 1
+ self.window = create_window(window_size, self.channel)
+
+ def forward(self, img1, img2, mask=None):
+ (_, channel, _, _) = img1.size()
+
+ if (
+ channel == self.channel
+ and self.window.data.type() == img1.data.type()
+ ):
+ window = self.window
+ else:
+ window = create_window(self.window_size, channel)
+
+ if img1.is_cuda:
+ window = window.cuda(img1.get_device())
+ window = window.type_as(img1)
+
+ self.window = window
+ self.channel = channel
+
+ return _ssim(
+ img1,
+ img2,
+ window,
+ self.window_size,
+ channel,
+ mask,
+ self.size_average,
+ )
+
+
+def ssim(img1, img2, window_size=11, mask=None, size_average=True):
+ (_, channel, _, _) = img1.size()
+ window = create_window(window_size, channel)
+
+ if img1.is_cuda:
+ window = window.cuda(img1.get_device())
+ window = window.type_as(img1)
+
+ return _ssim(img1, img2, window, window_size, channel, mask, size_average)
diff --git a/ldm/modules/evaluate/torch_frechet_video_distance.py b/ldm/modules/evaluate/torch_frechet_video_distance.py
new file mode 100755
index 0000000..04856b8
--- /dev/null
+++ b/ldm/modules/evaluate/torch_frechet_video_distance.py
@@ -0,0 +1,294 @@
+# based on https://github.com/universome/fvd-comparison/blob/master/compare_models.py; huge thanks!
+import os
+import numpy as np
+import io
+import re
+import requests
+import html
+import hashlib
+import urllib
+import urllib.request
+import scipy.linalg
+import multiprocessing as mp
+import glob
+
+
+from tqdm import tqdm
+from typing import Any, List, Tuple, Union, Dict, Callable
+
+from torchvision.io import read_video
+import torch; torch.set_grad_enabled(False)
+from einops import rearrange
+
+from nitro.util import isvideo
+
+def compute_frechet_distance(mu_sample,sigma_sample,mu_ref,sigma_ref) -> float:
+ print('Calculate frechet distance...')
+ m = np.square(mu_sample - mu_ref).sum()
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma_sample, sigma_ref), disp=False) # pylint: disable=no-member
+ fid = np.real(m + np.trace(sigma_sample + sigma_ref - s * 2))
+
+ return float(fid)
+
+
+def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ mu = feats.mean(axis=0) # [d]
+ sigma = np.cov(feats, rowvar=False) # [d, d]
+
+ return mu, sigma
+
+
+def open_url(url: str, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False) -> Any:
+ """Download the given URL and return a binary-mode file object to access the data."""
+ assert num_attempts >= 1
+
+ # Doesn't look like an URL scheme so interpret it as a local filename.
+ if not re.match('^[a-z]+://', url):
+ return url if return_filename else open(url, "rb")
+
+ # Handle file URLs. This code handles unusual file:// patterns that
+ # arise on Windows:
+ #
+ # file:///c:/foo.txt
+ #
+ # which would translate to a local '/c:/foo.txt' filename that's
+ # invalid. Drop the forward slash for such pathnames.
+ #
+ # If you touch this code path, you should test it on both Linux and
+ # Windows.
+ #
+ # Some internet resources suggest using urllib.request.url2pathname() but
+ # but that converts forward slashes to backslashes and this causes
+ # its own set of problems.
+ if url.startswith('file://'):
+ filename = urllib.parse.urlparse(url).path
+ if re.match(r'^/[a-zA-Z]:', filename):
+ filename = filename[1:]
+ return filename if return_filename else open(filename, "rb")
+
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
+
+ # Download.
+ url_name = None
+ url_data = None
+ with requests.Session() as session:
+ if verbose:
+ print("Downloading %s ..." % url, end="", flush=True)
+ for attempts_left in reversed(range(num_attempts)):
+ try:
+ with session.get(url) as res:
+ res.raise_for_status()
+ if len(res.content) == 0:
+ raise IOError("No data received")
+
+ if len(res.content) < 8192:
+ content_str = res.content.decode("utf-8")
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
+ if len(links) == 1:
+ url = requests.compat.urljoin(url, links[0])
+ raise IOError("Google Drive virus checker nag")
+ if "Google Drive - Quota exceeded" in content_str:
+ raise IOError("Google Drive download quota exceeded -- please try again later")
+
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
+ url_name = match[1] if match else url
+ url_data = res.content
+ if verbose:
+ print(" done")
+ break
+ except KeyboardInterrupt:
+ raise
+ except:
+ if not attempts_left:
+ if verbose:
+ print(" failed")
+ raise
+ if verbose:
+ print(".", end="", flush=True)
+
+ # Return data as file object.
+ assert not return_filename
+ return io.BytesIO(url_data)
+
+def load_video(ip):
+ vid, *_ = read_video(ip)
+ vid = rearrange(vid, 't h w c -> t c h w').to(torch.uint8)
+ return vid
+
+def get_data_from_str(input_str,nprc = None):
+ assert os.path.isdir(input_str), f'Specified input folder "{input_str}" is not a directory'
+ vid_filelist = glob.glob(os.path.join(input_str,'*.mp4'))
+ print(f'Found {len(vid_filelist)} videos in dir {input_str}')
+
+ if nprc is None:
+ try:
+ nprc = mp.cpu_count()
+ except NotImplementedError:
+ print('WARNING: cpu_count() not avlailable, using only 1 cpu for video loading')
+ nprc = 1
+
+ pool = mp.Pool(processes=nprc)
+
+ vids = []
+ for v in tqdm(pool.imap_unordered(load_video,vid_filelist),total=len(vid_filelist),desc='Loading videos...'):
+ vids.append(v)
+
+
+ vids = torch.stack(vids,dim=0).float()
+
+ return vids
+
+def get_stats(stats):
+ assert os.path.isfile(stats) and stats.endswith('.npz'), f'no stats found under {stats}'
+
+ print(f'Using precomputed statistics under {stats}')
+ stats = np.load(stats)
+ stats = {key: stats[key] for key in stats.files}
+
+ return stats
+
+
+
+
+@torch.no_grad()
+def compute_fvd(ref_input, sample_input, bs=32,
+ ref_stats=None,
+ sample_stats=None,
+ nprc_load=None):
+
+
+
+ calc_stats = ref_stats is None or sample_stats is None
+
+ if calc_stats:
+
+ only_ref = sample_stats is not None
+ only_sample = ref_stats is not None
+
+
+ if isinstance(ref_input,str) and not only_sample:
+ ref_input = get_data_from_str(ref_input,nprc_load)
+
+ if isinstance(sample_input, str) and not only_ref:
+ sample_input = get_data_from_str(sample_input, nprc_load)
+
+ stats = compute_statistics(sample_input,ref_input,
+ device='cuda' if torch.cuda.is_available() else 'cpu',
+ bs=bs,
+ only_ref=only_ref,
+ only_sample=only_sample)
+
+ if only_ref:
+ stats.update(get_stats(sample_stats))
+ elif only_sample:
+ stats.update(get_stats(ref_stats))
+
+
+
+ else:
+ stats = get_stats(sample_stats)
+ stats.update(get_stats(ref_stats))
+
+ fvd = compute_frechet_distance(**stats)
+
+ return {'FVD' : fvd,}
+
+
+@torch.no_grad()
+def compute_statistics(videos_fake, videos_real, device: str='cuda', bs=32, only_ref=False,only_sample=False) -> Dict:
+ detector_url = 'https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1'
+ detector_kwargs = dict(rescale=True, resize=True, return_features=True) # Return raw features before the softmax layer.
+
+ with open_url(detector_url, verbose=False) as f:
+ detector = torch.jit.load(f).eval().to(device)
+
+
+
+ assert not (only_sample and only_ref), 'only_ref and only_sample arguments are mutually exclusive'
+
+ ref_embed, sample_embed = [], []
+
+ info = f'Computing I3D activations for FVD score with batch size {bs}'
+
+ if only_ref:
+
+ if not isvideo(videos_real):
+ # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
+ videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float()
+ print(videos_real.shape)
+
+ if videos_real.shape[0] % bs == 0:
+ n_secs = videos_real.shape[0] // bs
+ else:
+ n_secs = videos_real.shape[0] // bs + 1
+
+ videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
+
+ for ref_v in tqdm(videos_real, total=len(videos_real),desc=info):
+
+ feats_ref = detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
+ ref_embed.append(feats_ref)
+
+ elif only_sample:
+
+ if not isvideo(videos_fake):
+ # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
+ videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float()
+ print(videos_fake.shape)
+
+ if videos_fake.shape[0] % bs == 0:
+ n_secs = videos_fake.shape[0] // bs
+ else:
+ n_secs = videos_fake.shape[0] // bs + 1
+
+ videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
+
+ for sample_v in tqdm(videos_fake, total=len(videos_real),desc=info):
+ feats_sample = detector(sample_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
+ sample_embed.append(feats_sample)
+
+
+ else:
+
+ if not isvideo(videos_real):
+ # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
+ videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float()
+
+ if not isvideo(videos_fake):
+ videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float()
+
+ if videos_fake.shape[0] % bs == 0:
+ n_secs = videos_fake.shape[0] // bs
+ else:
+ n_secs = videos_fake.shape[0] // bs + 1
+
+ videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
+ videos_fake = torch.tensor_split(videos_fake, n_secs, dim=0)
+
+ for ref_v, sample_v in tqdm(zip(videos_real,videos_fake),total=len(videos_fake),desc=info):
+ # print(ref_v.shape)
+ # ref_v = torch.nn.functional.interpolate(ref_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False)
+ # sample_v = torch.nn.functional.interpolate(sample_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False)
+
+
+ feats_sample = detector(sample_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
+ feats_ref = detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
+ sample_embed.append(feats_sample)
+ ref_embed.append(feats_ref)
+
+ out = dict()
+ if len(sample_embed) > 0:
+ sample_embed = np.concatenate(sample_embed,axis=0)
+ mu_sample, sigma_sample = compute_stats(sample_embed)
+ out.update({'mu_sample': mu_sample,
+ 'sigma_sample': sigma_sample})
+
+ if len(ref_embed) > 0:
+ ref_embed = np.concatenate(ref_embed,axis=0)
+ mu_ref, sigma_ref = compute_stats(ref_embed)
+ out.update({'mu_ref': mu_ref,
+ 'sigma_ref': sigma_ref})
+
+
+ return out
diff --git a/ldm/modules/image_degradation/__init__.py b/ldm/modules/image_degradation/__init__.py
new file mode 100755
index 0000000..7836cad
--- /dev/null
+++ b/ldm/modules/image_degradation/__init__.py
@@ -0,0 +1,2 @@
+from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
+from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
diff --git a/ldm/modules/image_degradation/bsrgan.py b/ldm/modules/image_degradation/bsrgan.py
new file mode 100755
index 0000000..32ef561
--- /dev/null
+++ b/ldm/modules/image_degradation/bsrgan.py
@@ -0,0 +1,730 @@
+# -*- coding: utf-8 -*-
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+import numpy as np
+import cv2
+import torch
+
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+
+import ldm.modules.image_degradation.utils_image as util
+
+
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+ k_size = k.shape[0]
+ # Calculate the big kernels size
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+ # Loop over the small kernel to fill the big one
+ for r in range(k_size):
+ for c in range(k_size):
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+ crop = k_size // 2
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
+ # Normalize to 1
+ return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+ Returns:
+ k : kernel
+ """
+
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+ return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+ k = k / np.sum(k)
+ return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf - 1) * 0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+
+ x1 = np.clip(x1, 0, w - 1)
+ y1 = np.clip(y1, 0, h - 1)
+
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+ return x
+
+
+def blur(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, Nx1xhxw
+ '''
+ n, c = x.shape[:2]
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+ k = k.repeat(1, c, 1, 1)
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
+ x = x.view(1, -1, x.shape[2], x.shape[3])
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+ x = x.view(n, c, x.shape[2], x.shape[3])
+
+ return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ SIGMA = Q @ LAMBDA @ Q.T
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+
+ # Create meshgrid for Gaussian
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z - MU
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+ # shift the kernel so it will be centered
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+ # Normalize the kernel and return
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+ std = sigma
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+ arg = -(x * x + y * y) / (2 * std * std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h / sumh
+ return h
+
+
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha, 1])])
+ h1 = alpha / (alpha + 1)
+ h2 = (1 - alpha) / (alpha + 1)
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+ '''
+ python code from:
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+ '''
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = util.imresize_np(x, scale=1 / sf)
+ return x
+
+
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
+ x = bicubic_degradation(x, sf=sf)
+ return x
+
+
+def dpsr_degradation(x, k, sf=3):
+ ''' bicubic downsampling + blur
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+
+
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+ st = 0
+ return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening. borrowed from real-ESRGAN
+ Input image: I; Blurry image: B.
+ 1. K = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * K + (1 - Mask) * I
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ K = img + weight * residual
+ K = np.clip(K, 0, 1)
+ return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2 * sf
+ if random.random() < 0.5:
+ l1 = wd2 * random.random()
+ l2 = wd2 * random.random()
+ k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+ else:
+ k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
+ img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+ return img
+
+
+def add_resize(img, sf=4):
+ rnum = np.random.rand()
+ if rnum > 0.8: # up
+ sf1 = random.uniform(1, 2)
+ elif rnum < 0.7: # down
+ sf1 = random.uniform(0.5 / sf, 1)
+ else:
+ sf1 = 1.0
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+# noise_level = random.randint(noise_level1, noise_level2)
+# rnum = np.random.rand()
+# if rnum > 0.6: # add color Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+# elif rnum < 0.4: # add grayscale Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+# else: # add noise
+# L = noise_level2 / 255.
+# D = np.diag(np.random.rand(3))
+# U = orth(np.random.rand(3, 3))
+# conv = np.dot(np.dot(np.transpose(U), D), U)
+# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+# img = np.clip(img, 0.0, 1.0)
+# return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ rnum = np.random.rand()
+ if rnum > 0.6: # add color Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4: # add grayscale Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else: # add noise
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ img = np.clip(img, 0.0, 1.0)
+ rnum = random.random()
+ if rnum > 0.6:
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4:
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_Poisson_noise(img):
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
+ if random.random() < 0.5:
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
+ else:
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+ img += noise_gray[:, :, np.newaxis]
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_JPEG_noise(img):
+ quality_factor = random.randint(30, 95)
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+ img = cv2.imdecode(encimg, 1)
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+ return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+ h, w = lq.shape[:2]
+ rnd_h = random.randint(0, h - lq_patchsize)
+ rnd_w = random.randint(0, w - lq_patchsize)
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+ return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ hq = img.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ img = util.imresize_np(img, 1 / 2, True)
+ img = np.clip(img, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ img = add_blur(img, sf=sf)
+
+ elif i == 1:
+ img = add_blur(img, sf=sf)
+
+ elif i == 2:
+ a, b = img.shape[1], img.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ img = add_JPEG_noise(img)
+
+ elif i == 6:
+ # add processed camera sensor noise
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+ return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ image = util.uint2single(image)
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = image.shape[:2]
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = image.shape[:2]
+
+ hq = image.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ image = util.imresize_np(image, 1 / 2, True)
+ image = np.clip(image, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ image = add_blur(image, sf=sf)
+
+ elif i == 1:
+ image = add_blur(image, sf=sf)
+
+ elif i == 2:
+ a, b = image.shape[1], image.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ image = add_JPEG_noise(image)
+
+ # elif i == 6:
+ # # add processed camera sensor noise
+ # if random.random() < isp_prob and isp_model is not None:
+ # with torch.no_grad():
+ # img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ image = add_JPEG_noise(image)
+ image = util.single2uint(image)
+ example = {"image":image}
+ return example
+
+
+# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
+def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
+ """
+ This is an extended degradation model by combining
+ the degradation models of BSRGAN and Real-ESRGAN
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ use_shuffle: the degradation shuffle
+ use_sharp: sharpening the img
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ if use_sharp:
+ img = add_sharpening(img)
+ hq = img.copy()
+
+ if random.random() < shuffle_prob:
+ shuffle_order = random.sample(range(13), 13)
+ else:
+ shuffle_order = list(range(13))
+ # local shuffle for noise, JPEG is always the last one
+ shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
+ shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
+
+ poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
+
+ for i in shuffle_order:
+ if i == 0:
+ img = add_blur(img, sf=sf)
+ elif i == 1:
+ img = add_resize(img, sf=sf)
+ elif i == 2:
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 3:
+ if random.random() < poisson_prob:
+ img = add_Poisson_noise(img)
+ elif i == 4:
+ if random.random() < speckle_prob:
+ img = add_speckle_noise(img)
+ elif i == 5:
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ elif i == 6:
+ img = add_JPEG_noise(img)
+ elif i == 7:
+ img = add_blur(img, sf=sf)
+ elif i == 8:
+ img = add_resize(img, sf=sf)
+ elif i == 9:
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 10:
+ if random.random() < poisson_prob:
+ img = add_Poisson_noise(img)
+ elif i == 11:
+ if random.random() < speckle_prob:
+ img = add_speckle_noise(img)
+ elif i == 12:
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ else:
+ print('check the shuffle!')
+
+ # resize to desired size
+ img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf, lq_patchsize)
+
+ return img, hq
+
+
+if __name__ == '__main__':
+ print("hey")
+ img = util.imread_uint('utils/test.png', 3)
+ print(img)
+ img = util.uint2single(img)
+ print(img)
+ img = img[:448, :448]
+ h = img.shape[0] // 4
+ print("resizing to", h)
+ sf = 4
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+ for i in range(20):
+ print(i)
+ img_lq = deg_fn(img)
+ print(img_lq)
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
+ print(img_lq.shape)
+ print("bicubic", img_lq_bicubic.shape)
+ print(img_hq.shape)
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+ util.imsave(img_concat, str(i) + '.png')
+
+
diff --git a/ldm/modules/image_degradation/bsrgan_light.py b/ldm/modules/image_degradation/bsrgan_light.py
new file mode 100755
index 0000000..dfa7606
--- /dev/null
+++ b/ldm/modules/image_degradation/bsrgan_light.py
@@ -0,0 +1,650 @@
+# -*- coding: utf-8 -*-
+import numpy as np
+import cv2
+import torch
+
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+
+import ldm.modules.image_degradation.utils_image as util
+
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+ k_size = k.shape[0]
+ # Calculate the big kernels size
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+ # Loop over the small kernel to fill the big one
+ for r in range(k_size):
+ for c in range(k_size):
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+ crop = k_size // 2
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
+ # Normalize to 1
+ return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+ Returns:
+ k : kernel
+ """
+
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+ return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+ k = k / np.sum(k)
+ return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf - 1) * 0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+
+ x1 = np.clip(x1, 0, w - 1)
+ y1 = np.clip(y1, 0, h - 1)
+
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+ return x
+
+
+def blur(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, Nx1xhxw
+ '''
+ n, c = x.shape[:2]
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+ k = k.repeat(1, c, 1, 1)
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
+ x = x.view(1, -1, x.shape[2], x.shape[3])
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+ x = x.view(n, c, x.shape[2], x.shape[3])
+
+ return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ SIGMA = Q @ LAMBDA @ Q.T
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+
+ # Create meshgrid for Gaussian
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z - MU
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+ # shift the kernel so it will be centered
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+ # Normalize the kernel and return
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+ std = sigma
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+ arg = -(x * x + y * y) / (2 * std * std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h / sumh
+ return h
+
+
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha, 1])])
+ h1 = alpha / (alpha + 1)
+ h2 = (1 - alpha) / (alpha + 1)
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+ '''
+ python code from:
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+ '''
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = util.imresize_np(x, scale=1 / sf)
+ return x
+
+
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
+ x = bicubic_degradation(x, sf=sf)
+ return x
+
+
+def dpsr_degradation(x, k, sf=3):
+ ''' bicubic downsampling + blur
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+
+
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+ st = 0
+ return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening. borrowed from real-ESRGAN
+ Input image: I; Blurry image: B.
+ 1. K = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * K + (1 - Mask) * I
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ K = img + weight * residual
+ K = np.clip(K, 0, 1)
+ return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2 * sf
+
+ wd2 = wd2/4
+ wd = wd/4
+
+ if random.random() < 0.5:
+ l1 = wd2 * random.random()
+ l2 = wd2 * random.random()
+ k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+ else:
+ k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
+ img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+ return img
+
+
+def add_resize(img, sf=4):
+ rnum = np.random.rand()
+ if rnum > 0.8: # up
+ sf1 = random.uniform(1, 2)
+ elif rnum < 0.7: # down
+ sf1 = random.uniform(0.5 / sf, 1)
+ else:
+ sf1 = 1.0
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+# noise_level = random.randint(noise_level1, noise_level2)
+# rnum = np.random.rand()
+# if rnum > 0.6: # add color Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+# elif rnum < 0.4: # add grayscale Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+# else: # add noise
+# L = noise_level2 / 255.
+# D = np.diag(np.random.rand(3))
+# U = orth(np.random.rand(3, 3))
+# conv = np.dot(np.dot(np.transpose(U), D), U)
+# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+# img = np.clip(img, 0.0, 1.0)
+# return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ rnum = np.random.rand()
+ if rnum > 0.6: # add color Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4: # add grayscale Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else: # add noise
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ img = np.clip(img, 0.0, 1.0)
+ rnum = random.random()
+ if rnum > 0.6:
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4:
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_Poisson_noise(img):
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
+ if random.random() < 0.5:
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
+ else:
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+ img += noise_gray[:, :, np.newaxis]
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_JPEG_noise(img):
+ quality_factor = random.randint(80, 95)
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+ img = cv2.imdecode(encimg, 1)
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+ return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+ h, w = lq.shape[:2]
+ rnd_h = random.randint(0, h - lq_patchsize)
+ rnd_w = random.randint(0, w - lq_patchsize)
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+ return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ hq = img.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ img = util.imresize_np(img, 1 / 2, True)
+ img = np.clip(img, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ img = add_blur(img, sf=sf)
+
+ elif i == 1:
+ img = add_blur(img, sf=sf)
+
+ elif i == 2:
+ a, b = img.shape[1], img.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ img = add_JPEG_noise(img)
+
+ elif i == 6:
+ # add processed camera sensor noise
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+ return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ image = util.uint2single(image)
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = image.shape[:2]
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = image.shape[:2]
+
+ hq = image.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ image = util.imresize_np(image, 1 / 2, True)
+ image = np.clip(image, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ image = add_blur(image, sf=sf)
+
+ # elif i == 1:
+ # image = add_blur(image, sf=sf)
+
+ if i == 0:
+ pass
+
+ elif i == 2:
+ a, b = image.shape[1], image.shape[0]
+ # downsample2
+ if random.random() < 0.8:
+ sf1 = random.uniform(1, 2 * sf)
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
+
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ image = add_JPEG_noise(image)
+ #
+ # elif i == 6:
+ # # add processed camera sensor noise
+ # if random.random() < isp_prob and isp_model is not None:
+ # with torch.no_grad():
+ # img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ image = add_JPEG_noise(image)
+ image = util.single2uint(image)
+ example = {"image": image}
+ return example
+
+
+
+
+if __name__ == '__main__':
+ print("hey")
+ img = util.imread_uint('utils/test.png', 3)
+ img = img[:448, :448]
+ h = img.shape[0] // 4
+ print("resizing to", h)
+ sf = 4
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+ for i in range(20):
+ print(i)
+ img_hq = img
+ img_lq = deg_fn(img)["image"]
+ img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
+ print(img_lq)
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
+ print(img_lq.shape)
+ print("bicubic", img_lq_bicubic.shape)
+ print(img_hq.shape)
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
+ (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+ util.imsave(img_concat, str(i) + '.png')
diff --git a/ldm/modules/image_degradation/utils/test.png b/ldm/modules/image_degradation/utils/test.png
new file mode 100755
index 0000000..4249b43
Binary files /dev/null and b/ldm/modules/image_degradation/utils/test.png differ
diff --git a/ldm/modules/image_degradation/utils_image.py b/ldm/modules/image_degradation/utils_image.py
new file mode 100755
index 0000000..0175f15
--- /dev/null
+++ b/ldm/modules/image_degradation/utils_image.py
@@ -0,0 +1,916 @@
+import os
+import math
+import random
+import numpy as np
+import torch
+import cv2
+from torchvision.utils import make_grid
+from datetime import datetime
+#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
+
+
+os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
+
+
+'''
+# --------------------------------------------
+# Kai Zhang (github: https://github.com/cszn)
+# 03/Mar/2019
+# --------------------------------------------
+# https://github.com/twhui/SRGAN-pyTorch
+# https://github.com/xinntao/BasicSR
+# --------------------------------------------
+'''
+
+
+IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def get_timestamp():
+ return datetime.now().strftime('%y%m%d-%H%M%S')
+
+
+def imshow(x, title=None, cbar=False, figsize=None):
+ plt.figure(figsize=figsize)
+ plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
+ if title:
+ plt.title(title)
+ if cbar:
+ plt.colorbar()
+ plt.show()
+
+
+def surf(Z, cmap='rainbow', figsize=None):
+ plt.figure(figsize=figsize)
+ ax3 = plt.axes(projection='3d')
+
+ w, h = Z.shape[:2]
+ xx = np.arange(0,w,1)
+ yy = np.arange(0,h,1)
+ X, Y = np.meshgrid(xx, yy)
+ ax3.plot_surface(X,Y,Z,cmap=cmap)
+ #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
+ plt.show()
+
+
+'''
+# --------------------------------------------
+# get image pathes
+# --------------------------------------------
+'''
+
+
+def get_image_paths(dataroot):
+ paths = None # return None if dataroot is None
+ if dataroot is not None:
+ paths = sorted(_get_paths_from_images(dataroot))
+ return paths
+
+
+def _get_paths_from_images(path):
+ assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
+ images = []
+ for dirpath, _, fnames in sorted(os.walk(path)):
+ for fname in sorted(fnames):
+ if is_image_file(fname):
+ img_path = os.path.join(dirpath, fname)
+ images.append(img_path)
+ assert images, '{:s} has no valid image file'.format(path)
+ return images
+
+
+'''
+# --------------------------------------------
+# split large images into small images
+# --------------------------------------------
+'''
+
+
+def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
+ w, h = img.shape[:2]
+ patches = []
+ if w > p_max and h > p_max:
+ w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
+ h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
+ w1.append(w-p_size)
+ h1.append(h-p_size)
+# print(w1)
+# print(h1)
+ for i in w1:
+ for j in h1:
+ patches.append(img[i:i+p_size, j:j+p_size,:])
+ else:
+ patches.append(img)
+
+ return patches
+
+
+def imssave(imgs, img_path):
+ """
+ imgs: list, N images of size WxHxC
+ """
+ img_name, ext = os.path.splitext(os.path.basename(img_path))
+
+ for i, img in enumerate(imgs):
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
+ cv2.imwrite(new_path, img)
+
+
+def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
+ """
+ split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
+ and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
+ will be splitted.
+ Args:
+ original_dataroot:
+ taget_dataroot:
+ p_size: size of small images
+ p_overlap: patch size in training is a good choice
+ p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
+ """
+ paths = get_image_paths(original_dataroot)
+ for img_path in paths:
+ # img_name, ext = os.path.splitext(os.path.basename(img_path))
+ img = imread_uint(img_path, n_channels=n_channels)
+ patches = patches_from_image(img, p_size, p_overlap, p_max)
+ imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
+ #if original_dataroot == taget_dataroot:
+ #del img_path
+
+'''
+# --------------------------------------------
+# makedir
+# --------------------------------------------
+'''
+
+
+def mkdir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+
+def mkdirs(paths):
+ if isinstance(paths, str):
+ mkdir(paths)
+ else:
+ for path in paths:
+ mkdir(path)
+
+
+def mkdir_and_rename(path):
+ if os.path.exists(path):
+ new_name = path + '_archived_' + get_timestamp()
+ print('Path already exists. Rename it to [{:s}]'.format(new_name))
+ os.rename(path, new_name)
+ os.makedirs(path)
+
+
+'''
+# --------------------------------------------
+# read image from path
+# opencv is fast, but read BGR numpy image
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# get uint8 image of size HxWxn_channles (RGB)
+# --------------------------------------------
+def imread_uint(path, n_channels=3):
+ # input: path
+ # output: HxWx3(RGB or GGG), or HxWx1 (G)
+ if n_channels == 1:
+ img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
+ img = np.expand_dims(img, axis=2) # HxWx1
+ elif n_channels == 3:
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
+ else:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
+ return img
+
+
+# --------------------------------------------
+# matlab's imwrite
+# --------------------------------------------
+def imsave(img, img_path):
+ img = np.squeeze(img)
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ cv2.imwrite(img_path, img)
+
+def imwrite(img, img_path):
+ img = np.squeeze(img)
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ cv2.imwrite(img_path, img)
+
+
+
+# --------------------------------------------
+# get single image of size HxWxn_channles (BGR)
+# --------------------------------------------
+def read_img(path):
+ # read image by cv2
+ # return: Numpy float32, HWC, BGR, [0,1]
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
+ img = img.astype(np.float32) / 255.
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ # some images have 4 channels
+ if img.shape[2] > 3:
+ img = img[:, :, :3]
+ return img
+
+
+'''
+# --------------------------------------------
+# image format conversion
+# --------------------------------------------
+# numpy(single) <---> numpy(unit)
+# numpy(single) <---> tensor
+# numpy(unit) <---> tensor
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# numpy(single) [0, 1] <---> numpy(unit)
+# --------------------------------------------
+
+
+def uint2single(img):
+
+ return np.float32(img/255.)
+
+
+def single2uint(img):
+
+ return np.uint8((img.clip(0, 1)*255.).round())
+
+
+def uint162single(img):
+
+ return np.float32(img/65535.)
+
+
+def single2uint16(img):
+
+ return np.uint16((img.clip(0, 1)*65535.).round())
+
+
+# --------------------------------------------
+# numpy(unit) (HxWxC or HxW) <---> tensor
+# --------------------------------------------
+
+
+# convert uint to 4-dimensional torch tensor
+def uint2tensor4(img):
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
+
+
+# convert uint to 3-dimensional torch tensor
+def uint2tensor3(img):
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
+
+
+# convert 2/3/4-dimensional torch tensor to uint
+def tensor2uint(img):
+ img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ return np.uint8((img*255.0).round())
+
+
+# --------------------------------------------
+# numpy(single) (HxWxC) <---> tensor
+# --------------------------------------------
+
+
+# convert single (HxWxC) to 3-dimensional torch tensor
+def single2tensor3(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
+
+
+# convert single (HxWxC) to 4-dimensional torch tensor
+def single2tensor4(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
+
+
+# convert torch tensor to single
+def tensor2single(img):
+ img = img.data.squeeze().float().cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+
+ return img
+
+# convert torch tensor to single
+def tensor2single3(img):
+ img = img.data.squeeze().float().cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ elif img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return img
+
+
+def single2tensor5(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
+
+
+def single32tensor5(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
+
+
+def single42tensor4(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
+
+
+# from skimage.io import imread, imsave
+def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
+ '''
+ Converts a torch Tensor into an image Numpy array of BGR channel order
+ Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
+ Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
+ '''
+ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
+ n_dim = tensor.dim()
+ if n_dim == 4:
+ n_img = len(tensor)
+ img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 3:
+ img_np = tensor.numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 2:
+ img_np = tensor.numpy()
+ else:
+ raise TypeError(
+ 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
+ if out_type == np.uint8:
+ img_np = (img_np * 255.0).round()
+ # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
+ return img_np.astype(out_type)
+
+
+'''
+# --------------------------------------------
+# Augmentation, flipe and/or rotate
+# --------------------------------------------
+# The following two are enough.
+# (1) augmet_img: numpy image of WxHxC or WxH
+# (2) augment_img_tensor4: tensor image 1xCxWxH
+# --------------------------------------------
+'''
+
+
+def augment_img(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return np.flipud(np.rot90(img))
+ elif mode == 2:
+ return np.flipud(img)
+ elif mode == 3:
+ return np.rot90(img, k=3)
+ elif mode == 4:
+ return np.flipud(np.rot90(img, k=2))
+ elif mode == 5:
+ return np.rot90(img)
+ elif mode == 6:
+ return np.rot90(img, k=2)
+ elif mode == 7:
+ return np.flipud(np.rot90(img, k=3))
+
+
+def augment_img_tensor4(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return img.rot90(1, [2, 3]).flip([2])
+ elif mode == 2:
+ return img.flip([2])
+ elif mode == 3:
+ return img.rot90(3, [2, 3])
+ elif mode == 4:
+ return img.rot90(2, [2, 3]).flip([2])
+ elif mode == 5:
+ return img.rot90(1, [2, 3])
+ elif mode == 6:
+ return img.rot90(2, [2, 3])
+ elif mode == 7:
+ return img.rot90(3, [2, 3]).flip([2])
+
+
+def augment_img_tensor(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ img_size = img.size()
+ img_np = img.data.cpu().numpy()
+ if len(img_size) == 3:
+ img_np = np.transpose(img_np, (1, 2, 0))
+ elif len(img_size) == 4:
+ img_np = np.transpose(img_np, (2, 3, 1, 0))
+ img_np = augment_img(img_np, mode=mode)
+ img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
+ if len(img_size) == 3:
+ img_tensor = img_tensor.permute(2, 0, 1)
+ elif len(img_size) == 4:
+ img_tensor = img_tensor.permute(3, 2, 0, 1)
+
+ return img_tensor.type_as(img)
+
+
+def augment_img_np3(img, mode=0):
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return img.transpose(1, 0, 2)
+ elif mode == 2:
+ return img[::-1, :, :]
+ elif mode == 3:
+ img = img[::-1, :, :]
+ img = img.transpose(1, 0, 2)
+ return img
+ elif mode == 4:
+ return img[:, ::-1, :]
+ elif mode == 5:
+ img = img[:, ::-1, :]
+ img = img.transpose(1, 0, 2)
+ return img
+ elif mode == 6:
+ img = img[:, ::-1, :]
+ img = img[::-1, :, :]
+ return img
+ elif mode == 7:
+ img = img[:, ::-1, :]
+ img = img[::-1, :, :]
+ img = img.transpose(1, 0, 2)
+ return img
+
+
+def augment_imgs(img_list, hflip=True, rot=True):
+ # horizontal flip OR rotate
+ hflip = hflip and random.random() < 0.5
+ vflip = rot and random.random() < 0.5
+ rot90 = rot and random.random() < 0.5
+
+ def _augment(img):
+ if hflip:
+ img = img[:, ::-1, :]
+ if vflip:
+ img = img[::-1, :, :]
+ if rot90:
+ img = img.transpose(1, 0, 2)
+ return img
+
+ return [_augment(img) for img in img_list]
+
+
+'''
+# --------------------------------------------
+# modcrop and shave
+# --------------------------------------------
+'''
+
+
+def modcrop(img_in, scale):
+ # img_in: Numpy, HWC or HW
+ img = np.copy(img_in)
+ if img.ndim == 2:
+ H, W = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r]
+ elif img.ndim == 3:
+ H, W, C = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r, :]
+ else:
+ raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
+ return img
+
+
+def shave(img_in, border=0):
+ # img_in: Numpy, HWC or HW
+ img = np.copy(img_in)
+ h, w = img.shape[:2]
+ img = img[border:h-border, border:w-border]
+ return img
+
+
+'''
+# --------------------------------------------
+# image processing process on numpy image
+# channel_convert(in_c, tar_type, img_list):
+# rgb2ycbcr(img, only_y=True):
+# bgr2ycbcr(img, only_y=True):
+# ycbcr2rgb(img):
+# --------------------------------------------
+'''
+
+
+def rgb2ycbcr(img, only_y=True):
+ '''same as matlab rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+ [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def ycbcr2rgb(img):
+ '''same as matlab ycbcr2rgb
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def bgr2ycbcr(img, only_y=True):
+ '''bgr version of rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+ [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def channel_convert(in_c, tar_type, img_list):
+ # conversion among BGR, gray and y
+ if in_c == 3 and tar_type == 'gray': # BGR to gray
+ gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in gray_list]
+ elif in_c == 3 and tar_type == 'y': # BGR to y
+ y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in y_list]
+ elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
+ return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
+ else:
+ return img_list
+
+
+'''
+# --------------------------------------------
+# metric, PSNR and SSIM
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# PSNR
+# --------------------------------------------
+def calculate_psnr(img1, img2, border=0):
+ # img1 and img2 have range [0, 255]
+ #img1 = img1.squeeze()
+ #img2 = img2.squeeze()
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ mse = np.mean((img1 - img2)**2)
+ if mse == 0:
+ return float('inf')
+ return 20 * math.log10(255.0 / math.sqrt(mse))
+
+
+# --------------------------------------------
+# SSIM
+# --------------------------------------------
+def calculate_ssim(img1, img2, border=0):
+ '''calculate SSIM
+ the same outputs as MATLAB's
+ img1, img2: [0, 255]
+ '''
+ #img1 = img1.squeeze()
+ #img2 = img2.squeeze()
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+
+ if img1.ndim == 2:
+ return ssim(img1, img2)
+ elif img1.ndim == 3:
+ if img1.shape[2] == 3:
+ ssims = []
+ for i in range(3):
+ ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
+ return np.array(ssims).mean()
+ elif img1.shape[2] == 1:
+ return ssim(np.squeeze(img1), np.squeeze(img2))
+ else:
+ raise ValueError('Wrong input image dimensions.')
+
+
+def ssim(img1, img2):
+ C1 = (0.01 * 255)**2
+ C2 = (0.03 * 255)**2
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+ mu1_sq = mu1**2
+ mu2_sq = mu2**2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
+ (sigma1_sq + sigma2_sq + C2))
+ return ssim_map.mean()
+
+
+'''
+# --------------------------------------------
+# matlab's bicubic imresize (numpy and torch) [0, 1]
+# --------------------------------------------
+'''
+
+
+# matlab 'imresize' function, now only support 'bicubic'
+def cubic(x):
+ absx = torch.abs(x)
+ absx2 = absx**2
+ absx3 = absx**3
+ return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
+ (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+ if (scale < 1) and (antialiasing):
+ # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
+ kernel_width = kernel_width / scale
+
+ # Output-space coordinates
+ x = torch.linspace(1, out_length, out_length)
+
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
+ # in output space maps to 0.5 in input space, and 0.5+scale in output
+ # space maps to 1.5 in input space.
+ u = x / scale + 0.5 * (1 - 1 / scale)
+
+ # What is the left-most pixel that can be involved in the computation?
+ left = torch.floor(u - kernel_width / 2)
+
+ # What is the maximum number of pixels that can be involved in the
+ # computation? Note: it's OK to use an extra pixel here; if the
+ # corresponding weights are all zero, it will be eliminated at the end
+ # of this function.
+ P = math.ceil(kernel_width) + 2
+
+ # The indices of the input pixels involved in computing the k-th output
+ # pixel are in row k of the indices matrix.
+ indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
+ 1, P).expand(out_length, P)
+
+ # The weights used to compute the k-th output pixel are in row k of the
+ # weights matrix.
+ distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
+ # apply cubic kernel
+ if (scale < 1) and (antialiasing):
+ weights = scale * cubic(distance_to_center * scale)
+ else:
+ weights = cubic(distance_to_center)
+ # Normalize the weights matrix so that each row sums to 1.
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
+ weights = weights / weights_sum.expand(out_length, P)
+
+ # If a column in weights is all zero, get rid of it. only consider the first and last column.
+ weights_zero_tmp = torch.sum((weights == 0), 0)
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 1, P - 2)
+ weights = weights.narrow(1, 1, P - 2)
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 0, P - 2)
+ weights = weights.narrow(1, 0, P - 2)
+ weights = weights.contiguous()
+ indices = indices.contiguous()
+ sym_len_s = -indices.min() + 1
+ sym_len_e = indices.max() - in_length
+ indices = indices + sym_len_s - 1
+ return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+# --------------------------------------------
+# imresize for tensor image [0, 1]
+# --------------------------------------------
+def imresize(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: pytorch tensor, CHW or HW [0,1]
+ # output: CHW or HW [0,1] w/o round
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(0)
+ in_C, in_H, in_W = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
+ img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:, :sym_len_Hs, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[:, -sym_len_He:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(in_C, out_H, in_W)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
+ out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :, :sym_len_Ws]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, :, -sym_len_We:]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(in_C, out_H, out_W)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+ return out_2
+
+
+# --------------------------------------------
+# imresize for numpy image [0, 1]
+# --------------------------------------------
+def imresize_np(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: Numpy, HWC or HW [0,1]
+ # output: HWC or HW [0,1] w/o round
+ img = torch.from_numpy(img)
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(2)
+
+ in_H, in_W, in_C = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
+ img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:sym_len_Hs, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[-sym_len_He:, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(out_H, in_W, in_C)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
+ out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :sym_len_Ws, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, -sym_len_We:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(out_H, out_W, in_C)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+
+ return out_2.numpy()
+
+
+if __name__ == '__main__':
+ print('---')
+# img = imread_uint('test.bmp', 3)
+# img = uint2single(img)
+# img_bicubic = imresize_np(img, 1/4)
\ No newline at end of file
diff --git a/ldm/modules/losses/__init__.py b/ldm/modules/losses/__init__.py
new file mode 100755
index 0000000..876d7c5
--- /dev/null
+++ b/ldm/modules/losses/__init__.py
@@ -0,0 +1 @@
+from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
\ No newline at end of file
diff --git a/ldm/modules/losses/contperceptual.py b/ldm/modules/losses/contperceptual.py
new file mode 100755
index 0000000..672c1e3
--- /dev/null
+++ b/ldm/modules/losses/contperceptual.py
@@ -0,0 +1,111 @@
+import torch
+import torch.nn as nn
+
+from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
+
+
+class LPIPSWithDiscriminator(nn.Module):
+ def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
+ disc_loss="hinge"):
+
+ super().__init__()
+ assert disc_loss in ["hinge", "vanilla"]
+ self.kl_weight = kl_weight
+ self.pixel_weight = pixelloss_weight
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+ # output log variance
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
+
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
+ n_layers=disc_num_layers,
+ use_actnorm=use_actnorm
+ ).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.disc_conditional = disc_conditional
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
+ if last_layer is not None:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+ else:
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
+ global_step, last_layer=None, cond=None, split="train",
+ weights=None):
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
+ weighted_nll_loss = nll_loss
+ if weights is not None:
+ weighted_nll_loss = weights*nll_loss
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+ kl_loss = posteriors.kl()
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ # generator update
+ if cond is None:
+ assert not self.disc_conditional
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ else:
+ assert self.disc_conditional
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
+ g_loss = -torch.mean(logits_fake)
+
+ if self.disc_factor > 0.0:
+ try:
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
+ except RuntimeError:
+ assert not self.training
+ d_weight = torch.tensor(0.0)
+ else:
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
+
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
+ "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
+ "{}/d_weight".format(split): d_weight.detach(),
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
+ "{}/g_loss".format(split): g_loss.detach().mean(),
+ }
+ return loss, log
+
+ if optimizer_idx == 1:
+ # second pass for discriminator update
+ if cond is None:
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+ else:
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
+
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
+ "{}/logits_real".format(split): logits_real.detach().mean(),
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
+ }
+ return d_loss, log
+
diff --git a/ldm/modules/losses/vqperceptual.py b/ldm/modules/losses/vqperceptual.py
new file mode 100755
index 0000000..f699817
--- /dev/null
+++ b/ldm/modules/losses/vqperceptual.py
@@ -0,0 +1,167 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+from einops import repeat
+
+from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
+from taming.modules.losses.lpips import LPIPS
+from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
+
+
+def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
+ assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
+ loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
+ loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
+ loss_real = (weights * loss_real).sum() / weights.sum()
+ loss_fake = (weights * loss_fake).sum() / weights.sum()
+ d_loss = 0.5 * (loss_real + loss_fake)
+ return d_loss
+
+def adopt_weight(weight, global_step, threshold=0, value=0.):
+ if global_step < threshold:
+ weight = value
+ return weight
+
+
+def measure_perplexity(predicted_indices, n_embed):
+ # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
+ # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
+ encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
+ avg_probs = encodings.mean(0)
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
+ cluster_use = torch.sum(avg_probs > 0)
+ return perplexity, cluster_use
+
+def l1(x, y):
+ return torch.abs(x-y)
+
+
+def l2(x, y):
+ return torch.pow((x-y), 2)
+
+
+class VQLPIPSWithDiscriminator(nn.Module):
+ def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
+ disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
+ pixel_loss="l1"):
+ super().__init__()
+ assert disc_loss in ["hinge", "vanilla"]
+ assert perceptual_loss in ["lpips", "clips", "dists"]
+ assert pixel_loss in ["l1", "l2"]
+ self.codebook_weight = codebook_weight
+ self.pixel_weight = pixelloss_weight
+ if perceptual_loss == "lpips":
+ print(f"{self.__class__.__name__}: Running with LPIPS.")
+ self.perceptual_loss = LPIPS().eval()
+ else:
+ raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
+ self.perceptual_weight = perceptual_weight
+
+ if pixel_loss == "l1":
+ self.pixel_loss = l1
+ else:
+ self.pixel_loss = l2
+
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
+ n_layers=disc_num_layers,
+ use_actnorm=use_actnorm,
+ ndf=disc_ndf
+ ).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ if disc_loss == "hinge":
+ self.disc_loss = hinge_d_loss
+ elif disc_loss == "vanilla":
+ self.disc_loss = vanilla_d_loss
+ else:
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
+ print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.disc_conditional = disc_conditional
+ self.n_classes = n_classes
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
+ if last_layer is not None:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+ else:
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
+ global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
+ if not exists(codebook_loss):
+ codebook_loss = torch.tensor([0.]).to(inputs.device)
+ #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
+ rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+ else:
+ p_loss = torch.tensor([0.0])
+
+ nll_loss = rec_loss
+ #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+ nll_loss = torch.mean(nll_loss)
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ # generator update
+ if cond is None:
+ assert not self.disc_conditional
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ else:
+ assert self.disc_conditional
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
+ g_loss = -torch.mean(logits_fake)
+
+ try:
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
+ except RuntimeError:
+ assert not self.training
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
+
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
+ "{}/quant_loss".format(split): codebook_loss.detach().mean(),
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
+ "{}/p_loss".format(split): p_loss.detach().mean(),
+ "{}/d_weight".format(split): d_weight.detach(),
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
+ "{}/g_loss".format(split): g_loss.detach().mean(),
+ }
+ if predicted_indices is not None:
+ assert self.n_classes is not None
+ with torch.no_grad():
+ perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
+ log[f"{split}/perplexity"] = perplexity
+ log[f"{split}/cluster_usage"] = cluster_usage
+ return loss, log
+
+ if optimizer_idx == 1:
+ # second pass for discriminator update
+ if cond is None:
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+ else:
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
+
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
+ "{}/logits_real".format(split): logits_real.detach().mean(),
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
+ }
+ return d_loss, log
diff --git a/ldm/modules/x_transformer.py b/ldm/modules/x_transformer.py
new file mode 100755
index 0000000..5fc15bf
--- /dev/null
+++ b/ldm/modules/x_transformer.py
@@ -0,0 +1,641 @@
+"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+from functools import partial
+from inspect import isfunction
+from collections import namedtuple
+from einops import rearrange, repeat, reduce
+
+# constants
+
+DEFAULT_DIM_HEAD = 64
+
+Intermediates = namedtuple('Intermediates', [
+ 'pre_softmax_attn',
+ 'post_softmax_attn'
+])
+
+LayerIntermediates = namedtuple('Intermediates', [
+ 'hiddens',
+ 'attn_intermediates'
+])
+
+
+class AbsolutePositionalEmbedding(nn.Module):
+ def __init__(self, dim, max_seq_len):
+ super().__init__()
+ self.emb = nn.Embedding(max_seq_len, dim)
+ self.init_()
+
+ def init_(self):
+ nn.init.normal_(self.emb.weight, std=0.02)
+
+ def forward(self, x):
+ n = torch.arange(x.shape[1], device=x.device)
+ return self.emb(n)[None, :, :]
+
+
+class FixedPositionalEmbedding(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer('inv_freq', inv_freq)
+
+ def forward(self, x, seq_dim=1, offset=0):
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
+ sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
+ return emb[None, :, :]
+
+
+# helpers
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def always(val):
+ def inner(*args, **kwargs):
+ return val
+ return inner
+
+
+def not_equals(val):
+ def inner(x):
+ return x != val
+ return inner
+
+
+def equals(val):
+ def inner(x):
+ return x == val
+ return inner
+
+
+def max_neg_value(tensor):
+ return -torch.finfo(tensor.dtype).max
+
+
+# keyword argument helpers
+
+def pick_and_pop(keys, d):
+ values = list(map(lambda key: d.pop(key), keys))
+ return dict(zip(keys, values))
+
+
+def group_dict_by_key(cond, d):
+ return_val = [dict(), dict()]
+ for key in d.keys():
+ match = bool(cond(key))
+ ind = int(not match)
+ return_val[ind][key] = d[key]
+ return (*return_val,)
+
+
+def string_begins_with(prefix, str):
+ return str.startswith(prefix)
+
+
+def group_by_key_prefix(prefix, d):
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
+
+
+def groupby_prefix_and_trim(prefix, d):
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
+ return kwargs_without_prefix, kwargs
+
+
+# classes
+class Scale(nn.Module):
+ def __init__(self, value, fn):
+ super().__init__()
+ self.value = value
+ self.fn = fn
+
+ def forward(self, x, **kwargs):
+ x, *rest = self.fn(x, **kwargs)
+ return (x * self.value, *rest)
+
+
+class Rezero(nn.Module):
+ def __init__(self, fn):
+ super().__init__()
+ self.fn = fn
+ self.g = nn.Parameter(torch.zeros(1))
+
+ def forward(self, x, **kwargs):
+ x, *rest = self.fn(x, **kwargs)
+ return (x * self.g, *rest)
+
+
+class ScaleNorm(nn.Module):
+ def __init__(self, dim, eps=1e-5):
+ super().__init__()
+ self.scale = dim ** -0.5
+ self.eps = eps
+ self.g = nn.Parameter(torch.ones(1))
+
+ def forward(self, x):
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
+ return x / norm.clamp(min=self.eps) * self.g
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim, eps=1e-8):
+ super().__init__()
+ self.scale = dim ** -0.5
+ self.eps = eps
+ self.g = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
+ return x / norm.clamp(min=self.eps) * self.g
+
+
+class Residual(nn.Module):
+ def forward(self, x, residual):
+ return x + residual
+
+
+class GRUGating(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.gru = nn.GRUCell(dim, dim)
+
+ def forward(self, x, residual):
+ gated_output = self.gru(
+ rearrange(x, 'b n d -> (b n) d'),
+ rearrange(residual, 'b n d -> (b n) d')
+ )
+
+ return gated_output.reshape_as(x)
+
+
+# feedforward
+
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(
+ nn.Linear(dim, inner_dim),
+ nn.GELU()
+ ) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(
+ project_in,
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+# attention.
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ dim_head=DEFAULT_DIM_HEAD,
+ heads=8,
+ causal=False,
+ mask=None,
+ talking_heads=False,
+ sparse_topk=None,
+ use_entmax15=False,
+ num_mem_kv=0,
+ dropout=0.,
+ on_attn=False
+ ):
+ super().__init__()
+ if use_entmax15:
+ raise NotImplementedError("Check out entmax activation instead of softmax activation!")
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+ self.causal = causal
+ self.mask = mask
+
+ inner_dim = dim_head * heads
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(dim, inner_dim, bias=False)
+ self.dropout = nn.Dropout(dropout)
+
+ # talking heads
+ self.talking_heads = talking_heads
+ if talking_heads:
+ self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
+ self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
+
+ # explicit topk sparse attention
+ self.sparse_topk = sparse_topk
+
+ # entmax
+ #self.attn_fn = entmax15 if use_entmax15 else F.softmax
+ self.attn_fn = F.softmax
+
+ # add memory key / values
+ self.num_mem_kv = num_mem_kv
+ if num_mem_kv > 0:
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
+
+ # attention on attention
+ self.attn_on_attn = on_attn
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ context_mask=None,
+ rel_pos=None,
+ sinusoidal_emb=None,
+ prev_attn=None,
+ mem=None
+ ):
+ b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
+ kv_input = default(context, x)
+
+ q_input = x
+ k_input = kv_input
+ v_input = kv_input
+
+ if exists(mem):
+ k_input = torch.cat((mem, k_input), dim=-2)
+ v_input = torch.cat((mem, v_input), dim=-2)
+
+ if exists(sinusoidal_emb):
+ # in shortformer, the query would start at a position offset depending on the past cached memory
+ offset = k_input.shape[-2] - q_input.shape[-2]
+ q_input = q_input + sinusoidal_emb(q_input, offset=offset)
+ k_input = k_input + sinusoidal_emb(k_input)
+
+ q = self.to_q(q_input)
+ k = self.to_k(k_input)
+ v = self.to_v(v_input)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
+
+ input_mask = None
+ if any(map(exists, (mask, context_mask))):
+ q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
+ k_mask = q_mask if not exists(context) else context_mask
+ k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
+ q_mask = rearrange(q_mask, 'b i -> b () i ()')
+ k_mask = rearrange(k_mask, 'b j -> b () () j')
+ input_mask = q_mask * k_mask
+
+ if self.num_mem_kv > 0:
+ mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
+ k = torch.cat((mem_k, k), dim=-2)
+ v = torch.cat((mem_v, v), dim=-2)
+ if exists(input_mask):
+ input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
+
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
+ mask_value = max_neg_value(dots)
+
+ if exists(prev_attn):
+ dots = dots + prev_attn
+
+ pre_softmax_attn = dots
+
+ if talking_heads:
+ dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
+
+ if exists(rel_pos):
+ dots = rel_pos(dots)
+
+ if exists(input_mask):
+ dots.masked_fill_(~input_mask, mask_value)
+ del input_mask
+
+ if self.causal:
+ i, j = dots.shape[-2:]
+ r = torch.arange(i, device=device)
+ mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
+ mask = F.pad(mask, (j - i, 0), value=False)
+ dots.masked_fill_(mask, mask_value)
+ del mask
+
+ if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
+ top, _ = dots.topk(self.sparse_topk, dim=-1)
+ vk = top[..., -1].unsqueeze(-1).expand_as(dots)
+ mask = dots < vk
+ dots.masked_fill_(mask, mask_value)
+ del mask
+
+ attn = self.attn_fn(dots, dim=-1)
+ post_softmax_attn = attn
+
+ attn = self.dropout(attn)
+
+ if talking_heads:
+ attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
+
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
+ out = rearrange(out, 'b h n d -> b n (h d)')
+
+ intermediates = Intermediates(
+ pre_softmax_attn=pre_softmax_attn,
+ post_softmax_attn=post_softmax_attn
+ )
+
+ return self.to_out(out), intermediates
+
+
+class AttentionLayers(nn.Module):
+ def __init__(
+ self,
+ dim,
+ depth,
+ heads=8,
+ causal=False,
+ cross_attend=False,
+ only_cross=False,
+ use_scalenorm=False,
+ use_rmsnorm=False,
+ use_rezero=False,
+ rel_pos_num_buckets=32,
+ rel_pos_max_distance=128,
+ position_infused_attn=False,
+ custom_layers=None,
+ sandwich_coef=None,
+ par_ratio=None,
+ residual_attn=False,
+ cross_residual_attn=False,
+ macaron=False,
+ pre_norm=True,
+ gate_residual=False,
+ **kwargs
+ ):
+ super().__init__()
+ ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
+ attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
+
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
+
+ self.dim = dim
+ self.depth = depth
+ self.layers = nn.ModuleList([])
+
+ self.has_pos_emb = position_infused_attn
+ self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
+ self.rotary_pos_emb = always(None)
+
+ assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
+ self.rel_pos = None
+
+ self.pre_norm = pre_norm
+
+ self.residual_attn = residual_attn
+ self.cross_residual_attn = cross_residual_attn
+
+ norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
+ norm_class = RMSNorm if use_rmsnorm else norm_class
+ norm_fn = partial(norm_class, dim)
+
+ norm_fn = nn.Identity if use_rezero else norm_fn
+ branch_fn = Rezero if use_rezero else None
+
+ if cross_attend and not only_cross:
+ default_block = ('a', 'c', 'f')
+ elif cross_attend and only_cross:
+ default_block = ('c', 'f')
+ else:
+ default_block = ('a', 'f')
+
+ if macaron:
+ default_block = ('f',) + default_block
+
+ if exists(custom_layers):
+ layer_types = custom_layers
+ elif exists(par_ratio):
+ par_depth = depth * len(default_block)
+ assert 1 < par_ratio <= par_depth, 'par ratio out of range'
+ default_block = tuple(filter(not_equals('f'), default_block))
+ par_attn = par_depth // par_ratio
+ depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
+ assert len(default_block) <= par_width, 'default block is too large for par_ratio'
+ par_block = default_block + ('f',) * (par_width - len(default_block))
+ par_head = par_block * par_attn
+ layer_types = par_head + ('f',) * (par_depth - len(par_head))
+ elif exists(sandwich_coef):
+ assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
+ layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
+ else:
+ layer_types = default_block * depth
+
+ self.layer_types = layer_types
+ self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
+
+ for layer_type in self.layer_types:
+ if layer_type == 'a':
+ layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
+ elif layer_type == 'c':
+ layer = Attention(dim, heads=heads, **attn_kwargs)
+ elif layer_type == 'f':
+ layer = FeedForward(dim, **ff_kwargs)
+ layer = layer if not macaron else Scale(0.5, layer)
+ else:
+ raise Exception(f'invalid layer type {layer_type}')
+
+ if isinstance(layer, Attention) and exists(branch_fn):
+ layer = branch_fn(layer)
+
+ if gate_residual:
+ residual_fn = GRUGating(dim)
+ else:
+ residual_fn = Residual()
+
+ self.layers.append(nn.ModuleList([
+ norm_fn(),
+ layer,
+ residual_fn
+ ]))
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ context_mask=None,
+ mems=None,
+ return_hiddens=False
+ ):
+ hiddens = []
+ intermediates = []
+ prev_attn = None
+ prev_cross_attn = None
+
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
+
+ for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
+ is_last = ind == (len(self.layers) - 1)
+
+ if layer_type == 'a':
+ hiddens.append(x)
+ layer_mem = mems.pop(0)
+
+ residual = x
+
+ if self.pre_norm:
+ x = norm(x)
+
+ if layer_type == 'a':
+ out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
+ prev_attn=prev_attn, mem=layer_mem)
+ elif layer_type == 'c':
+ out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
+ elif layer_type == 'f':
+ out = block(x)
+
+ x = residual_fn(out, residual)
+
+ if layer_type in ('a', 'c'):
+ intermediates.append(inter)
+
+ if layer_type == 'a' and self.residual_attn:
+ prev_attn = inter.pre_softmax_attn
+ elif layer_type == 'c' and self.cross_residual_attn:
+ prev_cross_attn = inter.pre_softmax_attn
+
+ if not self.pre_norm and not is_last:
+ x = norm(x)
+
+ if return_hiddens:
+ intermediates = LayerIntermediates(
+ hiddens=hiddens,
+ attn_intermediates=intermediates
+ )
+
+ return x, intermediates
+
+ return x
+
+
+class Encoder(AttentionLayers):
+ def __init__(self, **kwargs):
+ assert 'causal' not in kwargs, 'cannot set causality on encoder'
+ super().__init__(causal=False, **kwargs)
+
+
+
+class TransformerWrapper(nn.Module):
+ def __init__(
+ self,
+ *,
+ num_tokens,
+ max_seq_len,
+ attn_layers,
+ emb_dim=None,
+ max_mem_len=0.,
+ emb_dropout=0.,
+ num_memory_tokens=None,
+ tie_embedding=False,
+ use_pos_emb=True
+ ):
+ super().__init__()
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
+
+ dim = attn_layers.dim
+ emb_dim = default(emb_dim, dim)
+
+ self.max_seq_len = max_seq_len
+ self.max_mem_len = max_mem_len
+ self.num_tokens = num_tokens
+
+ self.token_emb = nn.Embedding(num_tokens, emb_dim)
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
+ use_pos_emb and not attn_layers.has_pos_emb) else always(0)
+ self.emb_dropout = nn.Dropout(emb_dropout)
+
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
+ self.attn_layers = attn_layers
+ self.norm = nn.LayerNorm(dim)
+
+ self.init_()
+
+ self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
+
+ # memory tokens (like [cls]) from Memory Transformers paper
+ num_memory_tokens = default(num_memory_tokens, 0)
+ self.num_memory_tokens = num_memory_tokens
+ if num_memory_tokens > 0:
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
+
+ # let funnel encoder know number of memory tokens, if specified
+ if hasattr(attn_layers, 'num_memory_tokens'):
+ attn_layers.num_memory_tokens = num_memory_tokens
+
+ def init_(self):
+ nn.init.normal_(self.token_emb.weight, std=0.02)
+
+ def forward(
+ self,
+ x,
+ return_embeddings=False,
+ mask=None,
+ return_mems=False,
+ return_attn=False,
+ mems=None,
+ **kwargs
+ ):
+ b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
+ x = self.token_emb(x)
+ x += self.pos_emb(x)
+ x = self.emb_dropout(x)
+
+ x = self.project_emb(x)
+
+ if num_mem > 0:
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
+ x = torch.cat((mem, x), dim=1)
+
+ # auto-handle masking after appending memory tokens
+ if exists(mask):
+ mask = F.pad(mask, (num_mem, 0), value=True)
+
+ x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
+ x = self.norm(x)
+
+ mem, x = x[:, :num_mem], x[:, num_mem:]
+
+ out = self.to_logits(x) if not return_embeddings else x
+
+ if return_mems:
+ hiddens = intermediates.hiddens
+ new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
+ return out, new_mems
+
+ if return_attn:
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
+ return out, attn_maps
+
+ return out
+
diff --git a/ldm/thirdp/psp/helpers.py b/ldm/thirdp/psp/helpers.py
new file mode 100755
index 0000000..983baaa
--- /dev/null
+++ b/ldm/thirdp/psp/helpers.py
@@ -0,0 +1,121 @@
+# https://github.com/eladrich/pixel2style2pixel
+
+from collections import namedtuple
+import torch
+from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
+
+"""
+ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
+"""
+
+
+class Flatten(Module):
+ def forward(self, input):
+ return input.view(input.size(0), -1)
+
+
+def l2_norm(input, axis=1):
+ norm = torch.norm(input, 2, axis, True)
+ output = torch.div(input, norm)
+ return output
+
+
+class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
+ """ A named tuple describing a ResNet block. """
+
+
+def get_block(in_channel, depth, num_units, stride=2):
+ return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
+
+
+def get_blocks(num_layers):
+ if num_layers == 50:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=3),
+ get_block(in_channel=64, depth=128, num_units=4),
+ get_block(in_channel=128, depth=256, num_units=14),
+ get_block(in_channel=256, depth=512, num_units=3)
+ ]
+ elif num_layers == 100:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=3),
+ get_block(in_channel=64, depth=128, num_units=13),
+ get_block(in_channel=128, depth=256, num_units=30),
+ get_block(in_channel=256, depth=512, num_units=3)
+ ]
+ elif num_layers == 152:
+ blocks = [
+ get_block(in_channel=64, depth=64, num_units=3),
+ get_block(in_channel=64, depth=128, num_units=8),
+ get_block(in_channel=128, depth=256, num_units=36),
+ get_block(in_channel=256, depth=512, num_units=3)
+ ]
+ else:
+ raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
+ return blocks
+
+
+class SEModule(Module):
+ def __init__(self, channels, reduction):
+ super(SEModule, self).__init__()
+ self.avg_pool = AdaptiveAvgPool2d(1)
+ self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
+ self.relu = ReLU(inplace=True)
+ self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
+ self.sigmoid = Sigmoid()
+
+ def forward(self, x):
+ module_input = x
+ x = self.avg_pool(x)
+ x = self.fc1(x)
+ x = self.relu(x)
+ x = self.fc2(x)
+ x = self.sigmoid(x)
+ return module_input * x
+
+
+class bottleneck_IR(Module):
+ def __init__(self, in_channel, depth, stride):
+ super(bottleneck_IR, self).__init__()
+ if in_channel == depth:
+ self.shortcut_layer = MaxPool2d(1, stride)
+ else:
+ self.shortcut_layer = Sequential(
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
+ BatchNorm2d(depth)
+ )
+ self.res_layer = Sequential(
+ BatchNorm2d(in_channel),
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
+ )
+
+ def forward(self, x):
+ shortcut = self.shortcut_layer(x)
+ res = self.res_layer(x)
+ return res + shortcut
+
+
+class bottleneck_IR_SE(Module):
+ def __init__(self, in_channel, depth, stride):
+ super(bottleneck_IR_SE, self).__init__()
+ if in_channel == depth:
+ self.shortcut_layer = MaxPool2d(1, stride)
+ else:
+ self.shortcut_layer = Sequential(
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
+ BatchNorm2d(depth)
+ )
+ self.res_layer = Sequential(
+ BatchNorm2d(in_channel),
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
+ PReLU(depth),
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
+ BatchNorm2d(depth),
+ SEModule(depth, 16)
+ )
+
+ def forward(self, x):
+ shortcut = self.shortcut_layer(x)
+ res = self.res_layer(x)
+ return res + shortcut
\ No newline at end of file
diff --git a/ldm/thirdp/psp/id_loss.py b/ldm/thirdp/psp/id_loss.py
new file mode 100755
index 0000000..e08ee09
--- /dev/null
+++ b/ldm/thirdp/psp/id_loss.py
@@ -0,0 +1,23 @@
+# https://github.com/eladrich/pixel2style2pixel
+import torch
+from torch import nn
+from ldm.thirdp.psp.model_irse import Backbone
+
+
+class IDFeatures(nn.Module):
+ def __init__(self, model_path):
+ super(IDFeatures, self).__init__()
+ print('Loading ResNet ArcFace')
+ self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
+ self.facenet.load_state_dict(torch.load(model_path, map_location="cpu"))
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
+ self.facenet.eval()
+
+ def forward(self, x, crop=False):
+ # Not sure of the image range here
+ if crop:
+ x = torch.nn.functional.interpolate(x, (256, 256), mode="area")
+ x = x[:, :, 35:223, 32:220]
+ x = self.face_pool(x)
+ x_feats = self.facenet(x)
+ return x_feats
diff --git a/ldm/thirdp/psp/model_irse.py b/ldm/thirdp/psp/model_irse.py
new file mode 100755
index 0000000..21cedd2
--- /dev/null
+++ b/ldm/thirdp/psp/model_irse.py
@@ -0,0 +1,86 @@
+# https://github.com/eladrich/pixel2style2pixel
+
+from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
+from ldm.thirdp.psp.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
+
+"""
+Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
+"""
+
+
+class Backbone(Module):
+ def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
+ super(Backbone, self).__init__()
+ assert input_size in [112, 224], "input_size should be 112 or 224"
+ assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
+ assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
+ blocks = get_blocks(num_layers)
+ if mode == 'ir':
+ unit_module = bottleneck_IR
+ elif mode == 'ir_se':
+ unit_module = bottleneck_IR_SE
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
+ BatchNorm2d(64),
+ PReLU(64))
+ if input_size == 112:
+ self.output_layer = Sequential(BatchNorm2d(512),
+ Dropout(drop_ratio),
+ Flatten(),
+ Linear(512 * 7 * 7, 512),
+ BatchNorm1d(512, affine=affine))
+ else:
+ self.output_layer = Sequential(BatchNorm2d(512),
+ Dropout(drop_ratio),
+ Flatten(),
+ Linear(512 * 14 * 14, 512),
+ BatchNorm1d(512, affine=affine))
+
+ modules = []
+ for block in blocks:
+ for bottleneck in block:
+ modules.append(unit_module(bottleneck.in_channel,
+ bottleneck.depth,
+ bottleneck.stride))
+ self.body = Sequential(*modules)
+
+ def forward(self, x):
+ x = self.input_layer(x)
+ x = self.body(x)
+ x = self.output_layer(x)
+ return l2_norm(x)
+
+
+def IR_50(input_size):
+ """Constructs a ir-50 model."""
+ model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
+ return model
+
+
+def IR_101(input_size):
+ """Constructs a ir-101 model."""
+ model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
+ return model
+
+
+def IR_152(input_size):
+ """Constructs a ir-152 model."""
+ model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
+ return model
+
+
+def IR_SE_50(input_size):
+ """Constructs a ir_se-50 model."""
+ model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
+ return model
+
+
+def IR_SE_101(input_size):
+ """Constructs a ir_se-101 model."""
+ model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
+ return model
+
+
+def IR_SE_152(input_size):
+ """Constructs a ir_se-152 model."""
+ model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
+ return model
\ No newline at end of file
diff --git a/ldm/util.py b/ldm/util.py
new file mode 100755
index 0000000..7dcad70
--- /dev/null
+++ b/ldm/util.py
@@ -0,0 +1,227 @@
+import importlib
+
+import torchvision
+import torch
+from torch import optim
+import numpy as np
+
+from inspect import isfunction
+from PIL import Image, ImageDraw, ImageFont
+
+import os
+import numpy as np
+import matplotlib.pyplot as plt
+from PIL import Image
+import torch
+import time
+import cv2
+
+import PIL
+
+def pil_rectangle_crop(im):
+ width, height = im.size # Get dimensions
+
+ if width <= height:
+ left = 0
+ right = width
+ top = (height - width)/2
+ bottom = (height + width)/2
+ else:
+
+ top = 0
+ bottom = height
+ left = (width - height) / 2
+ bottom = (width + height) / 2
+
+ # Crop the center of the image
+ im = im.crop((left, top, right, bottom))
+ return im
+
+
+def log_txt_as_img(wh, xc, size=10):
+ # wh a tuple of (width, height)
+ # xc a list of captions to plot
+ b = len(xc)
+ txts = list()
+ for bi in range(b):
+ txt = Image.new("RGB", wh, color="white")
+ draw = ImageDraw.Draw(txt)
+ font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
+ nc = int(40 * (wh[0] / 256))
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
+
+ try:
+ draw.text((0, 0), lines, fill="black", font=font)
+ except UnicodeEncodeError:
+ print("Cant encode string for logging. Skipping.")
+
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+ txts.append(txt)
+ txts = np.stack(txts)
+ txts = torch.tensor(txts)
+ return txts
+
+
+def ismap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+ if not isinstance(x,torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def exists(x):
+ return x is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+ """
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
+ return total_params
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == '__is_first_stage__':
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+class AdamWwithEMAandWings(optim.Optimizer):
+ # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
+ def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
+ weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
+ ema_power=1., param_names=()):
+ """AdamW that saves EMA versions of the parameters."""
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ if not 0.0 <= ema_decay <= 1.0:
+ raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
+ defaults = dict(lr=lr, betas=betas, eps=eps,
+ weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
+ ema_power=ema_power, param_names=param_names)
+ super().__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super().__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault('amsgrad', False)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ params_with_grad = []
+ grads = []
+ exp_avgs = []
+ exp_avg_sqs = []
+ ema_params_with_grad = []
+ state_sums = []
+ max_exp_avg_sqs = []
+ state_steps = []
+ amsgrad = group['amsgrad']
+ beta1, beta2 = group['betas']
+ ema_decay = group['ema_decay']
+ ema_power = group['ema_power']
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ params_with_grad.append(p)
+ if p.grad.is_sparse:
+ raise RuntimeError('AdamW does not support sparse gradients')
+ grads.append(p.grad)
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+ if amsgrad:
+ # Maintains max of all exp. moving avg. of sq. grad. values
+ state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+ # Exponential moving average of parameter values
+ state['param_exp_avg'] = p.detach().float().clone()
+
+ exp_avgs.append(state['exp_avg'])
+ exp_avg_sqs.append(state['exp_avg_sq'])
+ ema_params_with_grad.append(state['param_exp_avg'])
+
+ if amsgrad:
+ max_exp_avg_sqs.append(state['max_exp_avg_sq'])
+
+ # update the steps for each param group update
+ state['step'] += 1
+ # record the step after step update
+ state_steps.append(state['step'])
+
+ optim._functional.adamw(params_with_grad,
+ grads,
+ exp_avgs,
+ exp_avg_sqs,
+ max_exp_avg_sqs,
+ state_steps,
+ amsgrad=amsgrad,
+ beta1=beta1,
+ beta2=beta2,
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ maximize=False)
+
+ cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
+ for param, ema_param in zip(params_with_grad, ema_params_with_grad):
+ ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
+
+ return loss
\ No newline at end of file
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..4e60676
--- /dev/null
+++ b/main.py
@@ -0,0 +1,630 @@
+import torch
+import argparse
+import sys
+import os
+import pandas as pd
+
+from nerf.provider import NeRFDataset, generate_grid_points
+from nerf.utils import *
+
+import yaml
+from easydict import EasyDict as edict
+import dnnultis
+import logging
+
+logger = logging.getLogger(__name__)
+
+# The first arg parser parses out only the --config argument, this argument is used to
+# load a yaml file containing key-values that override the defaults for the main parser below
+config_parser = parser = argparse.ArgumentParser(
+ description='Training Config', add_help=False)
+parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
+ help='YAML config file specifying default arguments')
+parser = argparse.ArgumentParser(description='3D AIGC Training')
+parser.add_argument('--workspace', type=str, default='', help='path to log')
+parser.add_argument('--text', default=None, help="text prompt")
+parser.add_argument('--negative', default='', type=str,
+ help="negative text prompt")
+parser.add_argument('--dir_texts_neg', action='store_true',
+ help="enable negative directional text")
+parser.add_argument('--check_prompt', action='store_true', help="check prompt")
+parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray")
+parser.add_argument('-O2', action='store_true',
+ help="equals --backbone vanilla")
+parser.add_argument('--test', action='store_true', help="test mode")
+parser.add_argument('--six_views', action='store_true',
+ help="six_views mode: save the images of the six views")
+parser.add_argument('--eval_interval', type=int, default=1,
+ help="evaluate on the valid set every interval epochs")
+parser.add_argument('--test_interval', type=int, default=50,
+ help="test on the test set every interval epochs")
+parser.add_argument('--seed', type=int, default=101)
+parser.add_argument('--log_every', type=int, default=20,
+ help="log losses every X iterations")
+parser.add_argument('--use_wandb', action='store_true',
+ help="log online into wandb")
+
+# guidance
+parser.add_argument('--guidance', type=str, nargs='*',
+ default=['SD'], help='guidance model')
+parser.add_argument('--guidance_scale', type=float, nargs='*', default=[100],
+ help="diffusion model classifier-free guidance scale")
+parser.add_argument('--gudiance_spatial_weighting',
+ action='store_true', help="add spatial weighting to guidance")
+parser.add_argument('--save_train_every', type=int,
+ default=-1, help="save sds guidance")
+
+# clip guidance
+# lambda_clip, set to 1 if use clip loss outside sds
+parser.add_argument('--lambda_clip', type=float, default=0,
+ help="loss scale for clip loss outside sds")
+# set to 100 if use clip guidance in sds
+parser.add_argument('--clip_version', type=str,
+ default='large', help="clip version, large is ued in stable diffusion")
+parser.add_argument('--clip_guidance', type=float, default=0,
+ help="diffusion model classifier-free guidance scale")
+parser.add_argument('--clip_t', type=float, default=0.4,
+ help="time step thresh started to use clip")
+parser.add_argument('--clip_iterative', action='store_true',
+ help="use clipd iteratively with sds")
+parser.add_argument('--clip_image_loss', action='store_true',
+ help="use image as reference in clip")
+parser.add_argument('--save_guidance_every', type=int,
+ default=-1, help="save sds guidance")
+
+# 3D prior: Shap-E. Does not work.
+parser.add_argument('--use_shape', action='store_true',
+ help="enable shap-e initization")
+parser.add_argument('--shape_guidance', type=float, default=3,
+ help="guidance scaling for shap-e prior")
+parser.add_argument('--shape_radius', type=float, default=4,
+ help="camera raidus for shap-e prior")
+parser.add_argument('--shape_fovy', type=float, default=40,
+ help="fov for shap-e prior")
+parser.add_argument('--shape_no_color', action='store_false',
+ dest='shape_init_color', help="do not use shap-E color for initization")
+parser.add_argument('--shape_rpst', type=str, default='sdf',
+ help="use sdf to init NeRF/mesh by default")
+
+# image options.
+parser.add_argument('--image', default=None, help="image prompt")
+parser.add_argument('--image_config', default=None, help="image config csv")
+parser.add_argument('--learned_embeds_path', type=str,
+ default=None, help="path to learned embeds of the given image")
+parser.add_argument('--known_iters', type=int, default=100,
+ help="loss scale for alpha entropy")
+parser.add_argument('--known_view_interval', type=int, default=4,
+ help="do reconstruction every X iterations to save on compute")
+parser.add_argument('--bg_color_known', type=str,
+ default=None, help='pixelnoise, noise, None') # pixelnoise
+parser.add_argument('--known_shading', type=str, default='lambertian')
+
+# DMTet and Mesh options
+parser.add_argument('--save_mesh', action='store_true',
+ help="export an obj mesh with texture")
+parser.add_argument('--mcubes_resolution', type=int, default=256,
+ help="mcubes resolution for extracting mesh")
+parser.add_argument('--decimate_target', type=int, default=5e4,
+ help="target face number for mesh decimation")
+parser.add_argument('--dmtet', action='store_true',
+ help="use dmtet finetuning")
+parser.add_argument('--tet_mlp', action='store_true',
+ help="use tet_mlp finetuning")
+parser.add_argument('--base_mesh', default=None,
+ help="base mesh for dmtet init")
+parser.add_argument('--tet_grid_size', type=int,
+ default=256, help="tet grid size")
+parser.add_argument('--init_ckpt', type=str, default='',
+ help="ckpt to init dmtet")
+parser.add_argument('--lock_geo', action='store_true',
+ help="disable dmtet to learn geometry")
+
+# training options
+parser.add_argument('--iters', type=int, default=5000, help="training iters")
+parser.add_argument('--lr', type=float, default=1e-3, help="max learning rate")
+parser.add_argument('--lr_scale_nerf', type=float,
+ default=1, help="max learning rate")
+parser.add_argument('--lr_scale_texture', type=float,
+ default=1, help="max learning rate")
+parser.add_argument('--ckpt', type=str, default='latest')
+parser.add_argument('--cuda_ray', action='store_true',
+ help="use CUDA raymarching instead of pytorch")
+parser.add_argument('--taichi_ray', action='store_true',
+ help="use taichi raymarching")
+parser.add_argument('--max_steps', type=int, default=1024,
+ help="max num steps sampled per ray (only valid when using --cuda_ray)")
+parser.add_argument('--num_steps', type=int, default=64,
+ help="num steps sampled per ray (only valid when not using --cuda_ray)")
+parser.add_argument('--upsample_steps', type=int, default=32,
+ help="num steps up-sampled per ray (only valid when not using --cuda_ray)")
+parser.add_argument('--update_extra_interval', type=int, default=16,
+ help="iter interval to update extra status (only valid when using --cuda_ray)")
+parser.add_argument('--max_ray_batch', type=int, default=4096,
+ help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)")
+parser.add_argument('--latent_iter_ratio', type=float, default=0.0,
+ help="training iters that only use latent normal shading")
+parser.add_argument('--normal_iter_ratio', type=float, default=0.0,
+ help="training iters that only use normal shading")
+parser.add_argument('--textureless_iter_ratio', type=float, default=0.0,
+ help="training iters that only use textureless shading")
+parser.add_argument('--albedo_iter_ratio', type=float, default=0,
+ help="training iters that only use albedo shading")
+parser.add_argument('--warmup_bg_color', type=str, default=None,
+ help="bg color [None | pixelnoise | noise | white]")
+parser.add_argument('--bg_color', type=str, default=None)
+parser.add_argument('--bg_color_test', default='white')
+parser.add_argument('--ema_decay', type=float, default=0.95,
+ help="exponential moving average of model weights")
+parser.add_argument('--jitter_pose', action='store_true',
+ help="add jitters to the randomly sampled camera poses")
+parser.add_argument('--jitter_center', type=float, default=0.2,
+ help="amount of jitter to add to sampled camera pose's center (camera location)")
+parser.add_argument('--jitter_target', type=float, default=0.2,
+ help="amount of jitter to add to sampled camera pose's target (i.e. 'look-at')")
+parser.add_argument('--jitter_up', type=float, default=0.02,
+ help="amount of jitter to add to sampled camera pose's up-axis (i.e. 'camera roll')")
+parser.add_argument('--uniform_sphere_rate', type=float, default=0.5,
+ help="likelihood of sampling camera location uniformly on the sphere surface area")
+parser.add_argument('--grad_clip', type=float, default=-1,
+ help="clip grad of all grad to this limit, negative value disables it")
+parser.add_argument('--grad_clip_rgb', type=float, default=-1,
+ help="clip grad of rgb space grad to this limit, negative value disables it")
+parser.add_argument('--grid_levels_mask', type=int, default=8,
+ help="the number of levels in the feature grid to mask (to disable use 0)")
+parser.add_argument('--grid_levels_mask_iters', type=int, default=3000,
+ help="the number of iterations for feature grid masking (to disable use 0)")
+
+# model options
+parser.add_argument('--bg_radius', type=float, default=1.4,
+ help="if positive, use a background model at sphere(bg_radius)")
+parser.add_argument('--density_activation', type=str, default='exp',
+ choices=['softplus', 'exp', 'relu'], help="density activation function")
+parser.add_argument('--density_thresh', type=float, default=10,
+ help="threshold for density grid to be occupied")
+# add more strength to the center, believe the center is more likely to have objects.
+parser.add_argument('--blob_density', type=float, default=10,
+ help="max (center) density for the density blob")
+parser.add_argument('--blob_radius', type=float, default=0.2,
+ help="control the radius for the density blob")
+# network backbone
+parser.add_argument('--backbone', type=str, default='grid',
+ choices=['grid', 'vanilla', 'grid_taichi'], help="nerf backbone")
+parser.add_argument('--grid_type', type=str,
+ default='hashgrid', help="grid type")
+parser.add_argument('--hidden_dim_bg', type=int, default=32,
+ help="channels for background network")
+parser.add_argument('--optim', type=str, default='adam',
+ choices=['adan', 'adam'], help="optimizer")
+parser.add_argument('--sd_version', type=str, default='1.5',
+ choices=['1.5', '2.0', '2.1'], help="stable diffusion version")
+parser.add_argument('--hf_key', type=str, default=None,
+ help="hugging face Stable diffusion model key")
+# try this if CUDA OOM
+parser.add_argument('--fp16', action='store_true',
+ help="use float16 for training")
+parser.add_argument('--vram_O', action='store_true',
+ help="optimization for low VRAM usage")
+# rendering resolution in training, increase these for better quality / decrease these if CUDA OOM even if --vram_O enabled.
+parser.add_argument('--w', type=int, default=128,
+ help="render width for NeRF in training")
+parser.add_argument('--h', type=int, default=128,
+ help="render height for NeRF in training")
+parser.add_argument('--known_view_scale', type=float, default=1.5,
+ help="multiply --h/w by this for known view rendering")
+parser.add_argument('--known_view_noise_scale', type=float, default=1e-3,
+ help="random camera noise added to rays_o and rays_d")
+parser.add_argument('--noise_known_camera_annealing', action='store_true',
+ help="anneal the noise to zero over the coarse of training")
+parser.add_argument('--dmtet_reso_scale', type=float, default=8,
+ help="multiply --h/w by this for dmtet finetuning")
+parser.add_argument('--rm_edge', action='store_true',
+ help="remove edge (ideally only enale for high resolution cases)")
+parser.add_argument('--edge_threshold', type=float, default=0.1,
+ help="remove edges with value > threshold")
+parser.add_argument('--edge_width', type=float, default=5,
+ help="edge width")
+parser.add_argument('--batch_size', type=int, default=1,
+ help="images to render per batch using NeRF")
+
+# dataset options
+parser.add_argument('--bound', type=float, default=1.0,
+ help="assume the scene is bounded in box(-bound, bound)")
+parser.add_argument('--dt_gamma', type=float, default=0,
+ help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)")
+parser.add_argument('--min_near', type=float, default=0.1,
+ help="minimum near distance for camera")
+
+parser.add_argument('--radius_range', type=float, nargs='*',
+ default=[1.8, 1.8], help="training camera radius range")
+parser.add_argument('--theta_range', type=float, nargs='*',
+ default=[45, 135], help="training camera elevation/polar range, 90 is front")
+parser.add_argument('--phi_range', type=float, nargs='*',
+ default=[-180, 180], help="training camera azimuth range")
+parser.add_argument('--fovy_range', type=float, nargs='*',
+ default=[40, 40], help="training camera fovy range")
+
+parser.add_argument('--default_radius', type=float, default=1.8,
+ help="radius for the default view")
+parser.add_argument('--default_polar', type=float,
+ default=90, help="polar for the default view")
+parser.add_argument('--default_azimuth', type=float,
+ default=0, help="azimuth for the default view")
+parser.add_argument('--default_fovy', type=float, default=40,
+ help="fovy for the default view")
+
+parser.add_argument('--progressive_view', action='store_true',
+ help="progressively expand view sampling range from default to full")
+parser.add_argument('--progressive_level', action='store_true',
+ help="progressively increase gridencoder's max_level")
+
+parser.add_argument('--angle_overhead', type=float, default=30,
+ help="[0, angle_overhead] is the overhead region")
+parser.add_argument('--angle_front', type=float, default=60,
+ help="[0, angle_front] is the front region, [180, 180+angle_front] the back region, otherwise the side region.")
+parser.add_argument('--t_range', type=float, nargs='*',
+ default=[0.2, 0.6], help="stable diffusion time steps range")
+
+# regularizations
+parser.add_argument('--lambda_entropy', type=float, default=1e-3,
+ help="loss scale for alpha entropy, favors 0 or 1")
+# Try increasing/decreasing lambda_opacity if your scene is stuffed with floaters/becoming empty.
+parser.add_argument('--lambda_opacity', type=float, default=0.,
+ help="loss scale for alpha value, avoid uncessary filling")
+# Try increasing/decreasing lambda_orient if you object is foggy/over-smoothed.
+parser.add_argument('--lambda_orient', type=float,
+ default=1e-2, help="loss scale for orientation")
+parser.add_argument('--lambda_tv', type=float, default=0,
+ help="loss scale for total variation of grad")
+parser.add_argument('--lambda_wd', type=float, default=0,
+ help="loss scale for weight decay of grad")
+parser.add_argument('--lambda_normal_smooth', type=float, default=0.5,
+ help="loss scale for first-order 2D normal image smoothness")
+parser.add_argument('--lambda_normal_smooth2d', type=float, default=0.5,
+ help="loss scale for second-order 2D normal image smoothness")
+parser.add_argument('--lambda_3d_normal_smooth', type=float, default=0.0,
+ help="loss scale for second-order 2D normal image smoothness")
+parser.add_argument('--lambda_guidance', type=float, nargs='*',
+ default=[1], help="loss scale for SDS")
+parser.add_argument('--lambda_rgb', type=float,
+ default=5, help="loss scale for RGB")
+parser.add_argument('--lambda_mask', type=float, default=0.5,
+ help="loss scale for mask (alpha)")
+parser.add_argument('--lambda_depth', type=float, default=0.01,
+ help="loss scale for relative depth of the known view")
+parser.add_argument('--lambda_normal', type=float,
+ default=0.0, help="loss scale for normals of the known view")
+parser.add_argument('--lambda_depth_mse', type=float, default=0.0,
+ help="loss scale for depth of the known view")
+parser.add_argument('--no_normalize_depth', action='store_false', dest='normalize_depth', help="normalize depth")
+
+# for DMTet
+parser.add_argument('--lambda_mesh_normal', type=float,
+ default=0.1, help="loss scale for mesh normal smoothness")
+parser.add_argument('--lambda_mesh_lap', type=float,
+ default=0.1, help="loss scale for mesh laplacian")
+
+# GUI options
+parser.add_argument('--gui', action='store_true', help="start a GUI")
+parser.add_argument('--W', type=int, default=800, help="GUI width")
+parser.add_argument('--H', type=int, default=800, help="GUI height")
+parser.add_argument('--radius', type=float, default=1.8,
+ help="default GUI camera radius from center")
+parser.add_argument('--fovy', type=float, default=40,
+ help="default GUI camera fovy")
+parser.add_argument('--light_theta', type=float, default=60,
+ help="default GUI light direction in [0, 180], corresponding to elevation [90, -90]")
+parser.add_argument('--light_phi', type=float, default=0,
+ help="default GUI light direction in [0, 360), azimuth")
+parser.add_argument('--max_spp', type=int, default=1,
+ help="GUI rendering max sample per pixel")
+parser.add_argument('--zero123_config', type=str,
+ default='./pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml', help="config file for zero123")
+parser.add_argument('--zero123_ckpt', type=str,
+ default='./pretrained/zero123/105000.ckpt', help="ckpt for zero123")
+parser.add_argument('--zero123_grad_scale', type=str, default='angle',
+ help="whether to scale the gradients based on 'angle' or 'None'")
+
+parser.add_argument('--dataset_size_train', type=int, default=100,
+ help="Length of train dataset i.e. # of iterations per epoch")
+parser.add_argument('--dataset_size_valid', type=int, default=8,
+ help="# of frames to render in the turntable video in validation")
+parser.add_argument('--dataset_size_test', type=int, default=100,
+ help="# of frames to render in the turntable video at test time")
+
+
+def _parse_args():
+ args_config, remaining = config_parser.parse_known_args()
+ if args_config.config:
+ with open(args_config.config, 'r') as f:
+ cfg = yaml.safe_load(f)
+ parser.set_defaults(**cfg)
+
+ # The main arg parser parses the rest of the args, the usual
+ # defaults will have been overridden if config file specified.
+ args = parser.parse_args(remaining)
+
+ # Cache the args as a text string to save them in the output dir later
+ args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
+ return args, args_text
+
+
+if __name__ == '__main__':
+ args, args_text = _parse_args()
+ opt = edict(vars(args))
+
+ if opt.O:
+ opt.fp16 = True
+ opt.cuda_ray = True
+
+ elif opt.O2:
+ opt.fp16 = True
+ opt.backbone = 'vanilla'
+
+ opt.images, opt.ref_radii, opt.ref_polars, opt.ref_azimuths, opt.zero123_ws = [], [], [], [], []
+ opt.default_zero123_w = 1
+
+ # parameters for image-conditioned generation
+ if opt.image is not None or opt.image_config is not None:
+ if 'zero123' in opt.guidance:
+ # fix fov as zero123 doesn't support changing fov
+ opt.fovy_range = [opt.default_fovy, opt.default_fovy]
+ else:
+ opt.known_view_interval = 2
+
+ if 'SD' in opt.guidance:
+ opt.t_range = [0.2, 0.6]
+ opt.bg_radius = -1
+
+ # latent warmup is not needed
+ opt.latent_iter_ratio = 0
+ opt.albedo_iter_ratio = 0
+
+ if opt.image is not None:
+ opt.images += [opt.image]
+ opt.ref_radii += [opt.default_radius]
+ opt.ref_polars += [opt.default_polar]
+ opt.ref_azimuths += [opt.default_azimuth]
+ opt.zero123_ws += [opt.default_zero123_w]
+
+ if opt.image_config is not None:
+ # for multiview (zero123)
+ conf = pd.read_csv(opt.image_config, skipinitialspace=True)
+ opt.images += list(conf.image)
+ opt.ref_radii += list(conf.radius)
+ opt.ref_polars += list(conf.polar)
+ opt.ref_azimuths += list(conf.azimuth)
+ opt.zero123_ws += list(conf.zero123_weight)
+ if opt.image is None:
+ opt.default_radius = opt.ref_radii[0]
+ opt.default_polar = opt.ref_polars[0]
+ opt.default_azimuth = opt.ref_azimuths[0]
+ opt.default_zero123_w = opt.zero123_ws[0]
+
+ # reset to None
+ if len(opt.images) == 0:
+ opt.images = None
+
+ # default parameters for finetuning
+ if opt.dmtet:
+ opt.h = int(opt.h * opt.dmtet_reso_scale)
+ opt.w = int(opt.w * opt.dmtet_reso_scale)
+ opt.known_view_scale = 1
+ opt.grid_levels_mask = -1 # disable corse nerf (fine to keep, not necesary)
+ opt.t_range = [0.02, 0.50] # ref: magic3D
+
+ if opt.images is not None:
+ opt.lambda_normal = 0
+ opt.lambda_depth = 0
+
+ # assume finetuning
+ opt.latent_iter_ratio = 0
+ opt.textureless_iter_ratio = 0
+ opt.albedo_iter_ratio = 0
+ opt.normal_iter_ratio = 0
+ opt.progressive_view = False
+ opt.progressive_level = False
+
+ # record full range for progressive view expansion
+ if opt.progressive_view:
+ # disable as they disturb progressive view
+ opt.jitter_pose = False
+ opt.uniform_sphere_rate = 0
+ # back up full range
+ opt.full_radius_range = opt.radius_range
+ opt.full_theta_range = opt.theta_range
+ opt.full_phi_range = opt.phi_range
+ opt.full_fovy_range = opt.fovy_range
+
+ opt.use_clip = opt.clip_guidance > 0 or opt.lambda_clip > 0
+ # Do not support Shap-E for NeRF yet.
+ opt.use_shape = False if not opt.dmtet else opt.use_shape
+
+ # workspace prepare
+ setup_workspace(opt)
+ dnnultis.setup_logging(opt.log_path)
+
+ if opt.seed < 0:
+ opt.seed = random.randint(0, 10000)
+ seed_everything(int(opt.seed))
+
+ if opt.backbone == 'vanilla':
+ from nerf.network import NeRFNetwork
+ elif opt.backbone == 'grid':
+ from nerf.network_grid import NeRFNetwork
+ elif opt.backbone == 'grid_tcnn':
+ from nerf.network_grid_tcnn import NeRFNetwork
+ elif opt.backbone == 'grid_taichi':
+ opt.cuda_ray = False
+ opt.taichi_ray = True
+ import taichi as ti
+ from nerf.network_grid_taichi import NeRFNetwork
+ taichi_half2_opt = True
+ taichi_init_args = {"arch": ti.cuda, "device_memory_GB": 4.0}
+ if taichi_half2_opt:
+ taichi_init_args["half2_vectorization"] = True
+ ti.init(**taichi_init_args)
+ else:
+ raise NotImplementedError(
+ f'--backbone {opt.backbone} is not implemented!')
+
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ opt.device = device
+ model = NeRFNetwork(opt).to(device)
+
+ if opt.init_ckpt != '':
+ if not os.path.exists(opt.init_ckpt):
+ logger.warning(f'ckpt {opt.init_ckpt} is not found')
+ else:
+ # load pretrained weights to init dmtet
+ state_dict = torch.load(opt.init_ckpt, map_location=device)
+ model.load_state_dict(state_dict['model'], strict=False)
+ if opt.cuda_ray:
+ model.mean_density = state_dict['mean_density']
+ logger.info(f'init from {opt.init_ckpt}...')
+ # if init ckpt is provided, we assume the color network is well learned and do not need base_mesh init
+ opt.shape_init_color = False
+ opt.base_mesh = None
+
+ if opt.use_shape and opt.dmtet:
+ # now only supports shape for dmtet init
+ from guidance.shape_utils import get_shape_from_image
+
+ opt.points = generate_grid_points(
+ 128, device=device) if not opt.dmtet else model.dmtet.verts
+ opt.rpsts, opt.colors = get_shape_from_image(
+ opt.image.replace('rgba', 'rgb'),
+ opt.points,
+ rpst_type=opt.shape_rpst,
+ get_color=opt.shape_init_color,
+ shape_guidance=opt.shape_guidance, device=device)
+ scale = opt.default_radius / opt.shape_radius * \
+ np.tan(np.deg2rad(opt.default_fovy / 2)) / \
+ np.tan(np.deg2rad(opt.shape_fovy / 2))
+ if opt.dmtet:
+ model.dmtet.reset_tet_scale(scale)
+ else:
+ opt.points *= scale
+ logger.info(f'Got sdf from Shap-E init...')
+
+ logger.info(model)
+
+ if opt.six_views:
+ guidance = None # no need to load guidance model at test
+
+ trainer = Trainer(' '.join(sys.argv), 'df', opt, model, guidance, device=device,
+ workspace=opt.workspace, fp16=opt.fp16, use_checkpoint=opt.ckpt)
+
+ test_loader = NeRFDataset(
+ opt, device=device, type='six_views', H=opt.H, W=opt.W, size=6).dataloader(batch_size=1)
+ trainer.test(test_loader, write_video=False)
+
+ if opt.save_mesh:
+ trainer.save_mesh()
+
+ elif opt.test:
+ guidance = None # no need to load guidance model at test
+ trainer = Trainer(' '.join(sys.argv), os.path.basename(opt.workspace), opt, model, guidance,
+ device=device, workspace=opt.workspace, fp16=opt.fp16, use_checkpoint=opt.ckpt)
+ if opt.gui:
+ from nerf.gui import NeRFGUI
+ gui = NeRFGUI(opt, trainer)
+ gui.render()
+
+ else:
+ test_loader = NeRFDataset(
+ opt, device=device, type='test', H=opt.H, W=opt.W, size=opt.dataset_size_test).dataloader()
+ trainer.test(test_loader)
+ trainer.test(test_loader, shading='normal') # save normal
+ if opt.save_mesh:
+ try:
+ trainer.save_mesh()
+ except:
+ pass
+ else:
+ train_loader = NeRFDataset(
+ opt, device=device, type='train', H=opt.h, W=opt.w, size=opt.dataset_size_train * opt.batch_size).dataloader()
+
+ if opt.optim == 'adan':
+ from optimizer import Adan
+ # Adan usually requires a larger LR
+
+ def optimizer(model): return Adan(model.get_params(
+ 5 * opt.lr), eps=1e-8, weight_decay=2e-5, max_grad_norm=5.0, foreach=False)
+ else: # adam
+ def optimizer(model): return torch.optim.Adam(
+ model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)
+
+ if opt.backbone == 'vanilla':
+ def scheduler(optimizer): return optim.lr_scheduler.LambdaLR(
+ optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1))
+ else:
+ def scheduler(optimizer): return optim.lr_scheduler.LambdaLR(
+ optimizer, lambda iter: 1) # fixed
+
+ guidance = nn.ModuleDict()
+ lambda_guidance, guidance_scale = {}, {}
+ for idx, guidance_type in enumerate(opt.guidance):
+ lambda_guidance[guidance_type] = opt.lambda_guidance[idx] if idx < len(
+ opt.lambda_guidance) else opt.lambda_guidance[-1]
+ guidance_scale[guidance_type] = opt.guidance_scale[idx] if idx < len(
+ opt.guidance_scale) else opt.guidance_scale[-1]
+ if 'SD' == guidance_type:
+ from guidance.sd_utils import StableDiffusion, token_replace
+ guidance['SD'] = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key, opt.t_range,
+ learned_embeds_path=opt.learned_embeds_path,
+ use_clip=opt.use_clip, clip_t=opt.clip_t, clip_iterative=opt.clip_iterative, clip_version=opt.clip_version,
+ )
+ if opt.learned_embeds_path is not None and os.path.exists(opt.learned_embeds_path): # add textual inversion tokens to model
+ opt.text, opt.negative = token_replace(
+ opt.text, opt.negative, opt.learned_embeds_path)
+ logger.info(
+ f'prompt: {opt.text}, negative: {opt.negative}')
+ if opt.check_prompt:
+ guidance['SD'].check_prompt(opt)
+ else:
+ opt.text = opt.text.replace('', os.path.basename(os.path.dirname(opt.image)))
+ logger.warning('No learned_embeds_path provided, using the folowing pure text prompt with degraded performance: ' + opt.text)
+
+ if 'IF' == guidance_type:
+ from guidance.if_utils import IF
+ guidance['IF'] = IF(device, opt.vram_O, opt.t_range)
+
+ if 'zero123' == guidance_type:
+ from guidance.zero123_utils import Zero123
+ guidance['zero123'] = Zero123(device=device, fp16=opt.fp16, config=opt.zero123_config,
+ ckpt=opt.zero123_ckpt, vram_O=opt.vram_O, t_range=opt.t_range, opt=opt)
+
+ if 'clip' == guidance_type:
+ from guidance.clip_utils import CLIP
+ guidance['clip'] = CLIP(device)
+ opt.lambda_guidance = lambda_guidance
+ opt.guidance_scale = guidance_scale
+
+ logger.info(opt)
+ trainer = Trainer(' '.join(sys.argv), os.path.basename(opt.workspace), opt, model,
+ guidance,
+ device=device, workspace=opt.workspace, optimizer=optimizer,
+ ema_decay=opt.ema_decay, fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint=opt.ckpt, scheduler_update_every_step=True)
+ trainer.default_view_data = train_loader._data.get_default_view_data()
+
+ if opt.gui:
+ from nerf.gui import NeRFGUI
+ gui = NeRFGUI(opt, trainer, train_loader)
+ gui.render()
+
+ else:
+ valid_loader = NeRFDataset(
+ opt, device=device, type='val', H=opt.H, W=opt.W, size=opt.dataset_size_valid).dataloader()
+ test_loader = NeRFDataset(
+ opt, device=device, type='test', H=opt.H, W=opt.W, size=opt.dataset_size_test).dataloader()
+ max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32)
+
+ trainer.train(train_loader, valid_loader, test_loader, max_epoch)
+
+ trainer.test(test_loader)
+ trainer.test(test_loader, shading='normal') # save normal
+ if opt.save_mesh:
+ try:
+ trainer.save_mesh()
+ except:
+ pass
\ No newline at end of file
diff --git a/meshutils.py b/meshutils.py
new file mode 100644
index 0000000..4d1c53d
--- /dev/null
+++ b/meshutils.py
@@ -0,0 +1,117 @@
+import numpy as np
+import pymeshlab as pml
+
+def poisson_mesh_reconstruction(points, normals=None):
+ # points/normals: [N, 3] np.ndarray
+
+ import open3d as o3d
+
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(points)
+
+ # outlier removal
+ pcd, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=10)
+
+ # normals
+ if normals is None:
+ pcd.estimate_normals()
+ else:
+ pcd.normals = o3d.utility.Vector3dVector(normals[ind])
+
+ # visualize
+ o3d.visualization.draw_geometries([pcd], point_show_normal=False)
+
+ mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=9)
+ vertices_to_remove = densities < np.quantile(densities, 0.1)
+ mesh.remove_vertices_by_mask(vertices_to_remove)
+
+ # visualize
+ o3d.visualization.draw_geometries([mesh])
+
+ vertices = np.asarray(mesh.vertices)
+ triangles = np.asarray(mesh.triangles)
+
+ print(f'[INFO] poisson mesh reconstruction: {points.shape} --> {vertices.shape} / {triangles.shape}')
+
+ return vertices, triangles
+
+
+def decimate_mesh(verts, faces, target, backend='pymeshlab', remesh=False, optimalplacement=True):
+ # optimalplacement: default is True, but for flat mesh must turn False to prevent spike artifect.
+
+ _ori_vert_shape = verts.shape
+ _ori_face_shape = faces.shape
+
+ if backend == 'pyfqmr':
+ import pyfqmr
+ solver = pyfqmr.Simplify()
+ solver.setMesh(verts, faces)
+ solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False)
+ verts, faces, normals = solver.getMesh()
+ else:
+
+ m = pml.Mesh(verts, faces)
+ ms = pml.MeshSet()
+ ms.add_mesh(m, 'mesh') # will copy!
+
+ # filters
+ # ms.meshing_decimation_clustering(threshold=pml.Percentage(1))
+ ms.meshing_decimation_quadric_edge_collapse(targetfacenum=int(target), optimalplacement=optimalplacement)
+
+ if remesh:
+ # ms.apply_coord_taubin_smoothing()
+ ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.Percentage(1))
+
+ # extract mesh
+ m = ms.current_mesh()
+ verts = m.vertex_matrix()
+ faces = m.face_matrix()
+
+ print(f'[INFO] mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}')
+
+ return verts, faces
+
+
+def clean_mesh(verts, faces, v_pct=1, min_f=8, min_d=5, repair=True, remesh=True, remesh_size=0.01):
+ # verts: [N, 3]
+ # faces: [N, 3]
+
+ _ori_vert_shape = verts.shape
+ _ori_face_shape = faces.shape
+
+ m = pml.Mesh(verts, faces)
+ ms = pml.MeshSet()
+ ms.add_mesh(m, 'mesh') # will copy!
+
+ # filters
+ ms.meshing_remove_unreferenced_vertices() # verts not refed by any faces
+
+ if v_pct > 0:
+ ms.meshing_merge_close_vertices(threshold=pml.Percentage(v_pct)) # 1/10000 of bounding box diagonal
+
+ ms.meshing_remove_duplicate_faces() # faces defined by the same verts
+ ms.meshing_remove_null_faces() # faces with area == 0
+
+ if min_d > 0:
+ ms.meshing_remove_connected_component_by_diameter(mincomponentdiag=pml.Percentage(min_d))
+
+ if min_f > 0:
+ ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f)
+
+ if repair:
+ # ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True)
+ ms.meshing_repair_non_manifold_edges(method=0)
+ ms.meshing_repair_non_manifold_vertices(vertdispratio=0)
+
+ if remesh:
+ # ms.apply_coord_taubin_smoothing()
+ ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.AbsoluteValue(remesh_size))
+
+ # extract mesh
+ m = ms.current_mesh()
+ verts = m.vertex_matrix()
+ faces = m.face_matrix()
+
+ print(f'[INFO] mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}')
+
+ return verts, faces
\ No newline at end of file
diff --git a/midas/__init__.py b/midas/__init__.py
new file mode 100644
index 0000000..07398fb
--- /dev/null
+++ b/midas/__init__.py
@@ -0,0 +1 @@
+from .model_loader import load_model, default_models
\ No newline at end of file
diff --git a/midas/backbones/beit.py b/midas/backbones/beit.py
new file mode 100644
index 0000000..7a24e02
--- /dev/null
+++ b/midas/backbones/beit.py
@@ -0,0 +1,196 @@
+import timm
+import torch
+import types
+
+import numpy as np
+import torch.nn.functional as F
+
+from .utils import forward_adapted_unflatten, make_backbone_default
+from timm.models.beit import gen_relative_position_index
+from torch.utils.checkpoint import checkpoint
+from typing import Optional
+
+
+def forward_beit(pretrained, x):
+ return forward_adapted_unflatten(pretrained, x, "forward_features")
+
+
+def patch_embed_forward(self, x):
+ """
+ Modification of timm.models.layers.patch_embed.py: PatchEmbed.forward to support arbitrary window sizes.
+ """
+ x = self.proj(x)
+ if self.flatten:
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+ return x
+
+
+def _get_rel_pos_bias(self, window_size):
+ """
+ Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes.
+ """
+ old_height = 2 * self.window_size[0] - 1
+ old_width = 2 * self.window_size[1] - 1
+
+ new_height = 2 * window_size[0] - 1
+ new_width = 2 * window_size[1] - 1
+
+ old_relative_position_bias_table = self.relative_position_bias_table
+
+ old_num_relative_distance = self.num_relative_distance
+ new_num_relative_distance = new_height * new_width + 3
+
+ old_sub_table = old_relative_position_bias_table[:old_num_relative_distance - 3]
+
+ old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
+ new_sub_table = F.interpolate(old_sub_table, size=(new_height, new_width), mode="bilinear")
+ new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)
+
+ new_relative_position_bias_table = torch.cat(
+ [new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3:]])
+
+ key = str(window_size[1]) + "," + str(window_size[0])
+ if key not in self.relative_position_indices.keys():
+ self.relative_position_indices[key] = gen_relative_position_index(window_size)
+
+ relative_position_bias = new_relative_position_bias_table[
+ self.relative_position_indices[key].view(-1)].view(
+ window_size[0] * window_size[1] + 1,
+ window_size[0] * window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ return relative_position_bias.unsqueeze(0)
+
+
+def attention_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None):
+ """
+ Modification of timm.models.beit.py: Attention.forward to support arbitrary window sizes.
+ """
+ B, N, C = x.shape
+
+ qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ if self.relative_position_bias_table is not None:
+ window_size = tuple(np.array(resolution) // 16)
+ attn = attn + self._get_rel_pos_bias(window_size)
+ if shared_rel_pos_bias is not None:
+ attn = attn + shared_rel_pos_bias
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+def block_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None):
+ """
+ Modification of timm.models.beit.py: Block.forward to support arbitrary window sizes.
+ """
+ if self.gamma_1 is None:
+ x = x + self.drop_path(self.attn(self.norm1(x), resolution, shared_rel_pos_bias=shared_rel_pos_bias))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ else:
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), resolution,
+ shared_rel_pos_bias=shared_rel_pos_bias))
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+ return x
+
+
+def beit_forward_features(self, x):
+ """
+ Modification of timm.models.beit.py: Beit.forward_features to support arbitrary window sizes.
+ """
+ resolution = x.shape[2:]
+
+ x = self.patch_embed(x)
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
+ for blk in self.blocks:
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias)
+ else:
+ x = blk(x, resolution, shared_rel_pos_bias=rel_pos_bias)
+ x = self.norm(x)
+ return x
+
+
+def _make_beit_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ size=[384, 384],
+ hooks=[0, 4, 8, 11],
+ vit_features=768,
+ use_readout="ignore",
+ start_index=1,
+ start_index_readout=1,
+):
+ backbone = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index,
+ start_index_readout)
+
+ backbone.model.patch_embed.forward = types.MethodType(patch_embed_forward, backbone.model.patch_embed)
+ backbone.model.forward_features = types.MethodType(beit_forward_features, backbone.model)
+
+ for block in backbone.model.blocks:
+ attn = block.attn
+ attn._get_rel_pos_bias = types.MethodType(_get_rel_pos_bias, attn)
+ attn.forward = types.MethodType(attention_forward, attn)
+ attn.relative_position_indices = {}
+
+ block.forward = types.MethodType(block_forward, block)
+
+ return backbone
+
+
+def _make_pretrained_beitl16_512(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("beit_large_patch16_512", pretrained=pretrained)
+
+ hooks = [5, 11, 17, 23] if hooks is None else hooks
+
+ features = [256, 512, 1024, 1024]
+
+ return _make_beit_backbone(
+ model,
+ features=features,
+ size=[512, 512],
+ hooks=hooks,
+ vit_features=1024,
+ use_readout=use_readout,
+ )
+
+
+def _make_pretrained_beitl16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("beit_large_patch16_384", pretrained=pretrained)
+
+ hooks = [5, 11, 17, 23] if hooks is None else hooks
+ return _make_beit_backbone(
+ model,
+ features=[256, 512, 1024, 1024],
+ hooks=hooks,
+ vit_features=1024,
+ use_readout=use_readout,
+ )
+
+
+def _make_pretrained_beitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("beit_base_patch16_384", pretrained=pretrained)
+
+ hooks = [2, 5, 8, 11] if hooks is None else hooks
+ return _make_beit_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ hooks=hooks,
+ use_readout=use_readout,
+ )
diff --git a/midas/backbones/levit.py b/midas/backbones/levit.py
new file mode 100644
index 0000000..6d023a9
--- /dev/null
+++ b/midas/backbones/levit.py
@@ -0,0 +1,106 @@
+import timm
+import torch
+import torch.nn as nn
+import numpy as np
+
+from .utils import activations, get_activation, Transpose
+
+
+def forward_levit(pretrained, x):
+ pretrained.model.forward_features(x)
+
+ layer_1 = pretrained.activations["1"]
+ layer_2 = pretrained.activations["2"]
+ layer_3 = pretrained.activations["3"]
+
+ layer_1 = pretrained.act_postprocess1(layer_1)
+ layer_2 = pretrained.act_postprocess2(layer_2)
+ layer_3 = pretrained.act_postprocess3(layer_3)
+
+ return layer_1, layer_2, layer_3
+
+
+def _make_levit_backbone(
+ model,
+ hooks=[3, 11, 21],
+ patch_grid=[14, 14]
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+
+ pretrained.activations = activations
+
+ patch_grid_size = np.array(patch_grid, dtype=int)
+
+ pretrained.act_postprocess1 = nn.Sequential(
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size(patch_grid_size.tolist()))
+ )
+ pretrained.act_postprocess2 = nn.Sequential(
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 2).astype(int)).tolist()))
+ )
+ pretrained.act_postprocess3 = nn.Sequential(
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 4).astype(int)).tolist()))
+ )
+
+ return pretrained
+
+
+class ConvTransposeNorm(nn.Sequential):
+ """
+ Modification of
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: ConvNorm
+ such that ConvTranspose2d is used instead of Conv2d.
+ """
+
+ def __init__(
+ self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1,
+ groups=1, bn_weight_init=1):
+ super().__init__()
+ self.add_module('c',
+ nn.ConvTranspose2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False))
+ self.add_module('bn', nn.BatchNorm2d(out_chs))
+
+ nn.init.constant_(self.bn.weight, bn_weight_init)
+
+ @torch.no_grad()
+ def fuse(self):
+ c, bn = self._modules.values()
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
+ w = c.weight * w[:, None, None, None]
+ b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
+ m = nn.ConvTranspose2d(
+ w.size(1), w.size(0), w.shape[2:], stride=self.c.stride,
+ padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
+ m.weight.data.copy_(w)
+ m.bias.data.copy_(b)
+ return m
+
+
+def stem_b4_transpose(in_chs, out_chs, activation):
+ """
+ Modification of
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: stem_b16
+ such that ConvTranspose2d is used instead of Conv2d and stem is also reduced to the half.
+ """
+ return nn.Sequential(
+ ConvTransposeNorm(in_chs, out_chs, 3, 2, 1),
+ activation(),
+ ConvTransposeNorm(out_chs, out_chs // 2, 3, 2, 1),
+ activation())
+
+
+def _make_pretrained_levit_384(pretrained, hooks=None):
+ model = timm.create_model("levit_384", pretrained=pretrained)
+
+ hooks = [3, 11, 21] if hooks == None else hooks
+ return _make_levit_backbone(
+ model,
+ hooks=hooks
+ )
diff --git a/midas/backbones/next_vit.py b/midas/backbones/next_vit.py
new file mode 100644
index 0000000..8afdd8b
--- /dev/null
+++ b/midas/backbones/next_vit.py
@@ -0,0 +1,39 @@
+import timm
+
+import torch.nn as nn
+
+from pathlib import Path
+from .utils import activations, forward_default, get_activation
+
+from ..external.next_vit.classification.nextvit import *
+
+
+def forward_next_vit(pretrained, x):
+ return forward_default(pretrained, x, "forward")
+
+
+def _make_next_vit_backbone(
+ model,
+ hooks=[2, 6, 36, 39],
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+ pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2"))
+ pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ return pretrained
+
+
+def _make_pretrained_next_vit_large_6m(hooks=None):
+ model = timm.create_model("nextvit_large")
+
+ hooks = [2, 6, 36, 39] if hooks == None else hooks
+ return _make_next_vit_backbone(
+ model,
+ hooks=hooks,
+ )
diff --git a/midas/backbones/swin.py b/midas/backbones/swin.py
new file mode 100644
index 0000000..f8c7136
--- /dev/null
+++ b/midas/backbones/swin.py
@@ -0,0 +1,13 @@
+import timm
+
+from .swin_common import _make_swin_backbone
+
+
+def _make_pretrained_swinl12_384(pretrained, hooks=None):
+ model = timm.create_model("swin_large_patch4_window12_384", pretrained=pretrained)
+
+ hooks = [1, 1, 17, 1] if hooks == None else hooks
+ return _make_swin_backbone(
+ model,
+ hooks=hooks
+ )
diff --git a/midas/backbones/swin2.py b/midas/backbones/swin2.py
new file mode 100644
index 0000000..ce4c8f1
--- /dev/null
+++ b/midas/backbones/swin2.py
@@ -0,0 +1,34 @@
+import timm
+
+from .swin_common import _make_swin_backbone
+
+
+def _make_pretrained_swin2l24_384(pretrained, hooks=None):
+ model = timm.create_model("swinv2_large_window12to24_192to384_22kft1k", pretrained=pretrained)
+
+ hooks = [1, 1, 17, 1] if hooks == None else hooks
+ return _make_swin_backbone(
+ model,
+ hooks=hooks
+ )
+
+
+def _make_pretrained_swin2b24_384(pretrained, hooks=None):
+ model = timm.create_model("swinv2_base_window12to24_192to384_22kft1k", pretrained=pretrained)
+
+ hooks = [1, 1, 17, 1] if hooks == None else hooks
+ return _make_swin_backbone(
+ model,
+ hooks=hooks
+ )
+
+
+def _make_pretrained_swin2t16_256(pretrained, hooks=None):
+ model = timm.create_model("swinv2_tiny_window16_256", pretrained=pretrained)
+
+ hooks = [1, 1, 5, 1] if hooks == None else hooks
+ return _make_swin_backbone(
+ model,
+ hooks=hooks,
+ patch_grid=[64, 64]
+ )
diff --git a/midas/backbones/swin_common.py b/midas/backbones/swin_common.py
new file mode 100644
index 0000000..94d63d4
--- /dev/null
+++ b/midas/backbones/swin_common.py
@@ -0,0 +1,52 @@
+import torch
+
+import torch.nn as nn
+import numpy as np
+
+from .utils import activations, forward_default, get_activation, Transpose
+
+
+def forward_swin(pretrained, x):
+ return forward_default(pretrained, x)
+
+
+def _make_swin_backbone(
+ model,
+ hooks=[1, 1, 17, 1],
+ patch_grid=[96, 96]
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+ pretrained.model.layers[0].blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.layers[1].blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ pretrained.model.layers[2].blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.layers[3].blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ if hasattr(model, "patch_grid"):
+ used_patch_grid = model.patch_grid
+ else:
+ used_patch_grid = patch_grid
+
+ patch_grid_size = np.array(used_patch_grid, dtype=int)
+
+ pretrained.act_postprocess1 = nn.Sequential(
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size(patch_grid_size.tolist()))
+ )
+ pretrained.act_postprocess2 = nn.Sequential(
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size((patch_grid_size // 2).tolist()))
+ )
+ pretrained.act_postprocess3 = nn.Sequential(
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size((patch_grid_size // 4).tolist()))
+ )
+ pretrained.act_postprocess4 = nn.Sequential(
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size((patch_grid_size // 8).tolist()))
+ )
+
+ return pretrained
diff --git a/midas/backbones/utils.py b/midas/backbones/utils.py
new file mode 100644
index 0000000..0558899
--- /dev/null
+++ b/midas/backbones/utils.py
@@ -0,0 +1,249 @@
+import torch
+
+import torch.nn as nn
+
+
+class Slice(nn.Module):
+ def __init__(self, start_index=1):
+ super(Slice, self).__init__()
+ self.start_index = start_index
+
+ def forward(self, x):
+ return x[:, self.start_index:]
+
+
+class AddReadout(nn.Module):
+ def __init__(self, start_index=1):
+ super(AddReadout, self).__init__()
+ self.start_index = start_index
+
+ def forward(self, x):
+ if self.start_index == 2:
+ readout = (x[:, 0] + x[:, 1]) / 2
+ else:
+ readout = x[:, 0]
+ return x[:, self.start_index:] + readout.unsqueeze(1)
+
+
+class ProjectReadout(nn.Module):
+ def __init__(self, in_features, start_index=1):
+ super(ProjectReadout, self).__init__()
+ self.start_index = start_index
+
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
+
+ def forward(self, x):
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:])
+ features = torch.cat((x[:, self.start_index:], readout), -1)
+
+ return self.project(features)
+
+
+class Transpose(nn.Module):
+ def __init__(self, dim0, dim1):
+ super(Transpose, self).__init__()
+ self.dim0 = dim0
+ self.dim1 = dim1
+
+ def forward(self, x):
+ x = x.transpose(self.dim0, self.dim1)
+ return x
+
+
+activations = {}
+
+
+def get_activation(name):
+ def hook(model, input, output):
+ activations[name] = output
+
+ return hook
+
+
+def forward_default(pretrained, x, function_name="forward_features"):
+ exec(f"pretrained.model.{function_name}(x)")
+
+ layer_1 = pretrained.activations["1"]
+ layer_2 = pretrained.activations["2"]
+ layer_3 = pretrained.activations["3"]
+ layer_4 = pretrained.activations["4"]
+
+ if hasattr(pretrained, "act_postprocess1"):
+ layer_1 = pretrained.act_postprocess1(layer_1)
+ if hasattr(pretrained, "act_postprocess2"):
+ layer_2 = pretrained.act_postprocess2(layer_2)
+ if hasattr(pretrained, "act_postprocess3"):
+ layer_3 = pretrained.act_postprocess3(layer_3)
+ if hasattr(pretrained, "act_postprocess4"):
+ layer_4 = pretrained.act_postprocess4(layer_4)
+
+ return layer_1, layer_2, layer_3, layer_4
+
+
+def forward_adapted_unflatten(pretrained, x, function_name="forward_features"):
+ b, c, h, w = x.shape
+
+ exec(f"glob = pretrained.model.{function_name}(x)")
+
+ layer_1 = pretrained.activations["1"]
+ layer_2 = pretrained.activations["2"]
+ layer_3 = pretrained.activations["3"]
+ layer_4 = pretrained.activations["4"]
+
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
+
+ unflatten = nn.Sequential(
+ nn.Unflatten(
+ 2,
+ torch.Size(
+ [
+ h // pretrained.model.patch_size[1],
+ w // pretrained.model.patch_size[0],
+ ]
+ ),
+ )
+ )
+
+ if layer_1.ndim == 3:
+ layer_1 = unflatten(layer_1)
+ if layer_2.ndim == 3:
+ layer_2 = unflatten(layer_2)
+ if layer_3.ndim == 3:
+ layer_3 = unflatten(layer_3)
+ if layer_4.ndim == 3:
+ layer_4 = unflatten(layer_4)
+
+ layer_1 = pretrained.act_postprocess1[3: len(pretrained.act_postprocess1)](layer_1)
+ layer_2 = pretrained.act_postprocess2[3: len(pretrained.act_postprocess2)](layer_2)
+ layer_3 = pretrained.act_postprocess3[3: len(pretrained.act_postprocess3)](layer_3)
+ layer_4 = pretrained.act_postprocess4[3: len(pretrained.act_postprocess4)](layer_4)
+
+ return layer_1, layer_2, layer_3, layer_4
+
+
+def get_readout_oper(vit_features, features, use_readout, start_index=1):
+ if use_readout == "ignore":
+ readout_oper = [Slice(start_index)] * len(features)
+ elif use_readout == "add":
+ readout_oper = [AddReadout(start_index)] * len(features)
+ elif use_readout == "project":
+ readout_oper = [
+ ProjectReadout(vit_features, start_index) for out_feat in features
+ ]
+ else:
+ assert (
+ False
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
+
+ return readout_oper
+
+
+def make_backbone_default(
+ model,
+ features=[96, 192, 384, 768],
+ size=[384, 384],
+ hooks=[2, 5, 8, 11],
+ vit_features=768,
+ use_readout="ignore",
+ start_index=1,
+ start_index_readout=1,
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index_readout)
+
+ # 32, 48, 136, 384
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+
+ return pretrained
diff --git a/midas/backbones/vit.py b/midas/backbones/vit.py
new file mode 100644
index 0000000..413f969
--- /dev/null
+++ b/midas/backbones/vit.py
@@ -0,0 +1,221 @@
+import torch
+import torch.nn as nn
+import timm
+import types
+import math
+import torch.nn.functional as F
+
+from .utils import (activations, forward_adapted_unflatten, get_activation, get_readout_oper,
+ make_backbone_default, Transpose)
+
+
+def forward_vit(pretrained, x):
+ return forward_adapted_unflatten(pretrained, x, "forward_flex")
+
+
+def _resize_pos_embed(self, posemb, gs_h, gs_w):
+ posemb_tok, posemb_grid = (
+ posemb[:, : self.start_index],
+ posemb[0, self.start_index:],
+ )
+
+ gs_old = int(math.sqrt(len(posemb_grid)))
+
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
+
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+
+ return posemb
+
+
+def forward_flex(self, x):
+ b, c, h, w = x.shape
+
+ pos_embed = self._resize_pos_embed(
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
+ )
+
+ B = x.shape[0]
+
+ if hasattr(self.patch_embed, "backbone"):
+ x = self.patch_embed.backbone(x)
+ if isinstance(x, (list, tuple)):
+ x = x[-1] # last feature if backbone outputs list/tuple of features
+
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
+
+ if getattr(self, "dist_token", None) is not None:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ dist_token = self.dist_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
+ else:
+ if self.no_embed_class:
+ x = x + pos_embed
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ if not self.no_embed_class:
+ x = x + pos_embed
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.norm(x)
+
+ return x
+
+
+def _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ size=[384, 384],
+ hooks=[2, 5, 8, 11],
+ vit_features=768,
+ use_readout="ignore",
+ start_index=1,
+ start_index_readout=1,
+):
+ pretrained = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index,
+ start_index_readout)
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+
+ return pretrained
+
+
+def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
+
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[256, 512, 1024, 1024],
+ hooks=hooks,
+ vit_features=1024,
+ use_readout=use_readout,
+ )
+
+
+def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+ )
+
+
+def _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=[0, 1, 8, 11],
+ vit_features=768,
+ patch_size=[16, 16],
+ number_stages=2,
+ use_vit_only=False,
+ use_readout="ignore",
+ start_index=1,
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+
+ used_number_stages = 0 if use_vit_only else number_stages
+ for s in range(used_number_stages):
+ pretrained.model.patch_embed.backbone.stages[s].register_forward_hook(
+ get_activation(str(s + 1))
+ )
+ for s in range(used_number_stages, 4):
+ pretrained.model.blocks[hooks[s]].register_forward_hook(get_activation(str(s + 1)))
+
+ pretrained.activations = activations
+
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+ for s in range(used_number_stages):
+ value = nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity())
+ exec(f"pretrained.act_postprocess{s + 1}=value")
+ for s in range(used_number_stages, 4):
+ if s < number_stages:
+ final_layer = nn.ConvTranspose2d(
+ in_channels=features[s],
+ out_channels=features[s],
+ kernel_size=4 // (2 ** s),
+ stride=4 // (2 ** s),
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ )
+ elif s > number_stages:
+ final_layer = nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ )
+ else:
+ final_layer = None
+
+ layers = [
+ readout_oper[s],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[s],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ ]
+ if final_layer is not None:
+ layers.append(final_layer)
+
+ value = nn.Sequential(*layers)
+ exec(f"pretrained.act_postprocess{s + 1}=value")
+
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = patch_size
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+
+ return pretrained
+
+
+def _make_pretrained_vitb_rn50_384(
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
+):
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
+
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
+ return _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
diff --git a/midas/base_model.py b/midas/base_model.py
new file mode 100644
index 0000000..5cf4302
--- /dev/null
+++ b/midas/base_model.py
@@ -0,0 +1,16 @@
+import torch
+
+
+class BaseModel(torch.nn.Module):
+ def load(self, path):
+ """Load model from file.
+
+ Args:
+ path (str): file path
+ """
+ parameters = torch.load(path, map_location=torch.device('cpu'))
+
+ if "optimizer" in parameters:
+ parameters = parameters["model"]
+
+ self.load_state_dict(parameters)
diff --git a/midas/blocks.py b/midas/blocks.py
new file mode 100644
index 0000000..6d87a00
--- /dev/null
+++ b/midas/blocks.py
@@ -0,0 +1,439 @@
+import torch
+import torch.nn as nn
+
+from .backbones.beit import (
+ _make_pretrained_beitl16_512,
+ _make_pretrained_beitl16_384,
+ _make_pretrained_beitb16_384,
+ forward_beit,
+)
+from .backbones.swin_common import (
+ forward_swin,
+)
+from .backbones.swin2 import (
+ _make_pretrained_swin2l24_384,
+ _make_pretrained_swin2b24_384,
+ _make_pretrained_swin2t16_256,
+)
+from .backbones.swin import (
+ _make_pretrained_swinl12_384,
+)
+from .backbones.levit import (
+ _make_pretrained_levit_384,
+ forward_levit,
+)
+from .backbones.vit import (
+ _make_pretrained_vitb_rn50_384,
+ _make_pretrained_vitl16_384,
+ _make_pretrained_vitb16_384,
+ forward_vit,
+)
+
+def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None,
+ use_vit_only=False, use_readout="ignore", in_features=[96, 256, 512, 1024]):
+ if backbone == "beitl16_512":
+ pretrained = _make_pretrained_beitl16_512(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
+ ) # BEiT_512-L (backbone)
+ elif backbone == "beitl16_384":
+ pretrained = _make_pretrained_beitl16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
+ ) # BEiT_384-L (backbone)
+ elif backbone == "beitb16_384":
+ pretrained = _make_pretrained_beitb16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [96, 192, 384, 768], features, groups=groups, expand=expand
+ ) # BEiT_384-B (backbone)
+ elif backbone == "swin2l24_384":
+ pretrained = _make_pretrained_swin2l24_384(
+ use_pretrained, hooks=hooks
+ )
+ scratch = _make_scratch(
+ [192, 384, 768, 1536], features, groups=groups, expand=expand
+ ) # Swin2-L/12to24 (backbone)
+ elif backbone == "swin2b24_384":
+ pretrained = _make_pretrained_swin2b24_384(
+ use_pretrained, hooks=hooks
+ )
+ scratch = _make_scratch(
+ [128, 256, 512, 1024], features, groups=groups, expand=expand
+ ) # Swin2-B/12to24 (backbone)
+ elif backbone == "swin2t16_256":
+ pretrained = _make_pretrained_swin2t16_256(
+ use_pretrained, hooks=hooks
+ )
+ scratch = _make_scratch(
+ [96, 192, 384, 768], features, groups=groups, expand=expand
+ ) # Swin2-T/16 (backbone)
+ elif backbone == "swinl12_384":
+ pretrained = _make_pretrained_swinl12_384(
+ use_pretrained, hooks=hooks
+ )
+ scratch = _make_scratch(
+ [192, 384, 768, 1536], features, groups=groups, expand=expand
+ ) # Swin-L/12 (backbone)
+ elif backbone == "next_vit_large_6m":
+ from .backbones.next_vit import _make_pretrained_next_vit_large_6m
+ pretrained = _make_pretrained_next_vit_large_6m(hooks=hooks)
+ scratch = _make_scratch(
+ in_features, features, groups=groups, expand=expand
+ ) # Next-ViT-L on ImageNet-1K-6M (backbone)
+ elif backbone == "levit_384":
+ pretrained = _make_pretrained_levit_384(
+ use_pretrained, hooks=hooks
+ )
+ scratch = _make_scratch(
+ [384, 512, 768], features, groups=groups, expand=expand
+ ) # LeViT 384 (backbone)
+ elif backbone == "vitl16_384":
+ pretrained = _make_pretrained_vitl16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb_rn50_384":
+ pretrained = _make_pretrained_vitb_rn50_384(
+ use_pretrained,
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
+ scratch = _make_scratch(
+ [256, 512, 768, 768], features, groups=groups, expand=expand
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb16_384":
+ pretrained = _make_pretrained_vitb16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [96, 192, 384, 768], features, groups=groups, expand=expand
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
+ elif backbone == "resnext101_wsl":
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
+ elif backbone == "efficientnet_lite3":
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
+ else:
+ print(f"Backbone '{backbone}' not implemented")
+ assert False
+
+ return pretrained, scratch
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape
+
+ if expand:
+ out_shape1 = out_shape
+ out_shape2 = out_shape*2
+ out_shape3 = out_shape*4
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape*8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ if len(in_shape) >= 4:
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+
+ return scratch
+
+
+def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
+ efficientnet = torch.hub.load(
+ "rwightman/gen-efficientnet-pytorch",
+ "tf_efficientnet_lite3",
+ pretrained=use_pretrained,
+ exportable=exportable
+ )
+ return _make_efficientnet_backbone(efficientnet)
+
+
+def _make_efficientnet_backbone(effnet):
+ pretrained = nn.Module()
+
+ pretrained.layer1 = nn.Sequential(
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
+ )
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
+
+ return pretrained
+
+
+def _make_resnet_backbone(resnet):
+ pretrained = nn.Module()
+ pretrained.layer1 = nn.Sequential(
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
+ )
+
+ pretrained.layer2 = resnet.layer2
+ pretrained.layer3 = resnet.layer3
+ pretrained.layer4 = resnet.layer4
+
+ return pretrained
+
+
+def _make_pretrained_resnext101_wsl(use_pretrained):
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
+ return _make_resnet_backbone(resnet)
+
+
+
+class Interpolate(nn.Module):
+ """Interpolation module.
+ """
+
+ def __init__(self, scale_factor, mode, align_corners=False):
+ """Init.
+
+ Args:
+ scale_factor (float): scaling
+ mode (str): interpolation mode
+ """
+ super(Interpolate, self).__init__()
+
+ self.interp = nn.functional.interpolate
+ self.scale_factor = scale_factor
+ self.mode = mode
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: interpolated data
+ """
+
+ x = self.interp(
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
+ )
+
+ return x
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+ out = self.relu(x)
+ out = self.conv1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+
+ return out + x
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(self, features):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.resConfUnit1 = ResidualConvUnit(features)
+ self.resConfUnit2 = ResidualConvUnit(features)
+
+ def forward(self, *xs):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ output += self.resConfUnit1(xs[1])
+
+ output = self.resConfUnit2(output)
+
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=True
+ )
+
+ return output
+
+
+
+
+class ResidualConvUnit_custom(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features, activation, bn):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups=1
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ if self.bn==True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn==True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn==True:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+
+ # return out + x
+
+
+class FeatureFusionBlock_custom(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock_custom, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups=1
+
+ self.expand = expand
+ out_features = features
+ if self.expand==True:
+ out_features = features//2
+
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ self.size=size
+
+ def forward(self, *xs, size=None):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+ # output += res
+
+ output = self.resConfUnit2(output)
+
+ if (size is None) and (self.size is None):
+ modifier = {"scale_factor": 2}
+ elif size is None:
+ modifier = {"size": self.size}
+ else:
+ modifier = {"size": size}
+
+ output = nn.functional.interpolate(
+ output, **modifier, mode="bilinear", align_corners=self.align_corners
+ )
+
+ output = self.out_conv(output)
+
+ return output
+
diff --git a/midas/dpt_depth.py b/midas/dpt_depth.py
new file mode 100644
index 0000000..3129d09
--- /dev/null
+++ b/midas/dpt_depth.py
@@ -0,0 +1,166 @@
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import (
+ FeatureFusionBlock_custom,
+ Interpolate,
+ _make_encoder,
+ forward_beit,
+ forward_swin,
+ forward_levit,
+ forward_vit,
+)
+from .backbones.levit import stem_b4_transpose
+from timm.models.layers import get_act_layer
+
+
+def _make_fusion_block(features, use_bn, size = None):
+ return FeatureFusionBlock_custom(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ size=size,
+ )
+
+
+class DPT(BaseModel):
+ def __init__(
+ self,
+ head,
+ features=256,
+ backbone="vitb_rn50_384",
+ readout="project",
+ channels_last=False,
+ use_bn=False,
+ **kwargs
+ ):
+
+ super(DPT, self).__init__()
+
+ self.channels_last = channels_last
+
+ # For the Swin, Swin 2, LeViT and Next-ViT Transformers, the hierarchical architectures prevent setting the
+ # hooks freely. Instead, the hooks have to be chosen according to the ranges specified in the comments.
+ hooks = {
+ "beitl16_512": [5, 11, 17, 23],
+ "beitl16_384": [5, 11, 17, 23],
+ "beitb16_384": [2, 5, 8, 11],
+ "swin2l24_384": [1, 1, 17, 1], # Allowed ranges: [0, 1], [0, 1], [ 0, 17], [ 0, 1]
+ "swin2b24_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1]
+ "swin2t16_256": [1, 1, 5, 1], # [0, 1], [0, 1], [ 0, 5], [ 0, 1]
+ "swinl12_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1]
+ "next_vit_large_6m": [2, 6, 36, 39], # [0, 2], [3, 6], [ 7, 36], [37, 39]
+ "levit_384": [3, 11, 21], # [0, 3], [6, 11], [14, 21]
+ "vitb_rn50_384": [0, 1, 8, 11],
+ "vitb16_384": [2, 5, 8, 11],
+ "vitl16_384": [5, 11, 17, 23],
+ }[backbone]
+
+ if "next_vit" in backbone:
+ in_features = {
+ "next_vit_large_6m": [96, 256, 512, 1024],
+ }[backbone]
+ else:
+ in_features = None
+
+ # Instantiate backbone and reassemble blocks
+ self.pretrained, self.scratch = _make_encoder(
+ backbone,
+ features,
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
+ groups=1,
+ expand=False,
+ exportable=False,
+ hooks=hooks,
+ use_readout=readout,
+ in_features=in_features,
+ )
+
+ self.number_layers = len(hooks) if hooks is not None else 4
+ size_refinenet3 = None
+ self.scratch.stem_transpose = None
+
+ if "beit" in backbone:
+ self.forward_transformer = forward_beit
+ elif "swin" in backbone:
+ self.forward_transformer = forward_swin
+ elif "next_vit" in backbone:
+ from .backbones.next_vit import forward_next_vit
+ self.forward_transformer = forward_next_vit
+ elif "levit" in backbone:
+ self.forward_transformer = forward_levit
+ size_refinenet3 = 7
+ self.scratch.stem_transpose = stem_b4_transpose(256, 128, get_act_layer("hard_swish"))
+ else:
+ self.forward_transformer = forward_vit
+
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn, size_refinenet3)
+ if self.number_layers >= 4:
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+ self.scratch.output_conv = head
+
+
+ def forward(self, x):
+ if self.channels_last == True:
+ x.contiguous(memory_format=torch.channels_last)
+
+ layers = self.forward_transformer(self.pretrained, x)
+ if self.number_layers == 3:
+ layer_1, layer_2, layer_3 = layers
+ else:
+ layer_1, layer_2, layer_3, layer_4 = layers
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ if self.number_layers >= 4:
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ if self.number_layers == 3:
+ path_3 = self.scratch.refinenet3(layer_3_rn, size=layer_2_rn.shape[2:])
+ else:
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ if self.scratch.stem_transpose is not None:
+ path_1 = self.scratch.stem_transpose(path_1)
+
+ out = self.scratch.output_conv(path_1)
+
+ return out
+
+
+class DPTDepthModel(DPT):
+ def __init__(self, path=None, non_negative=True, **kwargs):
+ features = kwargs["features"] if "features" in kwargs else 256
+ head_features_1 = kwargs["head_features_1"] if "head_features_1" in kwargs else features
+ head_features_2 = kwargs["head_features_2"] if "head_features_2" in kwargs else 32
+ kwargs.pop("head_features_1", None)
+ kwargs.pop("head_features_2", None)
+
+ head = nn.Sequential(
+ nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+
+ super().__init__(head, **kwargs)
+
+ if path is not None:
+ self.load(path)
+
+ def forward(self, x):
+ return super().forward(x).squeeze(dim=1)
diff --git a/midas/midas_net.py b/midas/midas_net.py
new file mode 100644
index 0000000..8a95497
--- /dev/null
+++ b/midas/midas_net.py
@@ -0,0 +1,76 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
+
+
+class MidasNet(BaseModel):
+ """Network for monocular depth estimation.
+ """
+
+ def __init__(self, path=None, features=256, non_negative=True):
+ """Init.
+
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+
+ super(MidasNet, self).__init__()
+
+ use_pretrained = False if path is None else True
+
+ self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
+
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
+
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ )
+
+ if path:
+ self.load(path)
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input data (image)
+
+ Returns:
+ tensor: depth
+ """
+
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return torch.squeeze(out, dim=1)
diff --git a/midas/midas_net_custom.py b/midas/midas_net_custom.py
new file mode 100644
index 0000000..50e4acb
--- /dev/null
+++ b/midas/midas_net_custom.py
@@ -0,0 +1,128 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
+
+
+class MidasNet_small(BaseModel):
+ """Network for monocular depth estimation.
+ """
+
+ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
+ blocks={'expand': True}):
+ """Init.
+
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+
+ super(MidasNet_small, self).__init__()
+
+ use_pretrained = False if path else True
+
+ self.channels_last = channels_last
+ self.blocks = blocks
+ self.backbone = backbone
+
+ self.groups = 1
+
+ features1=features
+ features2=features
+ features3=features
+ features4=features
+ self.expand = False
+ if "expand" in self.blocks and self.blocks['expand'] == True:
+ self.expand = True
+ features1=features
+ features2=features*2
+ features3=features*4
+ features4=features*8
+
+ self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
+
+ self.scratch.activation = nn.ReLU(False)
+
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
+
+
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
+ self.scratch.activation,
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+
+ if path:
+ self.load(path)
+
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input data (image)
+
+ Returns:
+ tensor: depth
+ """
+ if self.channels_last==True:
+ print("self.channels_last = ", self.channels_last)
+ x.contiguous(memory_format=torch.channels_last)
+
+
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return torch.squeeze(out, dim=1)
+
+
+
+def fuse_model(m):
+ prev_previous_type = nn.Identity()
+ prev_previous_name = ''
+ previous_type = nn.Identity()
+ previous_name = ''
+ for name, module in m.named_modules():
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
+ # print("FUSED ", prev_previous_name, previous_name, name)
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
+ # print("FUSED ", prev_previous_name, previous_name)
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
+ # print("FUSED ", previous_name, name)
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
+
+ prev_previous_type = previous_type
+ prev_previous_name = previous_name
+ previous_type = type(module)
+ previous_name = name
\ No newline at end of file
diff --git a/midas/model_loader.py b/midas/model_loader.py
new file mode 100644
index 0000000..f1cd1f2
--- /dev/null
+++ b/midas/model_loader.py
@@ -0,0 +1,242 @@
+import cv2
+import torch
+
+from midas.dpt_depth import DPTDepthModel
+from midas.midas_net import MidasNet
+from midas.midas_net_custom import MidasNet_small
+from midas.transforms import Resize, NormalizeImage, PrepareForNet
+
+from torchvision.transforms import Compose
+
+default_models = {
+ "dpt_beit_large_512": "weights/dpt_beit_large_512.pt",
+ "dpt_beit_large_384": "weights/dpt_beit_large_384.pt",
+ "dpt_beit_base_384": "weights/dpt_beit_base_384.pt",
+ "dpt_swin2_large_384": "weights/dpt_swin2_large_384.pt",
+ "dpt_swin2_base_384": "weights/dpt_swin2_base_384.pt",
+ "dpt_swin2_tiny_256": "weights/dpt_swin2_tiny_256.pt",
+ "dpt_swin_large_384": "weights/dpt_swin_large_384.pt",
+ "dpt_next_vit_large_384": "weights/dpt_next_vit_large_384.pt",
+ "dpt_levit_224": "weights/dpt_levit_224.pt",
+ "dpt_large_384": "weights/dpt_large_384.pt",
+ "dpt_hybrid_384": "weights/dpt_hybrid_384.pt",
+ "midas_v21_384": "weights/midas_v21_384.pt",
+ "midas_v21_small_256": "weights/midas_v21_small_256.pt",
+ "openvino_midas_v21_small_256": "weights/openvino_midas_v21_small_256.xml",
+}
+
+
+def load_model(device, model_path, model_type="dpt_large_384", optimize=True, height=None, square=False):
+ """Load the specified network.
+
+ Args:
+ device (device): the torch device used
+ model_path (str): path to saved model
+ model_type (str): the type of the model to be loaded
+ optimize (bool): optimize the model to half-integer on CUDA?
+ height (int): inference encoder image height
+ square (bool): resize to a square resolution?
+
+ Returns:
+ The loaded network, the transform which prepares images as input to the network and the dimensions of the
+ network input
+ """
+ if "openvino" in model_type:
+ from openvino.runtime import Core
+
+ keep_aspect_ratio = not square
+
+ if model_type == "dpt_beit_large_512":
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="beitl16_512",
+ non_negative=True,
+ )
+ net_w, net_h = 512, 512
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_beit_large_384":
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="beitl16_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_beit_base_384":
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="beitb16_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_swin2_large_384":
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="swin2l24_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ keep_aspect_ratio = False
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_swin2_base_384":
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="swin2b24_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ keep_aspect_ratio = False
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_swin2_tiny_256":
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="swin2t16_256",
+ non_negative=True,
+ )
+ net_w, net_h = 256, 256
+ keep_aspect_ratio = False
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_swin_large_384":
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="swinl12_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ keep_aspect_ratio = False
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_next_vit_large_384":
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="next_vit_large_6m",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ # We change the notation from dpt_levit_224 (MiDaS notation) to levit_384 (timm notation) here, where the 224 refers
+ # to the resolution 224x224 used by LeViT and 384 is the first entry of the embed_dim, see _cfg and model_cfgs of
+ # https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/levit.py
+ # (commit id: 927f031293a30afb940fff0bee34b85d9c059b0e)
+ elif model_type == "dpt_levit_224":
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="levit_384",
+ non_negative=True,
+ head_features_1=64,
+ head_features_2=8,
+ )
+ net_w, net_h = 224, 224
+ keep_aspect_ratio = False
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_large_384":
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="vitl16_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_hybrid_384":
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="vitb_rn50_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "midas_v21_384":
+ model = MidasNet(model_path, non_negative=True)
+ net_w, net_h = 384, 384
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ elif model_type == "midas_v21_small_256":
+ model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
+ non_negative=True, blocks={'expand': True})
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ elif model_type == "openvino_midas_v21_small_256":
+ ie = Core()
+ uncompiled_model = ie.read_model(model=model_path)
+ model = ie.compile_model(uncompiled_model, "CPU")
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ else:
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
+ assert False
+
+ if not "openvino" in model_type:
+ print("Model loaded, number of parameters = {:.0f}M".format(sum(p.numel() for p in model.parameters()) / 1e6))
+ else:
+ print("Model loaded, optimized with OpenVINO")
+
+ if "openvino" in model_type:
+ keep_aspect_ratio = False
+
+ if height is not None:
+ net_w, net_h = height, height
+
+ transform = Compose(
+ [
+ Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=keep_aspect_ratio,
+ ensure_multiple_of=32,
+ resize_method=resize_mode,
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ normalization,
+ PrepareForNet(),
+ ]
+ )
+
+ if not "openvino" in model_type:
+ model.eval()
+
+ if optimize and (device == torch.device("cuda")):
+ if not "openvino" in model_type:
+ model = model.to(memory_format=torch.channels_last)
+ model = model.half()
+ else:
+ print("Error: OpenVINO models are already optimized. No optimization to half-float possible.")
+ exit()
+
+ if not "openvino" in model_type:
+ model.to(device)
+
+ return model, transform, net_w, net_h
diff --git a/midas/transforms.py b/midas/transforms.py
new file mode 100644
index 0000000..350cbc1
--- /dev/null
+++ b/midas/transforms.py
@@ -0,0 +1,234 @@
+import numpy as np
+import cv2
+import math
+
+
+def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
+
+ Args:
+ sample (dict): sample
+ size (tuple): image size
+
+ Returns:
+ tuple: new size
+ """
+ shape = list(sample["disparity"].shape)
+
+ if shape[0] >= size[0] and shape[1] >= size[1]:
+ return sample
+
+ scale = [0, 0]
+ scale[0] = size[0] / shape[0]
+ scale[1] = size[1] / shape[1]
+
+ scale = max(scale)
+
+ shape[0] = math.ceil(scale * shape[0])
+ shape[1] = math.ceil(scale * shape[1])
+
+ # resize
+ sample["image"] = cv2.resize(
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
+ )
+
+ sample["disparity"] = cv2.resize(
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
+ )
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ tuple(shape[::-1]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return tuple(shape)
+
+
+class Resize(object):
+ """Resize sample to given size (width, height).
+ """
+
+ def __init__(
+ self,
+ width,
+ height,
+ resize_target=True,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=1,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_AREA,
+ ):
+ """Init.
+
+ Args:
+ width (int): desired output width
+ height (int): desired output height
+ resize_target (bool, optional):
+ True: Resize the full sample (image, mask, target).
+ False: Resize image only.
+ Defaults to True.
+ keep_aspect_ratio (bool, optional):
+ True: Keep the aspect ratio of the input sample.
+ Output sample might not have the given width and height, and
+ resize behaviour depends on the parameter 'resize_method'.
+ Defaults to False.
+ ensure_multiple_of (int, optional):
+ Output width and height is constrained to be multiple of this parameter.
+ Defaults to 1.
+ resize_method (str, optional):
+ "lower_bound": Output will be at least as large as the given size.
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
+ Defaults to "lower_bound".
+ """
+ self.__width = width
+ self.__height = height
+
+ self.__resize_target = resize_target
+ self.__keep_aspect_ratio = keep_aspect_ratio
+ self.__multiple_of = ensure_multiple_of
+ self.__resize_method = resize_method
+ self.__image_interpolation_method = image_interpolation_method
+
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if max_val is not None and y > max_val:
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if y < min_val:
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ return y
+
+ def get_size(self, width, height):
+ # determine new height and width
+ scale_height = self.__height / height
+ scale_width = self.__width / width
+
+ if self.__keep_aspect_ratio:
+ if self.__resize_method == "lower_bound":
+ # scale such that output size is lower bound
+ if scale_width > scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "upper_bound":
+ # scale such that output size is upper bound
+ if scale_width < scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "minimal":
+ # scale as least as possbile
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ else:
+ raise ValueError(
+ f"resize_method {self.__resize_method} not implemented"
+ )
+
+ if self.__resize_method == "lower_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, min_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, min_val=self.__width
+ )
+ elif self.__resize_method == "upper_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, max_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, max_val=self.__width
+ )
+ elif self.__resize_method == "minimal":
+ new_height = self.constrain_to_multiple_of(scale_height * height)
+ new_width = self.constrain_to_multiple_of(scale_width * width)
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ return (new_width, new_height)
+
+ def __call__(self, sample):
+ width, height = self.get_size(
+ sample["image"].shape[1], sample["image"].shape[0]
+ )
+
+ # resize sample
+ sample["image"] = cv2.resize(
+ sample["image"],
+ (width, height),
+ interpolation=self.__image_interpolation_method,
+ )
+
+ if self.__resize_target:
+ if "disparity" in sample:
+ sample["disparity"] = cv2.resize(
+ sample["disparity"],
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ if "depth" in sample:
+ sample["depth"] = cv2.resize(
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
+ )
+
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return sample
+
+
+class NormalizeImage(object):
+ """Normlize image by given mean and std.
+ """
+
+ def __init__(self, mean, std):
+ self.__mean = mean
+ self.__std = std
+
+ def __call__(self, sample):
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+ return sample
+
+
+class PrepareForNet(object):
+ """Prepare sample for usage as network input.
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, sample):
+ image = np.transpose(sample["image"], (2, 0, 1))
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+ if "mask" in sample:
+ sample["mask"] = sample["mask"].astype(np.float32)
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+ if "disparity" in sample:
+ disparity = sample["disparity"].astype(np.float32)
+ sample["disparity"] = np.ascontiguousarray(disparity)
+
+ if "depth" in sample:
+ depth = sample["depth"].astype(np.float32)
+ sample["depth"] = np.ascontiguousarray(depth)
+
+ return sample
diff --git a/nerf/clip.py b/nerf/clip.py
new file mode 100644
index 0000000..2895474
--- /dev/null
+++ b/nerf/clip.py
@@ -0,0 +1,127 @@
+import torch
+import torch.nn as nn
+
+import torchvision.transforms as T
+import torchvision.transforms.functional as TF
+
+# import clip
+from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer, CLIPProcessor
+from torchvision import transforms
+
+import torch.nn.functional as F
+
+
+def spherical_dist_loss(x, y):
+ x = F.normalize(x, dim=-1)
+ y = F.normalize(y, dim=-1)
+ # print(x.shape, y.shape)
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
+
+
+class CLIP(nn.Module):
+ def __init__(self, device,
+ # clip_name = 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
+ clip_name = 'openai/clip-vit-large-patch14'
+ ):
+ super().__init__()
+
+ self.device = device
+
+ clip_name = clip_name
+
+ # self.feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_name)
+ self.clip_model = CLIPModel.from_pretrained(clip_name).cuda()
+ self.tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14')
+ # self.tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch32')
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
+
+ # self.normalize = transforms.Normalize(mean=self.feature_extractor.image_mean, std=self.feature_extractor.image_std)
+
+ # self.resize = transforms.Resize(224)
+
+ # # image augmentation
+ # self.aug = T.Compose([
+ # T.Resize((224, 224)),
+ # T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+ # ])
+
+
+ def get_text_embeds(self, prompt, neg_prompt=None, dir=None):
+
+ clip_text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ ).input_ids.cuda()
+ text_z = self.clip_model.get_text_features(clip_text_input)
+ # text = clip.tokenize(prompt).to(self.device)
+ # text_z = self.clip_model.encode_text(text)
+ text_z = text_z / text_z.norm(dim=-1, keepdim=True)
+
+ return text_z
+
+ def set_epoch(self, epoch):
+ pass
+
+ def get_img_embeds(self, img):
+ img = self.aug(img)
+ image_z = self.clip_model.get_image_features(img)
+ image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
+ return image_z
+
+
+ # def train_step(self, text_z, pred_rgb, image_ref_clip, **kwargs):
+
+ # pred_rgb = self.resize(pred_rgb)
+ # pred_rgb = self.normalize(pred_rgb)
+
+ # image_z = self.clip_model.get_image_features(pred_rgb)
+ # image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
+
+ # # print(image_z.shape, text_z.shape)
+ # loss = spherical_dist_loss(image_z, image_ref_clip)
+
+ # # loss = - (image_z * text_z).sum(-1).mean()
+
+ # return loss
+
+ def train_step(self, text_z, pred_rgb):
+
+ pred_rgb = self.aug(pred_rgb)
+
+ image_z = self.clip_model.encode_image(pred_rgb)
+ image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
+
+ loss = - (image_z * text_z).sum(-1).mean()
+ # loss = spherical_dist_loss(image_z, text_z)
+ return loss
+
+ def text_loss(self, text_z, pred_rgb):
+
+ pred_rgb = self.resize(pred_rgb)
+ pred_rgb = self.normalize(pred_rgb)
+
+ image_z = self.clip_model.get_image_features(pred_rgb)
+ image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
+
+ # print(image_z.shape, text_z.shape)
+ loss = spherical_dist_loss(image_z, text_z)
+
+ # loss = - (image_z * text_z).sum(-1).mean()
+
+ return loss
+
+ def img_loss(self, img_ref_z, pred_rgb):
+ # pred_rgb = self.aug(pred_rgb)
+ pred_rgb = self.resize(pred_rgb)
+ pred_rgb = self.normalize(pred_rgb)
+
+ image_z = self.clip_model.get_image_features(pred_rgb)
+ image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features
+
+ # loss = - (image_z * img_ref_z).sum(-1).mean()
+ loss = spherical_dist_loss(image_z, img_ref_z)
+
+ return loss
diff --git a/nerf/gui.py b/nerf/gui.py
new file mode 100644
index 0000000..65faa5c
--- /dev/null
+++ b/nerf/gui.py
@@ -0,0 +1,485 @@
+import math
+import torch
+import numpy as np
+import dearpygui.dearpygui as dpg
+from scipy.spatial.transform import Rotation as R
+
+from nerf.utils import *
+
+
+class OrbitCamera:
+ def __init__(self, W, H, r=2, fovy=60):
+ self.W = W
+ self.H = H
+ self.radius = r # camera distance from center
+ self.fovy = fovy # in degree
+ self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point
+ self.rot = R.from_matrix(np.eye(3))
+ self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized!
+ self.near = 0.001
+ self.far = 1000
+
+ # pose
+ @property
+ def pose(self):
+ # first move camera to radius
+ res = np.eye(4, dtype=np.float32)
+ res[2, 3] = self.radius
+ # rotate
+ rot = np.eye(4, dtype=np.float32)
+ rot[:3, :3] = self.rot.as_matrix()
+ res = rot @ res
+ # translate
+ res[:3, 3] -= self.center
+ return res
+
+ # intrinsics
+ @property
+ def intrinsics(self):
+ focal = self.H / (2 * np.tan(np.deg2rad(self.fovy) / 2))
+ return np.array([focal, focal, self.W // 2, self.H // 2])
+
+ @property
+ def mvp(self):
+ focal = self.H / (2 * np.tan(np.deg2rad(self.fovy) / 2))
+ projection = np.array([
+ [2*focal/self.W, 0, 0, 0],
+ [0, -2*focal/self.H, 0, 0],
+ [0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)],
+ [0, 0, -1, 0]
+ ], dtype=np.float32)
+
+ return projection @ np.linalg.inv(self.pose) # [4, 4]
+
+ def orbit(self, dx, dy):
+ # rotate along camera up/side axis!
+ side = self.rot.as_matrix()[:3, 0] # why this is side --> ? # already normalized.
+ rotvec_x = self.up * np.deg2rad(-0.1 * dx)
+ rotvec_y = side * np.deg2rad(-0.1 * dy)
+ self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot
+
+ def scale(self, delta):
+ self.radius *= 1.1 ** (-delta)
+
+ def pan(self, dx, dy, dz=0):
+ # pan in camera coordinate system (careful on the sensitivity!)
+ self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([dx, -dy, dz])
+
+
+class NeRFGUI:
+ def __init__(self, opt, trainer, loader=None, debug=True):
+ self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
+ self.W = opt.W
+ self.H = opt.H
+ self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
+ self.debug = debug
+ self.bg_color = torch.ones(3, dtype=torch.float32) # default white bg
+ self.training = False
+ self.step = 0 # training step
+
+ self.trainer = trainer
+ self.loader = loader
+ self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
+ self.need_update = True # camera moved, should reset accumulation
+ self.spp = 1 # sample per pixel
+ self.light_dir = np.array([opt.light_theta, opt.light_phi])
+ self.ambient_ratio = 1.0
+ self.mode = 'image' # choose from ['image', 'depth']
+ self.shading = 'albedo'
+
+ self.dynamic_resolution = True if not self.opt.dmtet else False
+ self.downscale = 1
+ self.train_steps = 16
+
+ dpg.create_context()
+ self.register_dpg()
+ self.test_step()
+
+
+ def __del__(self):
+ dpg.destroy_context()
+
+
+ def train_step(self):
+
+ starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
+ starter.record()
+
+ outputs = self.trainer.train_gui(self.loader, step=self.train_steps)
+
+ ender.record()
+ torch.cuda.synchronize()
+ t = starter.elapsed_time(ender)
+
+ self.step += self.train_steps
+ self.need_update = True
+
+ dpg.set_value("_log_train_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
+ dpg.set_value("_log_train_log", f'step = {self.step: 5d} (+{self.train_steps: 2d}), loss = {outputs["loss"]:.4f}, lr = {outputs["lr"]:.5f}')
+
+ # dynamic train steps
+ # max allowed train time per-frame is 500 ms
+ full_t = t / self.train_steps * 16
+ train_steps = min(16, max(4, int(16 * 500 / full_t)))
+ if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8:
+ self.train_steps = train_steps
+
+
+ def prepare_buffer(self, outputs):
+ if self.mode == 'image':
+ return outputs['image'].astype(np.float32)
+ else:
+ depth = outputs['depth'].astype(np.float32)
+ depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-6)
+ return np.expand_dims(depth, -1).repeat(3, -1)
+
+
+ def test_step(self):
+
+ if self.need_update or self.spp < self.opt.max_spp:
+
+ starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
+ starter.record()
+
+ outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.cam.mvp, self.W, self.H, self.bg_color, self.spp, self.downscale, self.light_dir, self.ambient_ratio, self.shading)
+
+ ender.record()
+ torch.cuda.synchronize()
+ t = starter.elapsed_time(ender)
+
+ # update dynamic resolution
+ if self.dynamic_resolution:
+ # max allowed infer time per-frame is 200 ms
+ full_t = t / (self.downscale ** 2)
+ downscale = min(1, max(1/4, math.sqrt(200 / full_t)))
+ if downscale > self.downscale * 1.2 or downscale < self.downscale * 0.8:
+ self.downscale = downscale
+
+ if self.need_update:
+ self.render_buffer = self.prepare_buffer(outputs)
+ self.spp = 1
+ self.need_update = False
+ else:
+ self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1)
+ self.spp += 1
+
+ dpg.set_value("_log_infer_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
+ dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}')
+ dpg.set_value("_log_spp", self.spp)
+ dpg.set_value("_texture", self.render_buffer)
+
+
+ def register_dpg(self):
+
+ ### register texture
+
+ with dpg.texture_registry(show=False):
+ dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture")
+
+ ### register window
+
+ # the rendered image, as the primary window
+ with dpg.window(tag="_primary_window", width=self.W, height=self.H):
+
+ # add the texture
+ dpg.add_image("_texture")
+
+ dpg.set_primary_window("_primary_window", True)
+
+ # control window
+ with dpg.window(label="Control", tag="_control_window", width=400, height=300):
+
+ # text prompt
+ if self.opt.text is not None:
+ dpg.add_text("text: " + self.opt.text, tag="_log_prompt_text")
+
+ if self.opt.negative != '':
+ dpg.add_text("negative text: " + self.opt.negative, tag="_log_prompt_negative_text")
+
+ # button theme
+ with dpg.theme() as theme_button:
+ with dpg.theme_component(dpg.mvButton):
+ dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))
+ dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))
+ dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))
+ dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)
+ dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)
+
+ # time
+ if not self.opt.test:
+ with dpg.group(horizontal=True):
+ dpg.add_text("Train time: ")
+ dpg.add_text("no data", tag="_log_train_time")
+
+ with dpg.group(horizontal=True):
+ dpg.add_text("Infer time: ")
+ dpg.add_text("no data", tag="_log_infer_time")
+
+ with dpg.group(horizontal=True):
+ dpg.add_text("SPP: ")
+ dpg.add_text("1", tag="_log_spp")
+
+ # train button
+ if not self.opt.test:
+ with dpg.collapsing_header(label="Train", default_open=True):
+ with dpg.group(horizontal=True):
+ dpg.add_text("Train: ")
+
+ def callback_train(sender, app_data):
+ if self.training:
+ self.training = False
+ dpg.configure_item("_button_train", label="start")
+ else:
+ self.training = True
+ dpg.configure_item("_button_train", label="stop")
+
+ dpg.add_button(label="start", tag="_button_train", callback=callback_train)
+ dpg.bind_item_theme("_button_train", theme_button)
+
+ def callback_reset(sender, app_data):
+ @torch.no_grad()
+ def weight_reset(m: nn.Module):
+ reset_parameters = getattr(m, "reset_parameters", None)
+ if callable(reset_parameters):
+ m.reset_parameters()
+ self.trainer.model.apply(fn=weight_reset)
+ self.trainer.model.reset_extra_state() # for cuda_ray density_grid and step_counter
+ self.need_update = True
+
+ dpg.add_button(label="reset", tag="_button_reset", callback=callback_reset)
+ dpg.bind_item_theme("_button_reset", theme_button)
+
+
+ with dpg.group(horizontal=True):
+ dpg.add_text("Checkpoint: ")
+
+ def callback_save(sender, app_data):
+ self.trainer.save_checkpoint(full=True, best=False)
+ dpg.set_value("_log_ckpt", "saved " + os.path.basename(self.trainer.stats["checkpoints"][-1]))
+ self.trainer.epoch += 1 # use epoch to indicate different calls.
+
+ dpg.add_button(label="save", tag="_button_save", callback=callback_save)
+ dpg.bind_item_theme("_button_save", theme_button)
+
+ dpg.add_text("", tag="_log_ckpt")
+
+ # save mesh
+ with dpg.group(horizontal=True):
+ dpg.add_text("Marching Cubes: ")
+
+ def callback_mesh(sender, app_data):
+ self.trainer.save_mesh()
+ dpg.set_value("_log_mesh", "saved " + f'{self.trainer.name}_{self.trainer.epoch}.ply')
+ self.trainer.epoch += 1 # use epoch to indicate different calls.
+
+ dpg.add_button(label="mesh", tag="_button_mesh", callback=callback_mesh)
+ dpg.bind_item_theme("_button_mesh", theme_button)
+
+ dpg.add_text("", tag="_log_mesh")
+
+ with dpg.group(horizontal=True):
+ dpg.add_text("", tag="_log_train_log")
+
+
+ # rendering options
+ with dpg.collapsing_header(label="Options", default_open=True):
+
+ # dynamic rendering resolution
+ with dpg.group(horizontal=True):
+
+ def callback_set_dynamic_resolution(sender, app_data):
+ if self.dynamic_resolution:
+ self.dynamic_resolution = False
+ self.downscale = 1
+ else:
+ self.dynamic_resolution = True
+ self.need_update = True
+
+ dpg.add_checkbox(label="dynamic resolution", default_value=self.dynamic_resolution, callback=callback_set_dynamic_resolution)
+ dpg.add_text(f"{self.W}x{self.H}", tag="_log_resolution")
+
+ # mode combo
+ def callback_change_mode(sender, app_data):
+ self.mode = app_data
+ self.need_update = True
+
+ dpg.add_combo(('image', 'depth'), label='mode', default_value=self.mode, callback=callback_change_mode)
+
+ # bg_color picker
+ def callback_change_bg(sender, app_data):
+ self.bg_color = torch.tensor(app_data[:3], dtype=torch.float32) # only need RGB in [0, 1]
+ self.need_update = True
+
+ dpg.add_color_edit((255, 255, 255), label="Background Color", width=200, tag="_color_editor", no_alpha=True, callback=callback_change_bg)
+
+ # fov slider
+ def callback_set_fovy(sender, app_data):
+ self.cam.fovy = app_data
+ self.need_update = True
+
+ dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy)
+
+ # dt_gamma slider
+ def callback_set_dt_gamma(sender, app_data):
+ self.opt.dt_gamma = app_data
+ self.need_update = True
+
+ dpg.add_slider_float(label="dt_gamma", min_value=0, max_value=0.1, format="%.5f", default_value=self.opt.dt_gamma, callback=callback_set_dt_gamma)
+
+ # max_steps slider
+ def callback_set_max_steps(sender, app_data):
+ self.opt.max_steps = app_data
+ self.need_update = True
+
+ dpg.add_slider_int(label="max steps", min_value=1, max_value=1024, format="%d", default_value=self.opt.max_steps, callback=callback_set_max_steps)
+
+ # aabb slider
+ def callback_set_aabb(sender, app_data, user_data):
+ # user_data is the dimension for aabb (xmin, ymin, zmin, xmax, ymax, zmax)
+ self.trainer.model.aabb_infer[user_data] = app_data
+
+ # also change train aabb ? [better not...]
+ #self.trainer.model.aabb_train[user_data] = app_data
+
+ self.need_update = True
+
+ dpg.add_separator()
+ dpg.add_text("Axis-aligned bounding box:")
+
+ with dpg.group(horizontal=True):
+ dpg.add_slider_float(label="x", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=0)
+ dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=3)
+
+ with dpg.group(horizontal=True):
+ dpg.add_slider_float(label="y", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=1)
+ dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=4)
+
+ with dpg.group(horizontal=True):
+ dpg.add_slider_float(label="z", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=2)
+ dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=5)
+
+ # light dir
+ def callback_set_light_dir(sender, app_data, user_data):
+ self.light_dir[user_data] = app_data
+ self.need_update = True
+
+ dpg.add_separator()
+ dpg.add_text("Plane Light Direction:")
+
+ with dpg.group(horizontal=True):
+ dpg.add_slider_float(label="theta", min_value=0, max_value=180, format="%.2f", default_value=self.opt.light_theta, callback=callback_set_light_dir, user_data=0)
+
+ with dpg.group(horizontal=True):
+ dpg.add_slider_float(label="phi", min_value=0, max_value=360, format="%.2f", default_value=self.opt.light_phi, callback=callback_set_light_dir, user_data=1)
+
+ # ambient ratio
+ def callback_set_abm_ratio(sender, app_data):
+ self.ambient_ratio = app_data
+ self.need_update = True
+
+ dpg.add_slider_float(label="ambient", min_value=0, max_value=1.0, format="%.5f", default_value=self.ambient_ratio, callback=callback_set_abm_ratio)
+
+ # shading mode
+ def callback_change_shading(sender, app_data):
+ self.shading = app_data
+ self.need_update = True
+
+ dpg.add_combo(('albedo', 'lambertian', 'textureless', 'normal'), label='shading', default_value=self.shading, callback=callback_change_shading)
+
+
+ # debug info
+ if self.debug:
+ with dpg.collapsing_header(label="Debug"):
+ # pose
+ dpg.add_separator()
+ dpg.add_text("Camera Pose:")
+ dpg.add_text(str(self.cam.pose), tag="_log_pose")
+
+
+ ### register camera handler
+
+ def callback_camera_drag_rotate(sender, app_data):
+
+ if not dpg.is_item_focused("_primary_window"):
+ return
+
+ dx = app_data[1]
+ dy = app_data[2]
+
+ self.cam.orbit(dx, dy)
+ self.need_update = True
+
+ if self.debug:
+ dpg.set_value("_log_pose", str(self.cam.pose))
+
+
+ def callback_camera_wheel_scale(sender, app_data):
+
+ if not dpg.is_item_focused("_primary_window"):
+ return
+
+ delta = app_data
+
+ self.cam.scale(delta)
+ self.need_update = True
+
+ if self.debug:
+ dpg.set_value("_log_pose", str(self.cam.pose))
+
+
+ def callback_camera_drag_pan(sender, app_data):
+
+ if not dpg.is_item_focused("_primary_window"):
+ return
+
+ dx = app_data[1]
+ dy = app_data[2]
+
+ self.cam.pan(dx, dy)
+ self.need_update = True
+
+ if self.debug:
+ dpg.set_value("_log_pose", str(self.cam.pose))
+
+
+ with dpg.handler_registry():
+ dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate)
+ dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
+ dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Right, callback=callback_camera_drag_pan)
+
+
+ dpg.create_viewport(title='torch-ngp', width=self.W, height=self.H, resizable=False)
+
+ # TODO: seems dearpygui doesn't support resizing texture...
+ # def callback_resize(sender, app_data):
+ # self.W = app_data[0]
+ # self.H = app_data[1]
+ # # how to reload texture ???
+
+ # dpg.set_viewport_resize_callback(callback_resize)
+
+ ### global theme
+ with dpg.theme() as theme_no_padding:
+ with dpg.theme_component(dpg.mvAll):
+ # set all padding to 0 to avoid scroll bar
+ dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core)
+ dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core)
+ dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core)
+
+ dpg.bind_item_theme("_primary_window", theme_no_padding)
+
+ dpg.setup_dearpygui()
+
+ #dpg.show_metrics()
+
+ dpg.show_viewport()
+
+
+ def render(self):
+
+ while dpg.is_dearpygui_running():
+ # update texture every frame
+ if self.training:
+ self.train_step()
+ self.test_step()
+ dpg.render_dearpygui_frame()
\ No newline at end of file
diff --git a/nerf/network.py b/nerf/network.py
new file mode 100644
index 0000000..685e8f9
--- /dev/null
+++ b/nerf/network.py
@@ -0,0 +1,238 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from activation import trunc_exp
+from .renderer import NeRFRenderer
+
+import numpy as np
+from encoding import get_encoder
+
+from .utils import safe_normalize
+from tqdm import tqdm
+
+
+class ResBlock(nn.Module):
+ def __init__(self, dim_in, dim_out, bias=True):
+ super().__init__()
+ self.dim_in = dim_in
+ self.dim_out = dim_out
+
+ self.dense = nn.Linear(self.dim_in, self.dim_out, bias=bias)
+ self.norm = nn.LayerNorm(self.dim_out)
+ self.activation = nn.SiLU(inplace=True)
+
+ if self.dim_in != self.dim_out:
+ self.skip = nn.Linear(self.dim_in, self.dim_out, bias=False)
+ else:
+ self.skip = None
+
+ def forward(self, x):
+ # x: [B, C]
+ identity = x
+
+ out = self.dense(x)
+ out = self.norm(out)
+
+ if self.skip is not None:
+ identity = self.skip(identity)
+
+ out += identity
+ out = self.activation(out)
+
+ return out
+
+class BasicBlock(nn.Module):
+ def __init__(self, dim_in, dim_out, bias=True):
+ super().__init__()
+ self.dim_in = dim_in
+ self.dim_out = dim_out
+
+ self.dense = nn.Linear(self.dim_in, self.dim_out, bias=bias)
+ self.activation = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ # x: [B, C]
+
+ out = self.dense(x)
+ out = self.activation(out)
+
+ return out
+
+class MLP(nn.Module):
+ def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True, block=BasicBlock):
+ super().__init__()
+ self.dim_in = dim_in
+ self.dim_out = dim_out
+ self.dim_hidden = dim_hidden
+ self.num_layers = num_layers
+
+ net = []
+ for l in range(num_layers):
+ if l == 0:
+ net.append(BasicBlock(self.dim_in, self.dim_hidden, bias=bias))
+ elif l != num_layers - 1:
+ net.append(block(self.dim_hidden, self.dim_hidden, bias=bias))
+ else:
+ net.append(nn.Linear(self.dim_hidden, self.dim_out, bias=bias))
+
+ self.net = nn.ModuleList(net)
+
+ def reset_parameters(self):
+ @torch.no_grad()
+ def weight_init(m):
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
+ nn.init.zeros_(m.bias)
+ self.apply(weight_init)
+
+ def forward(self, x):
+
+ for l in range(self.num_layers):
+ x = self.net[l](x)
+
+ return x
+
+
+class NeRFNetwork(NeRFRenderer):
+ def __init__(self,
+ opt,
+ num_layers=5, # 5 in paper
+ hidden_dim=64, # 128 in paper
+ num_layers_bg=2, # 3 in paper
+ hidden_dim_bg=32, # 64 in paper
+ encoding='frequency_torch', # pure pytorch
+ output_dim=4, # 7 for DMTet (sdf 1 + color 3 + deform 3), 4 for NeRF
+ ):
+
+ super().__init__(opt)
+ self.num_layers = opt.num_layers if hasattr(opt, 'num_layers') else num_layers
+ self.hidden_dim = opt.hidden_dim if hasattr(opt, 'hidden_dim') else hidden_dim
+ num_layers_bg = opt.num_layers_bg if hasattr(opt, 'num_layers_bg') else num_layers_bg
+ hidden_dim_bg = opt.hidden_dim_bg if hasattr(opt, 'hidden_dim_bg') else hidden_dim_bg
+
+ self.encoder, self.in_dim = get_encoder(encoding, input_dim=3, multires=6)
+ self.sigma_net = MLP(self.in_dim, output_dim, hidden_dim, num_layers, bias=True, block=ResBlock)
+
+ self.grid_levels_mask = 0
+
+ # background network
+ if self.opt.bg_radius > 0:
+ self.num_layers_bg = num_layers_bg
+ self.hidden_dim_bg = hidden_dim_bg
+ self.encoder_bg, self.in_dim_bg = get_encoder(encoding, input_dim=3, multires=4)
+ self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
+
+ else:
+ self.bg_net = None
+
+ def common_forward(self, x):
+ # x: [N, 3], in [-bound, bound]
+
+ # sigma
+ h = self.encoder(x, bound=self.bound, max_level=self.max_level)
+
+ # Feature masking for coarse-to-fine training
+ if self.grid_levels_mask > 0:
+ h_mask: torch.Tensor = torch.arange(self.in_dim, device=h.device) < self.in_dim - self.grid_levels_mask * self.level_dim # (self.in_dim)
+ h_mask = h_mask.reshape(1, self.in_dim).float() # (1, self.in_dim)
+ h = h * h_mask # (N, self.in_dim)
+
+ h = self.sigma_net(h)
+
+ sigma = self.density_activation(h[..., 0] + self.density_blob(x))
+ albedo = torch.sigmoid(h[..., 1:])
+
+ return sigma, albedo
+
+ def normal(self, x):
+
+ with torch.enable_grad():
+ x.requires_grad_(True)
+ sigma, albedo = self.common_forward(x)
+ # query gradient
+ normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
+
+ # normal = self.finite_difference_normal(x)
+ normal = safe_normalize(normal)
+ # normal = torch.nan_to_num(normal)
+
+ return normal
+
+ def forward(self, x, d, l=None, ratio=1, shading='albedo'):
+ # x: [N, 3], in [-bound, bound]
+ # d: [N, 3], view direction, nomalized in [-1, 1]
+ # l: [3], plane light direction, nomalized in [-1, 1]
+ # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
+
+ if shading == 'albedo':
+ # no need to query normal
+ sigma, color = self.common_forward(x)
+ normal = None
+
+ else:
+ # query normal
+
+ # sigma, albedo = self.common_forward(x)
+ # normal = self.normal(x)
+
+ with torch.enable_grad():
+ x.requires_grad_(True)
+ sigma, albedo = self.common_forward(x)
+ # query gradient
+ normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
+ normal = safe_normalize(normal)
+ # normal = torch.nan_to_num(normal)
+ # normal = normal.detach()
+
+ # lambertian shading
+ lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]
+
+ if shading == 'textureless':
+ color = lambertian.unsqueeze(-1).repeat(1, 3)
+ elif shading == 'normal':
+ color = (normal + 1) / 2
+ else: # 'lambertian'
+ color = albedo * lambertian.unsqueeze(-1)
+
+ return sigma, color, normal
+
+
+ def density(self, x):
+ # x: [N, 3], in [-bound, bound]
+
+ sigma, albedo = self.common_forward(x)
+
+ return {
+ 'sigma': sigma,
+ 'albedo': albedo,
+ }
+
+
+ def background(self, d):
+
+ h = self.encoder_bg(d) # [N, C]
+
+ h = self.bg_net(h)
+
+ # sigmoid activation for rgb
+ rgbs = torch.sigmoid(h)
+
+ return rgbs
+
+ # optimizer utils
+ def get_params(self, lr):
+
+ params = [
+ # {'params': self.encoder.parameters(), 'lr': lr * 10},
+ {'params': self.sigma_net.parameters(), 'lr': lr * self.opt.lr_scale_nerf},
+ ]
+
+ if self.opt.bg_radius > 0:
+ # params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
+ params.append({'params': self.bg_net.parameters(), 'lr': lr})
+
+ if self.opt.dmtet:
+ params.append({'params': self.dmtet.parameters(), 'lr': lr})
+
+ return params
\ No newline at end of file
diff --git a/nerf/network_grid.py b/nerf/network_grid.py
new file mode 100644
index 0000000..3ad56d1
--- /dev/null
+++ b/nerf/network_grid.py
@@ -0,0 +1,216 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from activation import trunc_exp, biased_softplus
+from .renderer import NeRFRenderer, MLP
+
+import numpy as np
+from encoding import get_encoder
+
+from .utils import safe_normalize
+from tqdm import tqdm
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class NeRFNetwork(NeRFRenderer):
+ def __init__(self,
+ opt,
+ num_layers=3,
+ hidden_dim=64,
+ num_layers_bg=2,
+ hidden_dim_bg=32,
+ level_dim=2
+ ):
+
+ super().__init__(opt)
+
+ self.num_layers = opt.num_layers if hasattr(opt, 'num_layers') else num_layers
+ self.hidden_dim = opt.hidden_dim if hasattr(opt, 'hidden_dim') else hidden_dim
+ self.level_dim = opt.level_dim if hasattr(opt, 'level_dim') else level_dim
+ num_layers_bg = opt.num_layers_bg if hasattr(opt, 'num_layers_bg') else num_layers_bg
+ hidden_dim_bg = opt.hidden_dim_bg if hasattr(opt, 'hidden_dim_bg') else hidden_dim_bg
+
+ if self.opt.grid_type == 'hashgrid':
+ self.encoder, self.in_dim = get_encoder('hashgrid', input_dim=3, log2_hashmap_size=19, desired_resolution=2048 * self.bound, interpolation='smoothstep')
+ elif self.opt.grid_type == 'tilegrid':
+ self.encoder, self.in_dim = get_encoder(
+ 'tiledgrid',
+ input_dim=3,
+ level_dim=self.level_dim,
+ log2_hashmap_size=16,
+ num_levels=16,
+ desired_resolution= 2048 * self.bound,
+ )
+ self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
+ # self.normal_net = MLP(self.in_dim, 3, hidden_dim, num_layers, bias=True)
+
+ # masking
+ self.grid_levels_mask = 0
+
+ # background network
+ if self.opt.bg_radius > 0:
+ self.num_layers_bg = num_layers_bg
+ self.hidden_dim_bg = hidden_dim_bg
+
+ # use a very simple network to avoid it learning the prompt...
+ self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3, multires=6)
+ self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
+
+ else:
+ self.bg_net = None
+
+ def common_forward(self, x):
+
+ # sigma
+ h = self.encoder(x, bound=self.bound, max_level=self.max_level)
+
+ # Feature masking for coarse-to-fine training
+ if self.grid_levels_mask > 0:
+ h_mask: torch.Tensor = torch.arange(self.in_dim, device=h.device) < self.in_dim - self.grid_levels_mask * self.level_dim # (self.in_dim)
+ h_mask = h_mask.reshape(1, self.in_dim).float() # (1, self.in_dim)
+ h = h * h_mask # (N, self.in_dim)
+
+ h = self.sigma_net(h)
+
+ sigma = self.density_activation(h[..., 0] + self.density_blob(x))
+ albedo = torch.sigmoid(h[..., 1:])
+
+ return sigma, albedo
+
+
+ def forward(self, x, d, l=None, ratio=1, shading='albedo'):
+ # x: [N, 3], in [-bound, bound]
+ # d: [N, 3], view direction, nomalized in [-1, 1]
+ # l: [3], plane light direction, nomalized in [-1, 1]
+ # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
+
+ sigma, albedo = self.common_forward(x)
+
+ if shading == 'albedo':
+ normal = None
+ color = albedo
+
+ else: # lambertian shading
+
+ normal = self.normal(x)
+ if shading == 'normal':
+ color = (normal + 1) / 2
+ else:
+ lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]
+ if shading == 'textureless':
+ color = lambertian.unsqueeze(-1).repeat(1, 3)
+ else: # 'lambertian'
+ color = albedo * lambertian.unsqueeze(-1)
+
+ return sigma, color, normal
+
+
+ def density(self, x):
+ # x: [N, 3], in [-bound, bound]
+
+ sigma, albedo = self.common_forward(x)
+
+ return {
+ 'sigma': sigma,
+ 'albedo': albedo,
+ }
+
+
+ def background(self, d):
+
+ h = self.encoder_bg(d) # [N, C]
+
+ h = self.bg_net(h)
+
+ # sigmoid activation for rgb
+ rgbs = torch.sigmoid(h)
+
+ return rgbs
+
+ # optimizer utils
+ def get_params(self, lr):
+
+ params = [
+ {'params': self.encoder.parameters(), 'lr': lr * 10},
+ {'params': self.sigma_net.parameters(), 'lr': lr * self.opt.lr_scale_nerf},
+ # {'params': self.normal_net.parameters(), 'lr': lr},
+ ]
+
+ if self.opt.bg_radius > 0:
+ # params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
+ params.append({'params': self.bg_net.parameters(), 'lr': lr})
+
+ if self.opt.dmtet:
+ params.append({'params': self.dmtet.parameters(), 'lr': lr})
+
+ return params
+
+ def reset_sigmanet(self):
+ self.sigma_net.reset_parameters()
+
+ def init_nerf_from_sdf_color(self, rpst, albedo,
+ points=None, pretrain_iters=10000, lr=0.001, rpst_type='sdf',
+ ):
+ self.reset_sigmanet()
+ # matching optimization
+ self.train()
+ self.grid_levels_mask = 0
+ loss_fn = torch.nn.MSELoss()
+ optimizer = torch.optim.Adam(list(self.parameters()), lr=lr)
+
+ milestones = [int(0.4 * pretrain_iters), int(0.8 * pretrain_iters)]
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)
+
+ rpst = rpst.squeeze().clamp(0, 1)
+
+ # rpst = torch.ones_like(rpst) * 0.4
+ pbar = tqdm(range(pretrain_iters), desc="NeRF sigma optimization")
+ rgb_loss = torch.tensor(0, device=rpst.device)
+ for i in pbar:
+ output = self.density(points)
+ if rpst_type == 'sdf':
+ pred_rpst = output['sigma'] - self.density_thresh
+ else:
+ pred_rpst = output['sigma']
+ sdf_loss = loss_fn(pred_rpst, rpst)
+
+ if albedo is not None:
+ pred_albedo = output['albedo']
+ rgb_loss = loss_fn(pred_albedo, albedo)
+ loss = 10 * sdf_loss + rgb_loss
+ else:
+ loss = sdf_loss
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ scheduler.step()
+ pbar.set_postfix(loss=loss.item(), rgb_loss=rgb_loss.item(), sdf_loss=sdf_loss.item())
+ logger.info(f'lr: {lr} Accuracy: (pred_rpst[rpst>0]>0).sum() / (rpst>0).sum()')
+ pbar = tqdm(range(pretrain_iters), desc="NeRF color optimization")
+
+
+ def init_tet_from_sdf_color(self, sdf, colors=None, pretrain_iters=5000, lr=0.01):
+ self.train()
+ self.grid_levels_mask = 0
+
+ self.dmtet.reset_tet(reset_scale=False)
+ self.dmtet.init_tet_from_sdf(sdf, pretrain_iters=pretrain_iters, lr=lr)
+
+ if colors is not None:
+ self.reset_sigmanet()
+ loss_fn = torch.nn.MSELoss()
+ pretrain_iters = 5000
+ optimizer = torch.optim.Adam(list(self.parameters()), lr=0.01)
+ pbar = tqdm(range(pretrain_iters), desc="NeRF color optimization")
+ for i in pbar:
+ pred_albedo = self.density(self.dmtet.verts)['albedo']
+ loss = loss_fn(pred_albedo, colors)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ pbar.set_postfix(loss=loss.item())
diff --git a/nerf/network_grid_taichi.py b/nerf/network_grid_taichi.py
new file mode 100644
index 0000000..586faff
--- /dev/null
+++ b/nerf/network_grid_taichi.py
@@ -0,0 +1,161 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from activation import trunc_exp
+from .renderer import NeRFRenderer
+
+import numpy as np
+from encoding import get_encoder
+
+from .utils import safe_normalize
+from tqdm import tqdm
+
+
+class MLP(nn.Module):
+ def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
+ super().__init__()
+ self.dim_in = dim_in
+ self.dim_out = dim_out
+ self.dim_hidden = dim_hidden
+ self.num_layers = num_layers
+
+ net = []
+ for l in range(num_layers):
+ net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
+
+ self.net = nn.ModuleList(net)
+
+ def forward(self, x):
+ for l in range(self.num_layers):
+ x = self.net[l](x)
+ if l != self.num_layers - 1:
+ x = F.relu(x, inplace=True)
+ return x
+
+ def reset_parameters(self):
+ @torch.no_grad()
+ def weight_init(m):
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
+ nn.init.zeros_(m.bias)
+ self.apply(weight_init)
+
+
+class NeRFNetwork(NeRFRenderer):
+ def __init__(self,
+ opt,
+ num_layers=2,
+ hidden_dim=32,
+ num_layers_bg=2,
+ hidden_dim_bg=16,
+ ):
+
+ super().__init__(opt)
+ self.num_layers = opt.num_layers if hasattr(opt, 'num_layers') else num_layers
+ self.hidden_dim = opt.hidden_dim if hasattr(opt, 'hidden_dim') else hidden_dim
+ num_layers_bg = opt.num_layers_bg if hasattr(opt, 'num_layers_bg') else num_layers_bg
+ hidden_dim_bg = opt.hidden_dim_bg if hasattr(opt, 'hidden_dim_bg') else hidden_dim_bg
+
+ self.encoder, self.in_dim = get_encoder('hashgrid_taichi', input_dim=3, log2_hashmap_size=19, desired_resolution=2048 * self.bound, interpolation='smoothstep')
+
+ self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
+ # self.normal_net = MLP(self.in_dim, 3, hidden_dim, num_layers, bias=True)
+
+ self.grid_levels_mask = 0
+
+ # background network
+ if self.opt.bg_radius > 0:
+ self.num_layers_bg = num_layers_bg
+ self.hidden_dim_bg = hidden_dim_bg
+ # use a very simple network to avoid it learning the prompt...
+ self.encoder_bg, self.in_dim_bg = get_encoder('frequency_torch', input_dim=3, multires=4) # TODO: freq encoder can be replaced by a Taichi implementation
+ self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
+
+ else:
+ self.bg_net = None
+
+ def common_forward(self, x):
+
+ # sigma
+ h = self.encoder(x, bound=self.bound)
+
+ # Feature masking for coarse-to-fine training
+ if self.grid_levels_mask > 0:
+ h_mask: torch.Tensor = torch.arange(self.in_dim, device=h.device) < self.in_dim - self.grid_levels_mask * self.level_dim # (self.in_dim)
+ h_mask = h_mask.reshape(1, self.in_dim).float() # (1, self.in_dim)
+ h = h * h_mask # (N, self.in_dim)
+
+ h = self.sigma_net(h)
+
+ sigma = self.density_activation(h[..., 0] + self.density_blob(x))
+ albedo = torch.sigmoid(h[..., 1:])
+
+ return sigma, albedo
+
+ def forward(self, x, d, l=None, ratio=1, shading='albedo'):
+ # x: [N, 3], in [-bound, bound]
+ # d: [N, 3], view direction, nomalized in [-1, 1]
+ # l: [3], plane light direction, nomalized in [-1, 1]
+ # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
+
+ sigma, albedo = self.common_forward(x)
+
+ if shading == 'albedo':
+ normal = None
+ color = albedo
+
+ else: # lambertian shading
+ normal = self.normal(x)
+
+ lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]
+
+ if shading == 'textureless':
+ color = lambertian.unsqueeze(-1).repeat(1, 3)
+ elif shading == 'normal':
+ color = (normal + 1) / 2
+ else: # 'lambertian'
+ color = albedo * lambertian.unsqueeze(-1)
+
+ return sigma, color, normal
+
+
+ def density(self, x):
+ # x: [N, 3], in [-bound, bound]
+
+ sigma, albedo = self.common_forward(x)
+
+ return {
+ 'sigma': sigma,
+ 'albedo': albedo,
+ }
+
+
+ def background(self, d):
+
+ h = self.encoder_bg(d) # [N, C]
+
+ h = self.bg_net(h)
+
+ # sigmoid activation for rgb
+ rgbs = torch.sigmoid(h)
+
+ return rgbs
+
+ # optimizer utils
+ def get_params(self, lr):
+
+ params = [
+ {'params': self.encoder.parameters(), 'lr': lr * 10},
+ {'params': self.sigma_net.parameters(), 'lr': lr * self.opt.lr_scale_nerf},
+ # {'params': self.normal_net.parameters(), 'lr': lr},
+ ]
+
+ if self.opt.bg_radius > 0:
+ # params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10})
+ params.append({'params': self.bg_net.parameters(), 'lr': lr})
+
+ if self.opt.dmtet:
+ params.append({'params': self.dmtet.parameters(), 'lr': lr})
+
+ return params
diff --git a/nerf/network_grid_tcnn.py b/nerf/network_grid_tcnn.py
new file mode 100644
index 0000000..22ae1ff
--- /dev/null
+++ b/nerf/network_grid_tcnn.py
@@ -0,0 +1,178 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from activation import trunc_exp, biased_softplus
+from .renderer import NeRFRenderer
+
+import numpy as np
+from encoding import get_encoder
+
+from .utils import safe_normalize
+
+import tinycudann as tcnn
+
+class MLP(nn.Module):
+ def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
+ super().__init__()
+ self.dim_in = dim_in
+ self.dim_out = dim_out
+ self.dim_hidden = dim_hidden
+ self.num_layers = num_layers
+
+ net = []
+ for l in range(num_layers):
+ net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
+
+ self.net = nn.ModuleList(net)
+
+ def forward(self, x):
+ for l in range(self.num_layers):
+ x = self.net[l](x)
+ if l != self.num_layers - 1:
+ x = F.relu(x, inplace=True)
+ return x
+
+
+class NeRFNetwork(NeRFRenderer):
+ def __init__(self,
+ opt,
+ num_layers=3,
+ hidden_dim=64,
+ num_layers_bg=2,
+ hidden_dim_bg=32,
+ ):
+
+ super().__init__(opt)
+
+ self.num_layers = num_layers
+ self.hidden_dim = hidden_dim
+
+ self.encoder = tcnn.Encoding(
+ n_input_dims=3,
+ encoding_config={
+ "otype": "HashGrid",
+ "n_levels": 16,
+ "n_features_per_level": 2,
+ "log2_hashmap_size": 19,
+ "base_resolution": 16,
+ "interpolation": "Smoothstep",
+ "per_level_scale": np.exp2(np.log2(2048 * self.bound / 16) / (16 - 1)),
+ },
+ dtype=torch.float32, # ENHANCE: default float16 seems unstable...
+ )
+ self.in_dim = self.encoder.n_output_dims
+ # use torch MLP, as tcnn MLP doesn't impl second-order derivative
+ self.sigma_net = MLP(self.in_dim, 4, hidden_dim, num_layers, bias=True)
+
+ self.density_activation = trunc_exp if self.opt.density_activation == 'exp' else biased_softplus
+
+ # background network
+ if self.opt.bg_radius > 0:
+ self.num_layers_bg = num_layers_bg
+ self.hidden_dim_bg = hidden_dim_bg
+
+ # use a very simple network to avoid it learning the prompt...
+ self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3, multires=6)
+ self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True)
+
+ else:
+ self.bg_net = None
+
+ def common_forward(self, x):
+
+ # sigma
+ enc = self.encoder((x + self.bound) / (2 * self.bound)).float()
+ h = self.sigma_net(enc)
+
+ sigma = self.density_activation(h[..., 0] + self.density_blob(x))
+ albedo = torch.sigmoid(h[..., 1:])
+
+ return sigma, albedo
+
+ def normal(self, x):
+
+ with torch.enable_grad():
+ with torch.cuda.amp.autocast(enabled=False):
+ x.requires_grad_(True)
+ sigma, albedo = self.common_forward(x)
+ # query gradient
+ normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
+
+ # normal = self.finite_difference_normal(x)
+ normal = safe_normalize(normal)
+ normal = torch.nan_to_num(normal)
+
+ return normal
+
+ def forward(self, x, d, l=None, ratio=1, shading='albedo'):
+ # x: [N, 3], in [-bound, bound]
+ # d: [N, 3], view direction, nomalized in [-1, 1]
+ # l: [3], plane light direction, nomalized in [-1, 1]
+ # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless)
+
+
+ if shading == 'albedo':
+ sigma, albedo = self.common_forward(x)
+ normal = None
+ color = albedo
+
+ else: # lambertian shading
+ with torch.enable_grad():
+ with torch.cuda.amp.autocast(enabled=False):
+ x.requires_grad_(True)
+ sigma, albedo = self.common_forward(x)
+ normal = - torch.autograd.grad(torch.sum(sigma), x, create_graph=True)[0] # [N, 3]
+ normal = safe_normalize(normal)
+ normal = torch.nan_to_num(normal)
+
+ lambertian = ratio + (1 - ratio) * (normal * l).sum(-1).clamp(min=0) # [N,]
+
+ if shading == 'textureless':
+ color = lambertian.unsqueeze(-1).repeat(1, 3)
+ elif shading == 'normal':
+ color = (normal + 1) / 2
+ else: # 'lambertian'
+ color = albedo * lambertian.unsqueeze(-1)
+
+ return sigma, color, normal
+
+
+ def density(self, x):
+ # x: [N, 3], in [-bound, bound]
+
+ sigma, albedo = self.common_forward(x)
+
+ return {
+ 'sigma': sigma,
+ 'albedo': albedo,
+ }
+
+
+ def background(self, d):
+
+ h = self.encoder_bg(d) # [N, C]
+
+ h = self.bg_net(h)
+
+ # sigmoid activation for rgb
+ rgbs = torch.sigmoid(h)
+
+ return rgbs
+
+ # optimizer utils
+ def get_params(self, lr):
+
+ params = [
+ {'params': self.encoder.parameters(), 'lr': lr * 10},
+ {'params': self.sigma_net.parameters(), 'lr': lr},
+ ]
+
+ if self.opt.bg_radius > 0:
+ params.append({'params': self.bg_net.parameters(), 'lr': lr})
+
+ if self.opt.dmtet:
+ params.append({'params': self.sdf, 'lr': lr})
+ params.append({'params': self.deform, 'lr': lr})
+
+ return params
\ No newline at end of file
diff --git a/nerf/provider.py b/nerf/provider.py
new file mode 100644
index 0000000..219b6ec
--- /dev/null
+++ b/nerf/provider.py
@@ -0,0 +1,329 @@
+import random
+import numpy as np
+from scipy.spatial.transform import Slerp, Rotation
+
+
+import torch
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+
+from .utils import get_rays, safe_normalize
+
+DIR_COLORS = np.array([
+ [255, 0, 0, 255], # front
+ [0, 255, 0, 255], # side
+ [0, 0, 255, 255], # back
+ [255, 255, 0, 255], # side
+ [255, 0, 255, 255], # overhead
+ [0, 255, 255, 255], # bottom
+], dtype=np.uint8)
+
+def visualize_poses(poses, dirs, size=0.1):
+ # poses: [B, 4, 4], dirs: [B]
+ import trimesh
+ axes = trimesh.creation.axis(axis_length=4)
+ sphere = trimesh.creation.icosphere(radius=1)
+ objects = [axes, sphere]
+
+ for pose, dir in zip(poses, dirs):
+ # a camera is visualized with 8 line segments.
+ pos = pose[:3, 3]
+ a = pos + size * pose[:3, 0] + size * pose[:3, 1] - size * pose[:3, 2]
+ b = pos - size * pose[:3, 0] + size * pose[:3, 1] - size * pose[:3, 2]
+ c = pos - size * pose[:3, 0] - size * pose[:3, 1] - size * pose[:3, 2]
+ d = pos + size * pose[:3, 0] - size * pose[:3, 1] - size * pose[:3, 2]
+
+ segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a]])
+ segs = trimesh.load_path(segs)
+
+ # different color for different dirs
+ segs.colors = DIR_COLORS[[dir]].repeat(len(segs.entities), 0)
+
+ objects.append(segs)
+
+ trimesh.Scene(objects).show()
+
+def get_view_direction(thetas, phis, overhead, front):
+ # phis [B,]; thetas: [B,]
+ # front = 0 [0, front)
+ # side (right) = 1 [front, 180)
+ # back = 2 [180, 180+front)
+ # side (left) = 3 [180+front, 360)
+ # top = 4 [0, overhead]
+ # bottom = 5 [180-overhead, 180]
+ res = torch.zeros(thetas.shape[0], dtype=torch.long)
+ # first determine by phis
+ phis = phis % (2 * np.pi)
+ res[(phis < front / 2) | (phis >= 2 * np.pi - front / 2)] = 0
+ res[(phis >= front / 2) & (phis < np.pi - front / 2)] = 1
+ res[(phis >= np.pi - front / 2) & (phis < np.pi + front / 2)] = 2
+ res[(phis >= np.pi + front / 2) & (phis < 2 * np.pi - front / 2)] = 3
+ # override by thetas
+ res[thetas <= overhead] = 4
+ res[thetas >= (np.pi - overhead)] = 5
+ return res
+
+
+def rand_poses(size, device, opt, radius_range=[1, 1.5], theta_range=[0, 120], phi_range=[0, 360], return_dirs=False, angle_overhead=30, angle_front=60, uniform_sphere_rate=0.5):
+ ''' generate random poses from an orbit camera
+ Args:
+ size: batch size of generated poses.
+ device: where to allocate the output.
+ radius: camera radius
+ theta_range: [min, max], should be in [0, pi]
+ phi_range: [min, max], should be in [0, 2 * pi]
+ Return:
+ poses: [size, 4, 4]
+ '''
+
+ theta_range = np.array(theta_range) / 180 * np.pi
+ phi_range = np.array(phi_range) / 180 * np.pi
+ angle_overhead = angle_overhead / 180 * np.pi
+ angle_front = angle_front / 180 * np.pi
+
+ radius = torch.rand(size, device=device) * (radius_range[1] - radius_range[0]) + radius_range[0]
+
+ if random.random() < uniform_sphere_rate:
+ unit_centers = F.normalize(
+ torch.stack([
+ (torch.rand(size, device=device) - 0.5) * 2.0,
+ torch.rand(size, device=device),
+ (torch.rand(size, device=device) - 0.5) * 2.0,
+ ], dim=-1), p=2, dim=1
+ )
+ thetas = torch.acos(unit_centers[:,1])
+ phis = torch.atan2(unit_centers[:,0], unit_centers[:,2])
+ phis[phis < 0] += 2 * np.pi
+ centers = unit_centers * radius.unsqueeze(-1)
+ else:
+ thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0]
+ phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0]
+ phis[phis < 0] += 2 * np.pi
+
+ centers = torch.stack([
+ radius * torch.sin(thetas) * torch.sin(phis),
+ radius * torch.cos(thetas),
+ radius * torch.sin(thetas) * torch.cos(phis),
+ ], dim=-1) # [B, 3]
+
+ targets = 0
+
+ # jitters
+ if opt.jitter_pose:
+ jit_center = opt.jitter_center # 0.015 # was 0.2
+ jit_target = opt.jitter_target
+ centers += torch.rand_like(centers) * jit_center - jit_center/2.0
+ targets += torch.randn_like(centers) * jit_target
+
+ # lookat
+ forward_vector = safe_normalize(centers - targets)
+ up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(size, 1)
+ right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
+
+ if opt.jitter_pose:
+ up_noise = torch.randn_like(up_vector) * opt.jitter_up
+ else:
+ up_noise = 0
+
+ up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1) + up_noise)
+
+ poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
+ poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
+ poses[:, :3, 3] = centers
+
+ if return_dirs:
+ dirs = get_view_direction(thetas, phis, angle_overhead, angle_front)
+ else:
+ dirs = None
+
+ # back to degree
+ thetas = thetas / np.pi * 180
+ phis = phis / np.pi * 180
+
+ return poses, dirs, thetas, phis, radius
+
+
+def circle_poses(device, radius=torch.tensor([3.2]), theta=torch.tensor([60]), phi=torch.tensor([0]), return_dirs=False, angle_overhead=30, angle_front=60):
+
+ theta = theta / 180 * np.pi
+ phi = phi / 180 * np.pi
+ angle_overhead = angle_overhead / 180 * np.pi
+ angle_front = angle_front / 180 * np.pi
+
+ centers = torch.stack([
+ radius * torch.sin(theta) * torch.sin(phi),
+ radius * torch.cos(theta),
+ radius * torch.sin(theta) * torch.cos(phi),
+ ], dim=-1) # [B, 3]
+
+ # lookat
+ forward_vector = safe_normalize(centers)
+ up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(len(centers), 1)
+ right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
+ up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1))
+
+ poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(len(centers), 1, 1)
+ poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
+ poses[:, :3, 3] = centers
+
+ if return_dirs:
+ dirs = get_view_direction(theta, phi, angle_overhead, angle_front)
+ else:
+ dirs = None
+
+ return poses, dirs
+
+
+class NeRFDataset:
+ def __init__(self, opt, device, type='train', H=256, W=256, size=100):
+ super().__init__()
+
+ self.opt = opt
+ self.device = device
+ self.type = type # train, val, test
+
+ self.H = H
+ self.W = W
+ self.size = size
+
+ self.training = self.type in ['train', 'all']
+
+ self.cx = self.H / 2
+ self.cy = self.W / 2
+
+ self.near = self.opt.min_near
+ self.far = 1000 # infinite
+
+ # [debug] visualize poses
+ # poses, dirs, _, _, _ = rand_poses(100, self.device, opt, radius_range=self.opt.radius_range, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, jitter=self.opt.jitter_pose, uniform_sphere_rate=1)
+ # visualize_poses(poses.detach().cpu().numpy(), dirs.detach().cpu().numpy())
+
+ def get_default_view_data(self):
+
+ H = int(self.opt.known_view_scale * self.H)
+ W = int(self.opt.known_view_scale * self.W)
+ cx = H / 2
+ cy = W / 2
+
+ radii = torch.FloatTensor(self.opt.ref_radii).to(self.device)
+ thetas = torch.FloatTensor(self.opt.ref_polars).to(self.device)
+ phis = torch.FloatTensor(self.opt.ref_azimuths).to(self.device)
+ poses, dirs = circle_poses(self.device, radius=radii, theta=thetas, phi=phis, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front)
+ fov = self.opt.default_fovy
+ focal = H / (2 * np.tan(np.deg2rad(fov) / 2))
+ intrinsics = np.array([focal, focal, cx, cy])
+
+ projection = torch.tensor([
+ [2*focal/W, 0, 0, 0],
+ [0, -2*focal/H, 0, 0],
+ [0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)],
+ [0, 0, -1, 0]
+ ], dtype=torch.float32, device=self.device).unsqueeze(0).repeat(len(radii), 1, 1)
+
+ mvp = projection @ torch.inverse(poses) # [B, 4, 4]
+
+ # sample a low-resolution but full image
+ rays = get_rays(poses, intrinsics, H, W, -1)
+
+ data = {
+ 'H': H,
+ 'W': W,
+ 'rays_o': rays['rays_o'],
+ 'rays_d': rays['rays_d'],
+ 'dir': dirs,
+ 'mvp': mvp,
+ 'polar': self.opt.ref_polars,
+ 'azimuth': self.opt.ref_azimuths,
+ 'radius': self.opt.ref_radii,
+ }
+
+ return data
+
+ def collate(self, index):
+
+ B = len(index)
+
+ if self.training:
+ # random pose on the fly
+ poses, dirs, thetas, phis, radius = rand_poses(B, self.device, self.opt, radius_range=self.opt.radius_range, theta_range=self.opt.theta_range, phi_range=self.opt.phi_range, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, uniform_sphere_rate=self.opt.uniform_sphere_rate)
+
+ # random focal
+ fov = random.random() * (self.opt.fovy_range[1] - self.opt.fovy_range[0]) + self.opt.fovy_range[0]
+
+ elif self.type == 'six_views':
+ # six views
+ thetas_six = [90]*4 + [1e-6] + [180]
+ phis_six = [0, 90, 180, -90, 0, 0]
+ thetas = torch.FloatTensor([thetas_six[index[0]]]).to(self.device)
+ phis = torch.FloatTensor([phis_six[index[0]]]).to(self.device)
+ radius = torch.FloatTensor([self.opt.default_radius]).to(self.device)
+ poses, dirs = circle_poses(self.device, radius=radius, theta=thetas, phi=phis, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front)
+
+ # fixed focal
+ fov = self.opt.default_fovy
+
+ else:
+ # circle pose
+ thetas = torch.FloatTensor([self.opt.default_polar]).to(self.device)
+ phis = torch.FloatTensor([(index[0] / self.size) * 360]).to(self.device)
+ radius = torch.FloatTensor([self.opt.default_radius]).to(self.device)
+ poses, dirs = circle_poses(self.device, radius=radius, theta=thetas, phi=phis, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front)
+
+ # fixed focal
+ fov = self.opt.default_fovy
+
+ focal = self.H / (2 * np.tan(np.deg2rad(fov) / 2))
+ intrinsics = np.array([focal, focal, self.cx, self.cy])
+
+ projection = torch.tensor([
+ [2*focal/self.W, 0, 0, 0],
+ [0, -2*focal/self.H, 0, 0],
+ [0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)],
+ [0, 0, -1, 0]
+ ], dtype=torch.float32, device=self.device).unsqueeze(0)
+
+ mvp = projection @ torch.inverse(poses) # [1, 4, 4]
+
+ # sample a low-resolution but full image
+ rays = get_rays(poses, intrinsics, self.H, self.W, -1)
+
+ # delta polar/azimuth/radius to default view
+ delta_polar = thetas - self.opt.default_polar
+ delta_azimuth = phis - self.opt.default_azimuth
+ delta_azimuth[delta_azimuth > 180] -= 360 # range in [-180, 180]
+ delta_radius = radius - self.opt.default_radius
+
+ data = {
+ 'H': self.H,
+ 'W': self.W,
+ 'rays_o': rays['rays_o'],
+ 'rays_d': rays['rays_d'],
+ 'dir': dirs,
+ 'mvp': mvp,
+ 'polar': delta_polar,
+ 'azimuth': delta_azimuth,
+ 'radius': delta_radius,
+ }
+
+ return data
+
+ def dataloader(self, batch_size=None):
+ batch_size = batch_size or self.opt.batch_size
+ loader = DataLoader(list(range(self.size)), batch_size=batch_size, collate_fn=self.collate, shuffle=self.training, num_workers=0)
+ loader._data = self
+ return loader
+
+
+def generate_grid_points(resolution=128, device='cuda'):
+ # resolution: number of points along each dimension
+ # Generate the grid points
+ x = torch.linspace(0, 1, resolution)
+ y = torch.linspace(0, 1, resolution)
+ z = torch.linspace(0, 1, resolution)
+ # Create the meshgrid
+ grid_x, grid_y, grid_z = torch.meshgrid(x, y, z)
+
+ # Flatten the grid points if needed
+ grid_points = torch.stack((grid_x.flatten(), grid_y.flatten(), grid_z.flatten()), dim=1).to(device)
+ return grid_points
+
diff --git a/nerf/renderer.py b/nerf/renderer.py
new file mode 100644
index 0000000..1fb9475
--- /dev/null
+++ b/nerf/renderer.py
@@ -0,0 +1,1575 @@
+import os
+import math
+import cv2
+import numpy as np
+from tqdm import tqdm
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from encoding import get_encoder
+import nvdiffrast.torch as dr
+
+import mcubes
+import raymarching
+from .utils import custom_meshgrid, safe_normalize
+import logging
+from activation import trunc_exp, biased_softplus
+
+
+logger = logging.getLogger(__name__)
+
+
+class MLP(nn.Module):
+ def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
+ super().__init__()
+ self.dim_in = dim_in
+ self.dim_out = dim_out
+ self.dim_hidden = dim_hidden
+ self.num_layers = num_layers
+
+ net = []
+ for l in range(num_layers):
+ net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden,
+ self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))
+
+ self.net = nn.ModuleList(net)
+
+ def forward(self, x):
+ for l in range(self.num_layers):
+ x = self.net[l](x)
+ if l != self.num_layers - 1:
+ x = F.relu(x, inplace=True)
+ return x
+
+ def reset_parameters(self):
+ @torch.no_grad()
+ def weight_init(m):
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(
+ m.weight, gain=nn.init.calculate_gain('relu'))
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ self.apply(weight_init)
+
+
+def sample_pdf(bins, weights, n_samples, det=False):
+ # This implementation is from NeRF
+ # bins: [B, T], old_z_vals
+ # weights: [B, T - 1], bin weights.
+ # return: [B, n_samples], new_z_vals
+
+ # Get pdf
+ weights = weights + 1e-5 # prevent nans
+ pdf = weights / torch.sum(weights, -1, keepdim=True)
+ cdf = torch.cumsum(pdf, -1)
+ cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
+ # Take uniform samples
+ if det:
+ u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device)
+ u = u.expand(list(cdf.shape[:-1]) + [n_samples])
+ else:
+ u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)
+
+ # Invert CDF
+ u = u.contiguous()
+ inds = torch.searchsorted(cdf, u, right=True)
+ below = torch.max(torch.zeros_like(inds - 1), inds - 1)
+ above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
+ inds_g = torch.stack([below, above], -1) # (B, n_samples, 2)
+
+ matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
+ cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
+ bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
+
+ denom = (cdf_g[..., 1] - cdf_g[..., 0])
+ denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
+ t = (u - cdf_g[..., 0]) / denom
+ samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
+
+ return samples
+
+@torch.cuda.amp.autocast(enabled=False)
+def near_far_from_bound(rays_o, rays_d, bound, type='cube', min_near=0.05):
+ # rays: [B, N, 3], [B, N, 3]
+ # bound: int, radius for ball or half-edge-length for cube
+ # return near [B, N, 1], far [B, N, 1]
+
+ radius = rays_o.norm(dim=-1, keepdim=True)
+
+ if type == 'sphere':
+ near = radius - bound # [B, N, 1]
+ far = radius + bound
+
+ elif type == 'cube':
+ tmin = (-bound - rays_o) / (rays_d + 1e-15) # [B, N, 3]
+ tmax = (bound - rays_o) / (rays_d + 1e-15)
+ near = torch.where(tmin < tmax, tmin, tmax).max(dim=-1, keepdim=True)[0]
+ far = torch.where(tmin > tmax, tmin, tmax).min(dim=-1, keepdim=True)[0]
+ # if far < near, means no intersection, set both near and far to inf (1e9 here)
+ mask = far < near
+ near[mask] = 1e9
+ far[mask] = 1e9
+ # restrict near to a minimal value
+ near = torch.clamp(near, min=min_near)
+
+ return near, far
+
+
+def plot_pointcloud(pc, color=None):
+ import trimesh
+ # pc: [N, 3]
+ # color: [N, 3/4]
+ logger.info('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0))
+ pc = trimesh.PointCloud(pc, color)
+ # axis
+ axes = trimesh.creation.axis(axis_length=4)
+ # sphere
+ sphere = trimesh.creation.icosphere(radius=1)
+ trimesh.Scene([pc, axes, sphere]).show()
+
+
+class DMTet:
+ def __init__(self, device='cuda'):
+ self.device = device
+ self.triangle_table = torch.tensor([
+ [-1, -1, -1, -1, -1, -1],
+ [1, 0, 2, -1, -1, -1],
+ [4, 0, 3, -1, -1, -1],
+ [1, 4, 2, 1, 3, 4],
+ [3, 1, 5, -1, -1, -1],
+ [2, 3, 0, 2, 5, 3],
+ [1, 4, 0, 1, 5, 4],
+ [4, 2, 5, -1, -1, -1],
+ [4, 5, 2, -1, -1, -1],
+ [4, 1, 0, 4, 5, 1],
+ [3, 2, 0, 3, 5, 2],
+ [1, 3, 5, -1, -1, -1],
+ [4, 1, 2, 4, 3, 1],
+ [3, 0, 4, -1, -1, -1],
+ [2, 0, 1, -1, -1, -1],
+ [-1, -1, -1, -1, -1, -1]
+ ], dtype=torch.long, device=self.device)
+
+ self.num_triangles_table = torch.tensor(
+ [0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=self.device)
+ self.base_tet_edges = torch.tensor(
+ [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=self.device)
+
+ ###############################################################################
+ # Utility functions
+ ###############################################################################
+
+ def sort_edges(self, edges_ex2):
+ with torch.no_grad():
+ order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
+ order = order.unsqueeze(dim=1)
+
+ a = torch.gather(input=edges_ex2, index=order, dim=1)
+ b = torch.gather(input=edges_ex2, index=1-order, dim=1)
+
+ return torch.stack([a, b], -1)
+
+ def map_uv(self, faces, face_gidx, max_idx):
+ N = int(np.ceil(np.sqrt((max_idx+1)//2)))
+ tex_y, tex_x = torch.meshgrid(
+ torch.linspace(0, 1 - (1 / N), N,
+ dtype=torch.float32, device=self.device),
+ torch.linspace(0, 1 - (1 / N), N,
+ dtype=torch.float32, device=self.device),
+ ) # indexing='ij')
+
+ pad = 0.9 / N
+
+ uvs = torch.stack([
+ tex_x, tex_y,
+ tex_x + pad, tex_y,
+ tex_x + pad, tex_y + pad,
+ tex_x, tex_y + pad
+ ], dim=-1).view(-1, 2)
+
+ def _idx(tet_idx, N):
+ x = tet_idx % N
+ y = torch.div(tet_idx, N, rounding_mode='trunc')
+ return y * N + x
+
+ tet_idx = _idx(torch.div(face_gidx, 2, rounding_mode='trunc'), N)
+ tri_idx = face_gidx % 2
+
+ uv_idx = torch.stack((
+ tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2
+ ), dim=-1). view(-1, 3)
+
+ return uvs, uv_idx
+
+ ###############################################################################
+ # Marching tets implementation
+ ###############################################################################
+
+ def __call__(self, pos_nx3, sdf_n, tet_fx4, return_uv=True):
+ # pos_nx3: [N, 3]
+ # sdf_n: [N]
+ # tet_fx4: [F, 4]
+
+ with torch.no_grad():
+ occ_n = sdf_n > 0
+ occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
+ occ_sum = torch.sum(occ_fx4, -1) # [F,]
+
+ # a valid tets not all positive (out space) and not all negative (inner)
+ valid_tets = (occ_sum > 0) & (occ_sum < 4)
+ occ_sum = occ_sum[valid_tets]
+
+ # find all vertices
+ all_edges = tet_fx4[valid_tets][:,
+ self.base_tet_edges].reshape(-1, 2)
+ all_edges = self.sort_edges(all_edges)
+ unique_edges, idx_map = torch.unique(
+ all_edges, dim=0, return_inverse=True)
+
+ # find out the edges across the surface to interpolate and refine
+ unique_edges = unique_edges.long()
+ mask_edges = occ_n[unique_edges.reshape(-1)
+ ].reshape(-1, 2).sum(-1) == 1
+ mapping = torch.ones(
+ (unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1
+ mapping[mask_edges] = torch.arange(
+ mask_edges.sum(), dtype=torch.long, device=self.device)
+ idx_map = mapping[idx_map] # map edges to verts
+
+ interp_v = unique_edges[mask_edges]
+ edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
+ edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
+ edges_to_interp_sdf[:, -1] *= -1
+ denominator = edges_to_interp_sdf.sum(1, keepdim=True)
+ edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1])/denominator
+
+ # interpolate edges by sdf
+ verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
+
+ idx_map = idx_map.reshape(-1, 6)
+
+ v_id = torch.pow(2, torch.arange(
+ 4, dtype=torch.long, device=self.device))
+ tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
+ num_triangles = self.num_triangles_table[tetindex]
+
+ # Generate triangle indices
+ faces = torch.cat((
+ torch.gather(input=idx_map[num_triangles == 1], dim=1,
+ index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3),
+ torch.gather(input=idx_map[num_triangles == 2], dim=1,
+ index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3),
+ ), dim=0)
+
+ if return_uv:
+ # Get global face index (static, does not depend on topology)
+ num_tets = tet_fx4.shape[0]
+ tet_gidx = torch.arange(num_tets, dtype=torch.long, device=self.device)[
+ valid_tets]
+ face_gidx = torch.cat((
+ tet_gidx[num_triangles == 1]*2,
+ torch.stack(
+ (tet_gidx[num_triangles == 2]*2, tet_gidx[num_triangles == 2]*2 + 1), dim=-1).view(-1)
+ ), dim=0)
+
+ uvs, uv_idx = self.map_uv(faces, face_gidx, num_tets*2)
+ else:
+ uvs, uv_idx = None, None
+ return verts, faces, uvs, uv_idx
+
+###############################################################################
+# Regularizer
+###############################################################################
+
+
+def sdf_reg_loss(sdf, all_edges):
+ sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1, 2)
+ mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
+ sdf_f1x6x2 = sdf_f1x6x2[mask]
+ sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \
+ torch.nn.functional.binary_cross_entropy_with_logits(
+ sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float())
+ return sdf_diff
+
+###############################################################################
+# Geometry interface
+###############################################################################
+
+
+class DMTetGeometry(torch.nn.Module):
+ def __init__(self, grid_res, tet_mlp, opt, device='cuda'):
+ super(DMTetGeometry, self).__init__()
+
+ self.opt = opt
+ self.device = device
+ self.tet_scale = torch.ones(3, device=device)
+ self.grid_res = grid_res
+ self.marching_tets = DMTet()
+
+ tets = np.load('data/tets/{}_tets.npz'.format(self.grid_res))
+ # for 64/128, [N=36562/277410, 3], in [-0.5, 0.5]^3
+ self.verts = torch.tensor(
+ tets['vertices'], dtype=torch.float32, device=self.device) * 2
+ # for 64/128, [M=192492/1524684, 4], vert indices for each tetrahetron
+ self.indices = torch.tensor(
+ tets['indices'], dtype=torch.long, device=self.device)
+ self.generate_edges()
+
+ self.tet_mlp = tet_mlp
+ if tet_mlp:
+ self.encoder, self.in_dim = get_encoder('hashgrid', input_dim=3)
+ self.encoder = self.encoder.to(device)
+ self.mlp = MLP(self.in_dim, 4, 32, 3, False).to(device)
+ self.sdf = None
+ else:
+ sdf = torch.nn.Parameter(torch.zeros_like(
+ self.verts[..., 0]), requires_grad=True)
+ self.register_parameter('sdf', sdf)
+ deform = torch.nn.Parameter(
+ torch.zeros_like(self.verts), requires_grad=True)
+ self.register_parameter('deform', deform)
+
+ if opt.base_mesh and os.path.exists(opt.base_mesh):
+ self.init_tet_from_mesh(opt.base_mesh)
+
+ def reset_tet(self, reset_scale=True):
+ if self.tet_mlp:
+ self.mlp.reset_parameters()
+ else:
+ self.sdf.data = torch.zeros_like(self.verts[..., 0])
+ self.deform.data = torch.zeros_like(self.verts)
+ if reset_scale:
+ self.reset_tet_scale()
+
+ def get_sdf_from_mesh(self, base_mesh):
+ logger.info(f'[INFO] init sdf from base mesh: {base_mesh}')
+
+ import cubvh
+ import trimesh
+ mesh = trimesh.load(base_mesh, force='mesh')
+
+ scale = 1.5 / np.array(mesh.bounds[1] - mesh.bounds[0]).max()
+ center = np.array(mesh.bounds[1] + mesh.bounds[0]) / 2
+ mesh.vertices = (mesh.vertices - center) * scale
+
+ # build with numpy.ndarray/torch.Tensor
+ BVH = cubvh.cuBVH(mesh.vertices, mesh.faces)
+ sdf, face_id, _ = BVH.signed_distance(
+ self.verts, return_uvw=False, mode='watertight')
+ sdf *= -1 # INNER is POSITIVE
+ return sdf
+
+ def init_tet_from_mesh(self, base_mesh):
+ sdf = self.get_sdf_from_mesh(base_mesh)
+ self.init_tet_from_sdf(sdf)
+ # visualize
+ # sdf_np_gt = sdf.cpu().numpy()
+ # sdf_np = self.mlp(self.encoder(self.verts)).detach().cpu().numpy()[..., 0]
+ # verts_np = self.verts.cpu().numpy()
+ # color = np.zeros_like(verts_np)
+ # color[sdf_np < 0] = [1, 0, 0]
+ # color[sdf_np > 0] = [0, 0, 1]
+ # color = (color * 255).astype(np.uint8)
+ # pc = trimesh.PointCloud(verts_np, color)
+ # axes = trimesh.creation.axis(axis_length=4)
+ # box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline()
+ # trimesh.Scene([mesh, pc, axes, box]).show()
+
+ def init_tet_from_sdf(self, sdf, pretrain_iters=5000, lr=1e-3):
+ if self.tet_mlp:
+ self.mlp.reset_parameters()
+ # pretraining
+ loss_fn = torch.nn.MSELoss()
+ optimizer = torch.optim.Adam(list(self.parameters()), lr=lr)
+
+ #batch_size = min(10240, self.verts.shape[0])
+ batch_size = self.verts.shape[0]
+ pbar = tqdm(range(pretrain_iters), desc="init dmtet mlp from sdf")
+ for i in pbar:
+ rand_idx = torch.randint(0, self.verts.shape[0], (batch_size,))
+ p = self.verts[rand_idx]
+ ref_value = sdf[rand_idx]
+ output = self.mlp(self.encoder(p))
+ loss = loss_fn(output[..., 0], ref_value)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ pbar.set_postfix(loss=loss.item())
+ else:
+ self.sdf.data = sdf.squeeze()
+
+ @torch.no_grad()
+ def reset_tet_scale(self, tet_scale=1.):
+ if isinstance(tet_scale, float):
+ tet_scale = torch.ones(3, device=self.device) * tet_scale
+ self.tet_scale = tet_scale
+ self.verts = self.verts * tet_scale
+
+ @torch.no_grad()
+ def generate_edges(self):
+ # six edges for each tetrahedron.
+ edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3],
+ dtype=torch.long, device=self.device)
+ all_edges = self.indices[:, edges].reshape(-1, 2) # [M * 6, 2]
+ all_edges_sorted = torch.sort(all_edges, dim=1)[0]
+ self.all_edges = torch.unique(all_edges_sorted, dim=0)
+
+ def get_sdf_deform(self):
+ if self.tet_mlp:
+ # predict SDF and per-vertex deformation
+ pred = self.mlp(self.encoder(self.verts))
+ sdf, deform = pred[:, 0], pred[:, 1:]
+ return sdf, torch.tanh(deform) / (self.grid_res)
+ else:
+ return self.sdf, torch.tanh(self.deform) / (self.grid_res)
+
+ def get_verts_face(self):
+ sdf, deform = self.get_sdf_deform()
+ verts, faces, _, _ = self.marching_tets(
+ self.verts + deform, sdf, self.indices, return_uv=False)
+ return verts, faces
+
+ # def getAABB(self):
+ # return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
+
+ # def getMesh(self, material):
+
+ # pred = self.mlp(self.encoder(self.verts)) # predict SDF and per-vertex deformation
+ # sdf, deform = pred[:, 0], pred[:, 1:]
+
+ # v_deformed = self.verts + torch.tanh(deform) / (self.grid_res)
+
+ # verts, faces, uvs, uv_idx = self.marching_tets(v_deformed, sdf, self.indices)
+
+ # imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material)
+
+ # # Run mesh operations to generate tangent space
+ # imesh = mesh.auto_normals(imesh)
+ # imesh = mesh.compute_tangents(imesh)
+
+ # return imesh, sdf
+
+ # def render(self, glctx, target, lgt, opt_material, bsdf=None):
+
+ # # return rendered buffers, keys: ['shaded', 'kd_grad', 'occlusion'].
+ # opt_mesh, sdf = self.getMesh(opt_material)
+ # buffers = render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'],
+ # msaa=True, background=None, bsdf=bsdf)
+ # buffers['mesh'] = opt_mesh
+ # buffers['sdf'] = sdf
+
+ # return buffers
+
+ # def tick(self, glctx, target, lgt, opt_material, loss_fn, guidance_model, text_z, iteration):
+
+ # # ==============================================================================================
+ # # Render optimizable object with identical conditions
+ # # ==============================================================================================
+ # buffers = self.render(glctx, target, lgt, opt_material)
+
+ # mesh = buffers['mesh']
+
+ # # ==============================================================================================
+ # # Compute loss
+ # # ==============================================================================================
+ # t_iter = iteration / self.opt.iter
+
+ # if iteration < int(self.opt.iter * 0.2):
+ # # mode = 'normal_latent'
+ # pred_rgb = buffers['normal'][..., 0:4].permute(0, 3, 1, 2).contiguous()
+ # as_latent = True
+ # elif iteration < int(self.opt.iter * 0.6):
+ # # mode = 'normal'
+ # pred_rgb = buffers['normal'][..., 0:3].permute(0, 3, 1, 2).contiguous()
+ # as_latent = False
+ # else:
+ # # mode = 'rgb'
+ # pred_rgb = buffers['shaded'][..., 0:3].permute(0, 3, 1, 2).contiguous()
+ # pred_ws = buffers['shaded'][..., 3].unsqueeze(1) # [B, 1, H, W]
+ # pred_rgb = pred_rgb * pred_ws + (1 - pred_ws) * 1 # white bg
+ # as_latent = False
+
+ # # torch_vis_2d(pred_rgb[0])
+ # # torch_vis_2d(pred_normal[0])
+ # # torch_vis_2d(pred_ws[0])
+
+ # if self.opt.directional_text:
+ # all_pos = []
+ # all_neg = []
+ # for emb in text_z[target['direction']]: # list of [2, S, -1]
+ # pos, neg = emb.chunk(2) # [1, S, -1]
+ # all_pos.append(pos)
+ # all_neg.append(neg)
+ # text_embedding = torch.cat(all_pos + all_neg, dim=0) # [2b, S, -1]
+ # else:
+ # text_embedding = text_z
+
+ # img_loss = guidance_model.train_step(text_embedding, pred_rgb.half(), as_latent=as_latent)
+
+ # # img_loss = torch.tensor(0.0, device = self.device)
+
+ # # below are lots of regularizations...
+ # reg_loss = torch.tensor(0.0, device = self.device)
+
+ # if iteration < int(self.opt.iter * 0.6):
+ # # SDF regularizer
+ # sdf_weight = self.opt.sdf_regularizer - (self.opt.sdf_regularizer - 0.01) * min(1.0, 4.0 * t_iter)
+ # sdf_loss = sdf_reg_loss(buffers['sdf'], self.all_edges).mean() * sdf_weight # Dropoff to 0.01
+ # reg_loss = reg_loss + sdf_loss
+
+ # # directly regularize mesh smoothness in finetuning...
+ # if iteration > int(self.opt.iter * 0.2):
+ # lap_loss = regularizer.laplace_regularizer_const(mesh.v_pos, mesh.t_pos_idx) * self.opt.laplace_scale #* min(1.0, iteration / 500)
+ # reg_loss = reg_loss + lap_loss
+
+ # # normal_loss = regularizer.normal_consistency(mesh.v_pos, mesh.t_pos_idx) * self.opt.laplace_scale * min(1.0, iteration / 500)
+ # # reg_loss = reg_loss + normal_loss
+
+ # else:
+ # # Albedo (k_d) smoothnesss regularizer
+ # # reg_loss += torch.mean(buffers['kd_grad'][..., :-1] * buffers['kd_grad'][..., -1:]) * 0.03 * min(1.0, (iteration - int(self.opt.iter * 0.6)) / 500)
+
+ # # # Visibility regularizer
+ # # reg_loss += torch.mean(buffers['occlusion'][..., :-1] * buffers['occlusion'][..., -1:]) * 0.001 * min(1.0, (iteration - int(self.opt.iter * 0.6)) / 500)
+
+ # # # Light white balance regularizer
+ # reg_loss += lgt.regularizer() * 0.005
+
+ # return img_loss, reg_loss
+
+
+def compute_edge_to_face_mapping(attr_idx):
+ with torch.no_grad():
+ # Get unique edges
+ # Create all edges, packed by triangle
+ all_edges = torch.cat((
+ torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1),
+ torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1),
+ torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1),
+ ), dim=-1).view(-1, 2)
+
+ # Swap edge order so min index is always first
+ order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1)
+ sorted_edges = torch.cat((
+ torch.gather(all_edges, 1, order),
+ torch.gather(all_edges, 1, 1 - order)
+ ), dim=-1)
+
+ # Elliminate duplicates and return inverse mapping
+ unique_edges, idx_map = torch.unique(
+ sorted_edges, dim=0, return_inverse=True)
+
+ tris = torch.arange(attr_idx.shape[0]).repeat_interleave(3).cuda()
+
+ tris_per_edge = torch.zeros(
+ (unique_edges.shape[0], 2), dtype=torch.int64).cuda()
+
+ # Compute edge to face table
+ mask0 = order[:, 0] == 0
+ mask1 = order[:, 0] == 1
+ tris_per_edge[idx_map[mask0], 0] = tris[mask0]
+ tris_per_edge[idx_map[mask1], 1] = tris[mask1]
+
+ return tris_per_edge
+
+
+@torch.cuda.amp.autocast(enabled=False)
+def normal_consistency(face_normals, t_pos_idx):
+
+ tris_per_edge = compute_edge_to_face_mapping(t_pos_idx)
+
+ # Fetch normals for both faces sharind an edge
+ n0 = face_normals[tris_per_edge[:, 0], :]
+ n1 = face_normals[tris_per_edge[:, 1], :]
+
+ # Compute error metric based on normal difference
+ term = torch.clamp(torch.sum(n0 * n1, -1, keepdim=True), min=-1.0, max=1.0)
+ term = (1.0 - term)
+
+ return torch.mean(torch.abs(term))
+
+
+def laplacian_uniform(verts, faces):
+
+ V = verts.shape[0]
+ F = faces.shape[0]
+
+ # Neighbor indices
+ ii = faces[:, [1, 2, 0]].flatten()
+ jj = faces[:, [2, 0, 1]].flatten()
+ adj = torch.stack(
+ [torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique(dim=1)
+ adj_values = torch.ones(
+ adj.shape[1], device=verts.device, dtype=torch.float)
+
+ # Diagonal indices
+ diag_idx = adj[0]
+
+ # Build the sparse matrix
+ idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1)
+ values = torch.cat((-adj_values, adj_values))
+
+ # The coalesce operation sums the duplicate indices, resulting in the
+ # correct diagonal
+ return torch.sparse_coo_tensor(idx, values, (V, V)).coalesce()
+
+
+@torch.cuda.amp.autocast(enabled=False)
+def laplacian_smooth_loss(verts, faces):
+ with torch.no_grad():
+ L = laplacian_uniform(verts, faces.long())
+ loss = L.mm(verts)
+ loss = loss.norm(dim=1)
+ loss = loss.mean()
+ return loss
+
+
+class NeRFRenderer(nn.Module):
+ def __init__(self, opt):
+ super().__init__()
+
+ self.opt = opt
+ self.bound = opt.bound
+ self.cascade = 1 + math.ceil(math.log2(opt.bound))
+ self.grid_size = 128
+ self.max_level = None
+ self.dmtet = opt.dmtet
+ self.cuda_ray = opt.cuda_ray
+ self.taichi_ray = opt.taichi_ray
+ self.min_near = opt.min_near
+ self.density_thresh = opt.density_thresh
+
+ # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)
+ # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.
+ aabb_train = torch.FloatTensor(
+ [-opt.bound, -opt.bound, -opt.bound, opt.bound, opt.bound, opt.bound])
+ aabb_infer = aabb_train.clone()
+ self.register_buffer('aabb_train', aabb_train)
+ self.register_buffer('aabb_infer', aabb_infer)
+
+ self.glctx = None
+
+ # extra state for cuda raymarching
+ if self.cuda_ray:
+ # density grid
+ density_grid = torch.zeros(
+ [self.cascade, self.grid_size ** 3]) # [CAS, H * H * H]
+ density_bitfield = torch.zeros(
+ self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8]
+ self.register_buffer('density_grid', density_grid)
+ self.register_buffer('density_bitfield', density_bitfield)
+ self.mean_density = 0
+ self.iter_density = 0
+
+ # load dmtet vertices
+ if self.opt.dmtet:
+ self.dmtet = DMTetGeometry(opt.tet_grid_size, opt.tet_mlp, opt).to(opt.device)
+ if self.opt.h <= 2048 and self.opt.w <= 2048:
+ self.glctx = dr.RasterizeCudaContext()
+ else:
+ self.glctx = dr.RasterizeGLContext()
+
+ if self.taichi_ray:
+ from einops import rearrange
+ from taichi_modules import RayMarcherTaichi
+ from taichi_modules import VolumeRendererTaichi
+ from taichi_modules import RayAABBIntersector as RayAABBIntersectorTaichi
+ from taichi_modules import raymarching_test as raymarching_test_taichi
+ from taichi_modules import composite_test as composite_test_fw
+ from taichi_modules import packbits as packbits_taichi
+ self.rearrange = rearrange
+ self.packbits_taichi = packbits_taichi
+ self.ray_aabb_intersector = RayAABBIntersectorTaichi
+ self.raymarching_test_taichi = raymarching_test_taichi
+ self.composite_test_fw = composite_test_fw
+ self.ray_marching = RayMarcherTaichi(
+ batch_size=4096) # TODO: hard encoded batch size
+ self.volume_render = VolumeRendererTaichi(
+ batch_size=4096) # TODO: hard encoded batch size
+ # density grid
+ density_grid = torch.zeros(
+ [self.cascade, self.grid_size ** 3]) # [CAS, H * H * H]
+ density_bitfield = torch.zeros(
+ self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8]
+ self.register_buffer('density_grid', density_grid)
+ self.register_buffer('density_bitfield', density_bitfield)
+ self.mean_density = 0
+ self.iter_density = 0
+
+ if self.opt.density_activation == 'exp':
+ self.density_activation = trunc_exp
+ elif self.opt.density_activation == 'softplus':
+ self.density_activation = F.softplus
+ elif self.opt.density_activation == 'relu':
+ self.density_activation = F.relu
+
+ # ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192
+ def finite_difference_normal(self, x, epsilon=1e-2):
+ # x: [N, 3]
+ dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
+ dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
+ dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
+ dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound))
+ dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound))
+ dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound))
+
+ normal = torch.stack([
+ 0.5 * (dx_pos - dx_neg) / epsilon,
+ 0.5 * (dy_pos - dy_neg) / epsilon,
+ 0.5 * (dz_pos - dz_neg) / epsilon
+ ], dim=-1)
+
+ return -normal
+
+ def normal(self, x):
+ normal = self.finite_difference_normal(x)
+ normal = safe_normalize(normal)
+ normal = torch.nan_to_num(normal)
+ return normal
+
+ @torch.no_grad()
+ def density_blob(self, x):
+ # x: [B, N, 3]
+
+ d = (x ** 2).sum(-1)
+
+ if self.opt.density_activation == 'exp':
+ g = self.opt.blob_density * \
+ torch.exp(- d / (2 * self.opt.blob_radius ** 2))
+ else:
+ g = self.opt.blob_density * \
+ (1 - torch.sqrt(d) / self.opt.blob_radius)
+
+ return g
+
+ def forward(self, x, d):
+ raise NotImplementedError()
+
+ def density(self, x):
+ raise NotImplementedError()
+
+ def reset_extra_state(self):
+ if not (self.cuda_ray or self.taichi_ray):
+ return
+ # density grid
+ self.density_grid.zero_()
+ self.mean_density = 0
+ self.iter_density = 0
+
+ @torch.no_grad()
+ def export_mesh(self, path, resolution=None, decimate_target=-1, S=128):
+ from meshutils import decimate_mesh, clean_mesh, poisson_mesh_reconstruction
+ if self.opt.dmtet:
+ vertices, triangles = self.dmtet.get_verts_face()
+ vertices = vertices.detach().cpu().numpy()
+ triangles = triangles.detach().cpu().numpy()
+
+ else:
+
+ if resolution is None:
+ resolution = self.grid_size
+
+ if self.cuda_ray:
+ density_thresh = min(self.mean_density, self.density_thresh) \
+ if np.greater(self.mean_density, 0) else self.density_thresh
+ else:
+ density_thresh = self.density_thresh
+
+ sigmas = np.zeros(
+ [resolution, resolution, resolution], dtype=np.float32)
+
+ # query
+ X = torch.linspace(-1, 1, resolution).split(S)
+ Y = torch.linspace(-1, 1, resolution).split(S)
+ Z = torch.linspace(-1, 1, resolution).split(S)
+
+ for xi, xs in enumerate(X):
+ for yi, ys in enumerate(Y):
+ for zi, zs in enumerate(Z):
+ xx, yy, zz = custom_meshgrid(xs, ys, zs)
+ pts = torch.cat(
+ [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3]
+ val = self.density(pts.to(self.aabb_train.device))
+ sigmas[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(
+ zs)] = val['sigma'].reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z]
+
+ logger.info(
+ f'[INFO] marching cubes thresh: {density_thresh} ({sigmas.min()} ~ {sigmas.max()})')
+
+ vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)
+ vertices = vertices / (resolution - 1.0) * 2 - 1
+
+ # clean
+ vertices = vertices.astype(np.float32)
+ triangles = triangles.astype(np.int32)
+ vertices, triangles = clean_mesh(
+ vertices, triangles, remesh=True, remesh_size=0.01)
+
+ # decimation
+ if decimate_target > 0 and triangles.shape[0] > decimate_target:
+ vertices, triangles = decimate_mesh(
+ vertices, triangles, decimate_target)
+
+ v = torch.from_numpy(vertices).contiguous(
+ ).float().to(self.aabb_train.device)
+ f = torch.from_numpy(triangles).contiguous().int().to(
+ self.aabb_train.device)
+
+ # mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault...
+ # mesh.export(os.path.join(path, f'mesh.ply'))
+
+ def _export(v, f, h0=2048, w0=2048, ssaa=1, name=''):
+ # v, f: torch Tensor
+ device = v.device
+ v_np = v.cpu().numpy() # [N, 3]
+ f_np = f.cpu().numpy() # [M, 3]
+
+ logger.info(
+ f'[INFO] running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}')
+
+ # unwrap uvs
+ import xatlas
+ import nvdiffrast.torch as dr
+ from sklearn.neighbors import NearestNeighbors
+ from scipy.ndimage import binary_dilation, binary_erosion
+
+ atlas = xatlas.Atlas()
+ atlas.add_mesh(v_np, f_np)
+ chart_options = xatlas.ChartOptions()
+ chart_options.max_iterations = 4 # for faster unwrap...
+ atlas.generate(chart_options=chart_options)
+ vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
+
+ # vmapping, ft_np, vt_np = xatlas.parametrize(v_np, f_np) # [N], [M, 3], [N, 2]
+
+ vt = torch.from_numpy(vt_np.astype(np.float32)).float().to(device)
+ ft = torch.from_numpy(ft_np.astype(np.int64)).int().to(device)
+
+ # render uv maps
+ uv = vt * 2.0 - 1.0 # uvs to range [-1, 1]
+ uv = torch.cat((uv, torch.zeros_like(
+ uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]
+
+ if ssaa > 1:
+ h = int(h0 * ssaa)
+ w = int(w0 * ssaa)
+ else:
+ h, w = h0, w0
+
+ if self.glctx is None:
+ if h <= 2048 and w <= 2048:
+ self.glctx = dr.RasterizeCudaContext()
+ else:
+ self.glctx = dr.RasterizeGLContext()
+
+ rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(
+ 0), ft, (h, w)) # [1, h, w, 4]
+ xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3]
+ mask, _ = dr.interpolate(torch.ones_like(
+ v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1]
+
+ # masked query
+ xyzs = xyzs.view(-1, 3)
+ mask = (mask > 0).view(-1)
+
+ feats = torch.zeros(h * w, 3, device=device, dtype=torch.float32)
+
+ if mask.any():
+ xyzs = xyzs[mask] # [M, 3]
+
+ # batched inference to avoid OOM
+ all_feats = []
+ head = 0
+ while head < xyzs.shape[0]:
+ tail = min(head + 640000, xyzs.shape[0])
+ results_ = self.density(xyzs[head:tail])
+ all_feats.append(results_['albedo'].float())
+ head += 640000
+
+ feats[mask] = torch.cat(all_feats, dim=0)
+
+ feats = feats.view(h, w, -1)
+ mask = mask.view(h, w)
+
+ # quantize [0.0, 1.0] to [0, 255]
+ feats = feats.cpu().numpy()
+ feats = (feats * 255).astype(np.uint8)
+
+ ### NN search as an antialiasing ...
+ mask = mask.cpu().numpy()
+
+ inpaint_region = binary_dilation(mask, iterations=3)
+ inpaint_region[mask] = 0
+
+ search_region = mask.copy()
+ not_search_region = binary_erosion(search_region, iterations=2)
+ search_region[not_search_region] = 0
+
+ search_coords = np.stack(np.nonzero(search_region), axis=-1)
+ inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1)
+
+ knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords)
+ _, indices = knn.kneighbors(inpaint_coords)
+
+ feats[tuple(inpaint_coords.T)] = feats[tuple(search_coords[indices[:, 0]].T)]
+
+ feats = cv2.cvtColor(feats, cv2.COLOR_RGB2BGR)
+
+ # do ssaa after the NN search, in numpy
+ if ssaa > 1:
+ feats = cv2.resize(feats, (w0, h0), interpolation=cv2.INTER_LINEAR)
+
+ cv2.imwrite(os.path.join(path, f'{name}albedo.png'), feats)
+
+ # save obj (v, vt, f /)
+ obj_file = os.path.join(path, f'{name}mesh.obj')
+ mtl_file = os.path.join(path, f'{name}mesh.mtl')
+
+ logger.info(f'[INFO] writing obj mesh to {obj_file}')
+ with open(obj_file, "w") as fp:
+ fp.write(f'mtllib {name}mesh.mtl \n')
+
+ logger.info(f'[INFO] writing vertices {v_np.shape}')
+ for v in v_np:
+ fp.write(f'v {v[0]} {v[1]} {v[2]} \n')
+
+ logger.info(
+ f'[INFO] writing vertices texture coords {vt_np.shape}')
+ for v in vt_np:
+ fp.write(f'vt {v[0]} {1 - v[1]} \n')
+
+ logger.info(f'[INFO] writing faces {f_np.shape}')
+ fp.write(f'usemtl mat0 \n')
+ for i in range(len(f_np)):
+ fp.write(
+ f"f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1} {f_np[i, 1] + 1}/{ft_np[i, 1] + 1} {f_np[i, 2] + 1}/{ft_np[i, 2] + 1} \n")
+
+ with open(mtl_file, "w") as fp:
+ fp.write(f'newmtl mat0 \n')
+ fp.write(f'Ka 1.000000 1.000000 1.000000 \n')
+ fp.write(f'Kd 1.000000 1.000000 1.000000 \n')
+ fp.write(f'Ks 0.000000 0.000000 0.000000 \n')
+ fp.write(f'Tr 1.000000 \n')
+ fp.write(f'illum 1 \n')
+ fp.write(f'Ns 0.000000 \n')
+ fp.write(f'map_Kd {name}albedo.png \n')
+
+ _export(v, f)
+
+ def run(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, **kwargs):
+ # rays_o, rays_d: [B, N, 3]
+ # bg_color: [BN, 3] in range [0, 1]
+ # return: image: [B, N, 3], depth: [B, N]
+
+ prefix = rays_o.shape[:-1]
+ rays_o = rays_o.contiguous().view(-1, 3)
+ rays_d = rays_d.contiguous().view(-1, 3)
+
+ N = rays_o.shape[0] # N = B * N, in fact
+ device = rays_o.device
+
+ results = {}
+
+ # choose aabb
+ aabb = self.aabb_train if self.training else self.aabb_infer
+
+ # sample steps
+ # nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb, self.min_near)
+ # nears.unsqueeze_(-1)
+ # fars.unsqueeze_(-1)
+ nears, fars = near_far_from_bound(rays_o, rays_d, self.bound, type='sphere', min_near=self.min_near)
+
+ # random sample light_d if not provided
+ if light_d is None:
+ # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)
+ if self.training:
+ light_d = safe_normalize(rays_o + torch.randn(3, device=rays_o.device)) # [N, 3]
+ else:
+ light_d = safe_normalize(rays_o[0:1] + torch.randn(3, device=rays_o.device)) # [N, 3]
+
+ #print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}')
+
+ z_vals = torch.linspace(0.0, 1.0, self.opt.num_steps, device=device).unsqueeze(0) # [1, T]
+ z_vals = z_vals.expand((N, self.opt.num_steps)) # [N, T]
+ z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars]
+
+ # perturb z_vals
+ sample_dist = (fars - nears) / self.opt.num_steps
+ if perturb:
+ z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist
+ #z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs.
+
+ # generate xyzs
+ xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3]
+ xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip.
+
+ #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
+
+ # query SDF and RGB
+ density_outputs = self.density(xyzs.reshape(-1, 3))
+
+ #sigmas = density_outputs['sigma'].view(N, self.opt.num_steps) # [N, T]
+ for k, v in density_outputs.items():
+ density_outputs[k] = v.view(N, self.opt.num_steps, -1)
+
+ # upsample z_vals (nerf-like)
+ if self.opt.upsample_steps > 0:
+ with torch.no_grad():
+
+ deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1]
+ deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
+
+ alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T]
+ alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1]
+ weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T]
+
+ # sample new z_vals
+ z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1]
+ new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], self.opt.upsample_steps, det=not self.training).detach() # [N, t]
+
+ new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3]
+ new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip.
+
+ # only forward new points to save computation
+ new_density_outputs = self.density(new_xyzs.reshape(-1, 3))
+ #new_sigmas = new_density_outputs['sigma'].view(N, self.opt.upsample_steps) # [N, t]
+ for k, v in new_density_outputs.items():
+ new_density_outputs[k] = v.view(N, self.opt.upsample_steps, -1)
+
+ # re-order
+ z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t]
+ z_vals, z_index = torch.sort(z_vals, dim=1)
+
+ xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3]
+ xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs))
+
+ for k in density_outputs:
+ tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1)
+ density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output))
+
+ deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1]
+ deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1)
+ alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T+t]
+ alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1]
+ weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t]
+
+ dirs = rays_d.view(-1, 1, 3).expand_as(xyzs)
+ light_d = light_d.view(-1, 1, 3).expand_as(xyzs)
+ for k, v in density_outputs.items():
+ density_outputs[k] = v.view(-1, v.shape[-1])
+
+ dirs = safe_normalize(dirs)
+ sigmas, rgbs, normals = self(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), light_d, ratio=ambient_ratio, shading=shading)
+ rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3]
+ if normals is not None:
+ normals = normals.view(N, -1, 3)
+
+ # calculate weight_sum (mask)
+ weights_sum = weights.sum(dim=-1) # [N]
+
+ # calculate depth
+ depth = torch.sum(weights * z_vals, dim=-1)
+
+ # calculate color
+ image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1]
+
+ # mix background color
+ if bg_color is None:
+ if self.opt.bg_radius > 0:
+ # use the bg model to calculate bg_color
+ bg_color = self.background(rays_d) # [N, 3]
+ else:
+ bg_color = 1
+
+ image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
+
+ image = image.view(*prefix, 3)
+ depth = depth.view(*prefix)
+ weights_sum = weights_sum.reshape(*prefix)
+
+ if self.training:
+ if self.opt.lambda_orient > 0 and normals is not None:
+ # orientation loss
+ loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2
+ results['loss_orient'] = loss_orient.sum(-1).mean()
+
+ if self.opt.lambda_3d_normal_smooth > 0 and normals is not None:
+ normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2)
+ results['loss_normal_perturb'] = (normals - normals_perturb).abs().mean()
+
+ if normals is not None:
+ normal_image = torch.sum(
+ weights.unsqueeze(-1) * (normals + 1) / 2, dim=-2) # [N, 3], in [0, 1]
+ results['normal_image'] = normal_image
+
+ results['image'] = image
+ results['depth'] = depth
+ results['weights'] = weights
+ results['weights_sum'] = weights_sum
+
+ return results
+
+
+ def run_cuda(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, T_thresh=1e-4, binarize=False, **kwargs):
+ # rays_o, rays_d: [B, N, 3]
+ # return: image: [B, N, 3], depth: [B, N]
+
+ prefix = rays_o.shape[:-1]
+ rays_o = rays_o.contiguous().view(-1, 3)
+ rays_d = rays_d.contiguous().view(-1, 3)
+
+ N = rays_o.shape[0] # B * N, in fact
+ device = rays_o.device
+
+ # pre-calculate near far
+ nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer)
+
+ # random sample light_d if not provided
+ if light_d is None:
+ # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)
+ if self.training:
+ light_d = safe_normalize(rays_o[0:1] + torch.randn(3, device=rays_o.device)) # [N, 3]
+ else:
+ light_d = safe_normalize(rays_o[0:1] + torch.randn(3, device=rays_o.device)) # [N, 3]
+
+ results = {}
+
+ if self.training:
+ xyzs, dirs, ts, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, perturb, self.opt.dt_gamma, self.opt.max_steps)
+ dirs = safe_normalize(dirs)
+
+ if light_d.shape[0] > 1:
+ flatten_rays = raymarching.flatten_rays(rays, xyzs.shape[0]).long()
+ light_d = light_d[flatten_rays]
+
+ sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading)
+ weights, weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, ts, rays, T_thresh, binarize)
+
+ # normals related regularizations
+ if self.opt.lambda_orient > 0 and normals is not None:
+ # orientation loss
+ loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2
+ results['loss_orient'] = loss_orient.mean()
+
+ if self.opt.lambda_3d_normal_smooth > 0 and normals is not None:
+ normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2)
+ results['loss_normal_perturb'] = (normals - normals_perturb).abs().mean()
+
+ if normals is not None:
+ _, _, _, normal_image = raymarching.composite_rays_train(sigmas.detach(), (normals + 1) / 2, ts, rays, T_thresh, binarize)
+ results['normal_image'] = normal_image
+
+ # weights normalization
+ results['weights'] = weights
+
+ else:
+
+ # allocate outputs
+ dtype = torch.float32
+
+ weights_sum = torch.zeros(N, dtype=dtype, device=device)
+ depth = torch.zeros(N, dtype=dtype, device=device)
+ image = torch.zeros(N, 3, dtype=dtype, device=device)
+
+ n_alive = N
+ rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]
+ rays_t = nears.clone() # [N]
+
+ step = 0
+
+ while step < self.opt.max_steps: # hard coded max step
+
+ # count alive rays
+ n_alive = rays_alive.shape[0]
+
+ # exit loop
+ if n_alive <= 0:
+ break
+
+ # decide compact_steps
+ n_step = max(min(N // n_alive, 8), 1)
+
+ xyzs, dirs, ts = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, perturb if step == 0 else False, self.opt.dt_gamma, self.opt.max_steps)
+ dirs = safe_normalize(dirs)
+ sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading)
+ raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, ts, weights_sum, depth, image, T_thresh, binarize)
+
+ rays_alive = rays_alive[rays_alive >= 0]
+ #print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')
+
+ step += n_step
+
+ # mix background color
+ if bg_color is None:
+ if self.opt.bg_radius > 0:
+ # use the bg model to calculate bg_color
+ bg_color = self.background(rays_d) # [N, 3]
+ else:
+ bg_color = 1
+
+ image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
+ image = image.view(*prefix, 3)
+
+ depth = depth.view(*prefix)
+
+ weights_sum = weights_sum.reshape(*prefix)
+
+ results['image'] = image
+ results['depth'] = depth
+ results['weights_sum'] = weights_sum
+
+ return results
+
+ def get_sdf_albedo_for_init(self, points=None):
+ output = self.density(self.dmtet.verts if points is None else points)
+ sigma, albedo = output['sigma'], output['albedo']
+ return sigma - self.density_thresh, albedo
+
+ def run_dmtet(self, rays_o, rays_d, mvp, h, w, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, **kwargs):
+ # mvp: [B, 4, 4]
+
+ device = mvp.device
+ campos = rays_o[:, 0, :] # only need one ray per batch
+
+ # random sample light_d if not provided
+ if light_d is None:
+ # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)
+ light_d = safe_normalize(campos + torch.randn_like(campos)).view(-1, 1, 1, 3) # [B, 1, 1, 3]
+
+ results = {}
+
+ verts, faces = self.dmtet.get_verts_face()
+
+ # get normals
+ i0, i1, i2 = faces[:, 0], faces[:, 1], faces[:, 2]
+ v0, v1, v2 = verts[i0, :], verts[i1, :], verts[i2, :]
+
+ faces = faces.int()
+
+ face_normals = torch.cross(v1 - v0, v2 - v0)
+ face_normals = safe_normalize(face_normals)
+
+ vn = torch.zeros_like(verts)
+ vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)
+ vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)
+ vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)
+
+ vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))
+
+ # rasterization
+ verts_clip = torch.bmm(F.pad(verts, pad=(0, 1), mode='constant', value=1.0).unsqueeze(0).repeat(mvp.shape[0], 1, 1),
+ mvp.permute(0,2,1)).float() # [B, N, 4]
+ rast, rast_db = dr.rasterize(self.glctx, verts_clip, faces, (h, w))
+
+ alpha, _ = dr.interpolate(torch.ones_like(verts[:, :1]).unsqueeze(0), rast, faces) # [B, H, W, 1]
+ xyzs, _ = dr.interpolate(verts.unsqueeze(0), rast, faces) # [B, H, W, 3]
+ normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, faces)
+ normal = safe_normalize(normal)
+
+ xyzs = xyzs.view(-1, 3)
+ mask = (alpha > 0).view(-1).detach()
+
+ # do the lighting here since we have normal from mesh now.
+ albedo = torch.zeros_like(xyzs, dtype=torch.float32)
+ if mask.any():
+ masked_albedo = self.density(xyzs[mask])['albedo']
+ albedo[mask] = masked_albedo.float()
+ albedo = albedo.view(-1, h, w, 3)
+
+ if shading == 'albedo':
+ color = albedo
+ elif shading == 'textureless':
+ lambertian = ambient_ratio + (1 - ambient_ratio) * (normal * light_d).sum(-1).float().clamp(min=0)
+ color = lambertian.unsqueeze(-1).repeat(1, 1, 1, 3)
+ elif shading == 'normal':
+ color = (normal + 1) / 2
+ else: # 'lambertian'
+ lambertian = ambient_ratio + (1 - ambient_ratio) * (normal * light_d).sum(-1).float().clamp(min=0)
+ color = albedo * lambertian.unsqueeze(-1)
+
+ color = dr.antialias(color, rast, verts_clip, faces).clamp(0, 1) # [B, H, W, 3]
+ alpha = dr.antialias(alpha, rast, verts_clip, faces).clamp(0, 1) # [B, H, W, 1]
+
+ # mix background color
+ if bg_color is None:
+ if self.opt.bg_radius > 0:
+ # use the bg model to calculate bg_color
+ bg_color = self.background(rays_d) # [N, 3]
+ else:
+ bg_color = 1
+
+ if torch.is_tensor(bg_color) and len(bg_color.shape) > 1:
+ bg_color = bg_color.view(-1, h, w, 3)
+
+ depth = rast[:, :, :, [2]] # [B, H, W]
+ color = color + (1 - alpha) * bg_color
+
+ results['depth'] = depth
+ results['image'] = color
+ results['weights_sum'] = alpha.squeeze(-1)
+
+ normal_image = dr.antialias((normal + 1) / 2, rast, verts_clip, faces).clamp(0, 1) # [B, H, W, 3]
+ results['normal_image'] = normal_image
+
+ # regularizations
+ if self.training:
+ if self.opt.lambda_mesh_normal > 0:
+ results['loss_normal'] = normal_consistency(
+ face_normals, faces)
+ if self.opt.lambda_mesh_lap > 0:
+ results['loss_lap'] = laplacian_smooth_loss(verts, faces)
+
+ return results
+
+ def run_taichi(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, T_thresh=1e-4, **kwargs):
+ # rays_o, rays_d: [B, N, 3], assumes B == 1
+ # return: image: [B, N, 3], depth: [B, N]
+
+ prefix = rays_o.shape[:-1]
+ rays_o = rays_o.contiguous().view(-1, 3)
+ rays_d = rays_d.contiguous().view(-1, 3)
+
+ N = rays_o.shape[0] # N = B * N, in fact
+ device = rays_o.device
+
+ # pre-calculate near far
+ exp_step_factor = kwargs.get('exp_step_factor', 0.)
+ MAX_SAMPLES = 1024
+ NEAR_DISTANCE = 0.01
+ center = torch.zeros(1, 3)
+ half_size = torch.ones(1, 3)
+ _, hits_t, _ = self.ray_aabb_intersector.apply(rays_o, rays_d, center, half_size, 1)
+ hits_t[(hits_t[:, 0, 0] >= 0) & (hits_t[:, 0, 0] < NEAR_DISTANCE), 0, 0] = NEAR_DISTANCE
+
+ # TODO: should sample different light_d for each batch... but taichi end doesn't have a flatten_ray implemented currently...
+ # random sample light_d if not provided
+ if light_d is None:
+ # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face)
+ light_d = (rays_o[0] + torch.randn(3, device=device, dtype=torch.float))
+ light_d = safe_normalize(light_d)
+
+ results = {}
+
+ if self.training:
+ rays_a, xyzs, dirs, deltas, ts, _ = self.ray_marching(rays_o, rays_d, hits_t[:, 0], self.density_bitfield, self.cascade, self.bound, exp_step_factor, self.grid_size, MAX_SAMPLES)
+ dirs = safe_normalize(dirs)
+ # plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
+ sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading)
+ _, weights_sum, depth, image, weights = self.volume_render(sigmas, rgbs, deltas, ts, rays_a, kwargs.get('T_threshold', 1e-4))
+
+ # normals related regularizations
+ if self.opt.lambda_orient > 0 and normals is not None:
+ # orientation loss
+ loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2
+ results['loss_orient'] = loss_orient.mean()
+
+ if self.opt.lambda_3d_normal_smooth > 0 and normals is not None:
+ normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2)
+ results['loss_normal_perturb'] = (normals - normals_perturb).abs().mean()
+
+ if normals is not None:
+ _, _, _, normal_image, _ = self.volume_render(sigmas.detach(), (normals + 1) / 2, deltas, ts, rays_a, kwargs.get('T_threshold', 1e-4))
+ results['normal_image'] = normal_image
+
+ # weights normalization
+ results['weights'] = weights
+
+ else:
+
+ # allocate outputs
+ dtype = torch.float32
+
+ weights_sum = torch.zeros(N, dtype=dtype, device=device)
+ depth = torch.zeros(N, dtype=dtype, device=device)
+ image = torch.zeros(N, 3, dtype=dtype, device=device)
+
+ n_alive = N
+ rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]
+ rays_t = hits_t[:, 0, 0]
+ step = 0
+
+ min_samples = 1 if exp_step_factor == 0 else 4
+
+ while step < self.opt.max_steps: # hard coded max step
+
+ # count alive rays
+ n_alive = rays_alive.shape[0]
+
+ # exit loop
+ if n_alive <= 0:
+ break
+
+ # decide compact_steps
+ # n_step = max(min(N // n_alive, 8), 1)
+ n_step = max(min(N // n_alive, 64), min_samples)
+
+ xyzs, dirs, deltas, ts, N_eff_samples = \
+ self.raymarching_test_taichi(rays_o, rays_d, hits_t[:, 0], rays_alive,
+ self.density_bitfield, self.cascade,
+ self.bound, exp_step_factor,
+ self.grid_size, MAX_SAMPLES, n_step)
+
+ xyzs = self.rearrange(xyzs, 'n1 n2 c -> (n1 n2) c')
+ dirs = self.rearrange(dirs, 'n1 n2 c -> (n1 n2) c')
+ dirs = safe_normalize(dirs)
+ valid_mask = ~torch.all(dirs == 0, dim=1)
+ if valid_mask.sum() == 0:
+ break
+
+ sigmas = torch.zeros(len(xyzs), device=device)
+ rgbs = torch.zeros(len(xyzs), 3, device=device)
+ normals = torch.zeros(len(xyzs), 3, device=device)
+
+ sigmas[valid_mask], _rgbs, normals = self(xyzs[valid_mask], dirs[valid_mask], light_d, ratio=ambient_ratio, shading=shading)
+ rgbs[valid_mask] = _rgbs.float()
+ sigmas = self.rearrange(sigmas, '(n1 n2) -> n1 n2', n2=n_step)
+ rgbs = self.rearrange(rgbs, '(n1 n2) c -> n1 n2 c', n2=n_step)
+ if normals is not None:
+ normals = self.rearrange(normals, '(n1 n2) c -> n1 n2 c', n2=n_step)
+
+ self.composite_test_fw(sigmas, rgbs, deltas, ts, hits_t[:,0], rays_alive,
+ kwargs.get('T_threshold', 1e-4), N_eff_samples,
+ weights_sum, depth, image)
+
+ rays_alive = rays_alive[rays_alive >= 0]
+
+ step += n_step
+
+ # mix background color
+ if bg_color is None:
+ if self.opt.bg_radius > 0:
+ # use the bg model to calculate bg_color
+ bg_color = self.background(rays_d) # [N, 3]
+ else:
+ bg_color = 1
+
+ image = image + self.rearrange(1 - weights_sum, 'n -> n 1') * bg_color
+ image = image.view(*prefix, 3)
+
+ depth = depth.view(*prefix)
+
+ weights_sum = weights_sum.reshape(*prefix)
+
+ results['image'] = image
+ results['depth'] = depth
+ results['weights_sum'] = weights_sum
+
+ return results
+
+
+ @torch.no_grad()
+ def update_extra_state(self, decay=0.95, S=128):
+ # call before each epoch to update extra states.
+
+ if not (self.cuda_ray or self.taichi_ray):
+ return
+
+ ### update density grid
+ tmp_grid = - torch.ones_like(self.density_grid)
+
+ X = torch.arange(self.grid_size, dtype=torch.int32, device=self.aabb_train.device).split(S)
+ Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.aabb_train.device).split(S)
+ Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.aabb_train.device).split(S)
+
+ for xs in X:
+ for ys in Y:
+ for zs in Z:
+
+ # construct points
+ xx, yy, zz = custom_meshgrid(xs, ys, zs)
+ coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128)
+ indices = raymarching.morton3D(coords).long() # [N]
+ xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1]
+
+ # cascading
+ for cas in range(self.cascade):
+ bound = min(2 ** cas, self.bound)
+ half_grid_size = bound / self.grid_size
+ # scale to current cascade's resolution
+ cas_xyzs = xyzs * (bound - half_grid_size)
+ # add noise in [-hgs, hgs]
+ cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
+ # query density
+ sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach()
+ # assign
+ tmp_grid[cas, indices] = sigmas
+ # ema update
+ valid_mask = self.density_grid >= 0
+ self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])
+ self.mean_density = torch.mean(self.density_grid[valid_mask]).item()
+ self.iter_density += 1
+
+ # convert to bitfield
+ density_thresh = min(self.mean_density, self.density_thresh)
+ if self.cuda_ray:
+ self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield)
+ elif self.taichi_ray:
+ self.packbits_taichi(self.density_grid.reshape(-1).contiguous(), density_thresh, self.density_bitfield)
+
+ # print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > density_thresh).sum() / (128**3 * self.cascade):.3f}')
+
+
+ def render(self, rays_o, rays_d, mvp, h, w, staged=False, max_ray_batch=4096, **kwargs):
+ # rays_o, rays_d: [B, N, 3]
+ # return: pred_rgb: [B, N, 3]
+ B, N = rays_o.shape[:2]
+ device = rays_o.device
+
+ if self.dmtet:
+ results = self.run_dmtet(rays_o, rays_d, mvp, h, w, **kwargs)
+ elif self.cuda_ray:
+ results = self.run_cuda(rays_o, rays_d, **kwargs)
+ elif self.taichi_ray:
+ results = self.run_taichi(rays_o, rays_d, **kwargs)
+ else:
+ if staged:
+ depth = torch.empty((B, N), device=device)
+ image = torch.empty((B, N, 3), device=device)
+ weights_sum = torch.empty((B, N), device=device)
+
+ for b in range(B):
+ head = 0
+ while head < N:
+ tail = min(head + max_ray_batch, N)
+ results_ = self.run(rays_o[b:b+1, head:tail], rays_d[b:b+1, head:tail], **kwargs)
+ depth[b:b+1, head:tail] = results_['depth']
+ weights_sum[b:b+1, head:tail] = results_['weights_sum']
+ image[b:b+1, head:tail] = results_['image']
+ head += max_ray_batch
+
+ results = {}
+ results['depth'] = depth
+ results['image'] = image
+ results['weights_sum'] = weights_sum
+
+ else:
+ results = self.run(rays_o, rays_d, **kwargs)
+
+ return results
+
+ def init_tet_from_nerf(self, reset_scale=True):
+ sdf = self.get_sdf_from_nerf(reset_scale=reset_scale)
+ self.dmtet.init_tet_from_sdf(sdf)
+ logger.info(f'init dmtet from NeRF Done ...')
+
+
+ @torch.no_grad()
+ def get_sdf_from_nerf(self, reset_scale=True):
+ if self.cuda_ray:
+ density_thresh = min(self.mean_density, self.density_thresh)
+ else:
+ density_thresh = self.density_thresh
+
+ if reset_scale:
+ # init scale
+ sigma = self.density(self.dmtet.verts)[
+ 'sigma'] # verts covers [-1, 1] now
+ mask = sigma > density_thresh
+ valid_verts = self.dmtet.verts[mask]
+ tet_scale = valid_verts.abs().amax(dim=0) + 1e-1
+ self.dmtet.reset_tet_scale(tet_scale)
+ sdf = (self.density(self.dmtet.verts)[
+ 'sigma'] - density_thresh).clamp(-1, 1)
+ return sdf
diff --git a/nerf/utils.py b/nerf/utils.py
new file mode 100644
index 0000000..b2f100d
--- /dev/null
+++ b/nerf/utils.py
@@ -0,0 +1,1599 @@
+import os
+import glob
+import tqdm
+import random
+import logging
+import gc
+
+import numpy as np
+import imageio, imageio_ffmpeg
+import time
+
+import cv2
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.nn.functional as F
+import torch.distributed as dist
+from torch import Tensor
+from torch.utils.tensorboard import SummaryWriter
+import torchvision.transforms.functional as TF
+from torchvision.utils import make_grid
+from torchmetrics.functional import pearson_corrcoef
+
+from rich.console import Console
+from torch_ema import ExponentialMovingAverage
+
+from packaging import version as pver
+
+from nerf.clip import CLIP
+from easydict import EasyDict as edict
+logger = logging.getLogger(__name__)
+
+
+class AverageMeters(object):
+ """Computes and stores the average and current value"""
+ def __init__(self, keys=['loss']):
+ self.meters = edict()
+ self.keys= keys
+ self.reset()
+
+ def reset(self):
+ for key in self.keys:
+ self.meters[key] = {}
+ self.meters[key].val = 0
+ self.meters[key].avg = 0
+ self.meters[key].sum = 0
+ self.meters[key].count = 0
+
+ def reset_by_key(self, key):
+ self.meters[key] = {}
+ self.meters[key].val = 0
+ self.meters[key].avg = 0
+ self.meters[key].sum = 0
+ self.meters[key].count = 0
+
+ def update(self, in_dict, n=1):
+ for key, val in in_dict.items():
+ if key not in self.keys:
+ self.keys.append(key)
+ self.reset_by_key(key)
+ self.meters[key].val = val
+ self.meters[key].sum += val * n
+ self.meters[key].count += n
+ self.meters[key].avg = self.meters[key].sum / self.meters[key].count
+
+
+def setup_workspace(opt):
+ if opt.workspace is None or opt.workspace == '':
+ opt.workspace = 'out/'
+ if opt.text:
+ opt.workspace += '_'.join(opt.text.split(' '))
+ if opt.image:
+ opt.workspace += '_'.join('_'.join(opt.image.split('/')
+ [-2:]).split('.')[:-1])
+ opt.workspace += '+' + time.strftime('%Y%m%d-%H%M%S')
+ opt.runname = os.path.basename(opt.workspace)
+ os.makedirs(opt.workspace, exist_ok=True)
+ opt.log_path = os.path.join(opt.workspace, f"log_{opt.runname}.txt")
+ opt.ckpt_path = os.path.join(opt.workspace, 'checkpoints')
+ opt.best_path = f"{opt.ckpt_path}/{opt.runname}.pth"
+ os.makedirs(opt.ckpt_path, exist_ok=True)
+
+
+def custom_meshgrid(*args):
+ # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
+ if pver.parse(torch.__version__) < pver.parse('1.10'):
+ return torch.meshgrid(*args)
+ else:
+ return torch.meshgrid(*args, indexing='ij')
+
+def safe_normalize(x, eps=1e-20):
+ return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps))
+
+@torch.cuda.amp.autocast(enabled=False)
+def get_rays(poses, intrinsics, H, W, N=-1, error_map=None):
+ ''' get rays
+ Args:
+ poses: [B, 4, 4], cam2world
+ intrinsics: [4]
+ H, W, N: int
+ error_map: [B, 128 * 128], sample probability based on training error
+ Returns:
+ rays_o, rays_d: [B, N, 3]
+ inds: [B, N]
+ '''
+
+ device = poses.device
+ B = poses.shape[0]
+ fx, fy, cx, cy = intrinsics
+
+ i, j = custom_meshgrid(torch.linspace(0, W-1, W, device=device), torch.linspace(0, H-1, H, device=device))
+ i = i.t().reshape([1, H*W]).expand([B, H*W]) + 0.5
+ j = j.t().reshape([1, H*W]).expand([B, H*W]) + 0.5
+
+ results = {}
+
+ if N > 0:
+ N = min(N, H*W)
+
+ if error_map is None:
+ inds = torch.randint(0, H*W, size=[N], device=device) # may duplicate
+ inds = inds.expand([B, N])
+ else:
+
+ # weighted sample on a low-reso grid
+ inds_coarse = torch.multinomial(error_map.to(device), N, replacement=False) # [B, N], but in [0, 128*128)
+
+ # map to the original resolution with random perturb.
+ inds_x, inds_y = inds_coarse // 128, inds_coarse % 128 # `//` will throw a warning in torch 1.10... anyway.
+ sx, sy = H / 128, W / 128
+ inds_x = (inds_x * sx + torch.rand(B, N, device=device) * sx).long().clamp(max=H - 1)
+ inds_y = (inds_y * sy + torch.rand(B, N, device=device) * sy).long().clamp(max=W - 1)
+ inds = inds_x * W + inds_y
+
+ results['inds_coarse'] = inds_coarse # need this when updating error_map
+
+ i = torch.gather(i, -1, inds)
+ j = torch.gather(j, -1, inds)
+
+ results['inds'] = inds
+
+ else:
+ inds = torch.arange(H*W, device=device).expand([B, H*W])
+
+ zs = - torch.ones_like(i)
+ xs = - (i - cx) / fx * zs
+ ys = (j - cy) / fy * zs
+ directions = torch.stack((xs, ys, zs), dim=-1)
+ # directions = safe_normalize(directions)
+ rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3)
+
+ rays_o = poses[..., :3, 3] # [B, 3]
+ rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3]
+
+ results['rays_o'] = rays_o
+ results['rays_d'] = rays_d
+
+ return results
+
+
+def seed_everything(seed):
+ random.seed(seed)
+ os.environ['PYTHONHASHSEED'] = str(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ #torch.backends.cudnn.deterministic = True
+ #torch.backends.cudnn.benchmark = True
+
+
+def save_tensor2image(x: torch.Tensor, path, channel_last=False, quality=75, **kwargs):
+ # assume the input x is channel last
+ if x.ndim == 4 and channel_last:
+ x = x.permute(0, 3, 1, 2)
+ TF.to_pil_image(make_grid(x, value_range=(0, 1), **kwargs)).save(path, quality=quality)
+
+@torch.jit.script
+def linear_to_srgb(x):
+ return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055)
+
+
+@torch.jit.script
+def srgb_to_linear(x):
+ return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)
+
+
+def nonzero_normalize_depth(depth, mask=None):
+ if mask is not None:
+ if (depth[mask]>0).sum() > 0:
+ nonzero_depth_min = depth[mask][depth[mask]>0].min()
+ else:
+ nonzero_depth_min = 0
+ else:
+ if (depth>0).sum() > 0:
+ nonzero_depth_min = depth[depth>0].min()
+ else:
+ nonzero_depth_min = 0
+ if nonzero_depth_min == 0:
+ return depth
+ else:
+ depth = (depth - nonzero_depth_min) / depth.max()
+ return depth.clamp(0, 1)
+
+
+class Trainer(object):
+ def __init__(self,
+ argv, # command line args
+ name, # name of this experiment
+ opt, # extra conf
+ model, # network
+ guidance, # guidance network
+ criterion=None, # loss function, if None, assume inline implementation in train_step
+ optimizer=None, # optimizer
+ ema_decay=None, # if use EMA, set the decay
+ lr_scheduler=None, # scheduler
+ metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric.
+ local_rank=0, # which GPU am I
+ world_size=1, # total num of GPUs
+ device=None, # device to use, usually setting to None is OK. (auto choose device)
+ mute=False, # whether to mute all print
+ fp16=False, # amp optimize level
+ max_keep_ckpt=1, # max num of saved ckpts in disk
+ workspace='workspace', # workspace to save logs & ckpts
+ best_mode='min', # the smaller/larger result, the better
+ use_loss_as_metric=True, # use loss as the first metric
+ report_metric_at_train=False, # also report metrics at training
+ use_checkpoint="latest", # which ckpt to use at init time
+ use_tensorboard=True, # whether to use tensorboard for logging
+ scheduler_update_every_step=False, # whether to call scheduler.step() after every train step
+ ):
+
+ self.argv = argv
+ self.name = name
+ self.opt = opt
+ self.mute = mute
+ self.metrics = metrics
+ self.local_rank = local_rank
+ self.world_size = world_size
+ self.workspace = workspace
+ self.ema_decay = ema_decay
+ self.fp16 = fp16
+ self.best_mode = best_mode
+ self.use_loss_as_metric = use_loss_as_metric
+ self.report_metric_at_train = report_metric_at_train
+ self.max_keep_ckpt = max_keep_ckpt
+ self.use_checkpoint = use_checkpoint
+ self.use_tensorboard = use_tensorboard
+ self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S")
+ self.scheduler_update_every_step = scheduler_update_every_step
+ self.device = device if device is not None else torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu')
+ self.console = Console()
+
+ model.to(self.device)
+ if self.world_size > 1:
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
+ self.model = model
+
+ # guide model
+ self.guidance = guidance
+ self.embeddings = {}
+
+ # text prompt / images
+ if self.guidance is not None:
+ for key in self.guidance:
+ for p in self.guidance[key].parameters():
+ p.requires_grad = False
+ self.embeddings[key] = {}
+ self.prepare_embeddings()
+
+ if isinstance(criterion, nn.Module):
+ criterion.to(self.device)
+ self.criterion = criterion
+
+ if optimizer is None:
+ self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=5e-4) # naive adam
+ else:
+ self.optimizer = optimizer(self.model)
+
+ if lr_scheduler is None:
+ self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1) # fake scheduler
+ else:
+ self.lr_scheduler = lr_scheduler(self.optimizer)
+
+ if ema_decay:
+ self.ema = ExponentialMovingAverage(
+ self.model.parameters(), decay=ema_decay)
+ else:
+ self.ema = None
+
+ self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)
+
+ # variable init
+ self.total_train_t = 0
+ self.epoch = 0
+ self.global_step = 0
+ self.local_step = 0
+ self.novel_view_step = 0
+ self.stats = {
+ "loss": [],
+ "valid_loss": [],
+ "results": [], # metrics[0], or valid_loss
+ "checkpoints": [], # record path of saved ckpt, to automatically remove old ckpt
+ "best_result": None,
+ }
+ self.loss_meter = AverageMeters()
+ # auto fix
+ if len(metrics) == 0 or self.use_loss_as_metric:
+ self.best_mode = 'min'
+
+ logger.info(f'[INFO] cmdline: {self.argv}')
+ logger.info(f'args:\n{self.opt}')
+ logger.info(
+ f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}')
+ logger.info(
+ f'[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}')
+ logger.info(f'[INFO] #Optimizer: \n{self.optimizer}')
+ logger.info(f'[INFO] #Scheduler: \n{self.lr_scheduler}')
+
+ if self.workspace is not None:
+ if self.use_checkpoint == "scratch":
+ logger.info("[INFO] Training from scratch ...")
+ elif self.use_checkpoint == "latest":
+ logger.info("[INFO] Loading latest checkpoint ...")
+ self.load_checkpoint()
+ elif self.use_checkpoint == "latest_model":
+ logger.info("[INFO] Loading latest checkpoint (model only)...")
+ self.load_checkpoint(model_only=True)
+ elif self.use_checkpoint == "best":
+ if os.path.exists(self.opt.best_path):
+ logger.info("[INFO] Loading best checkpoint ...")
+ self.load_checkpoint(self.opt.best_path)
+ else:
+ logger.info(
+ f"[INFO] {self.opt.best_path} not found, loading latest ...")
+ self.load_checkpoint()
+ else: # path to ckpt
+ logger.info(f"[INFO] Loading {self.use_checkpoint} ...")
+ self.load_checkpoint(self.use_checkpoint)
+
+ # calculate the text embs.
+ @torch.no_grad()
+ def prepare_embeddings(self):
+
+ # text embeddings (stable-diffusion)
+ if self.opt.text is not None:
+
+ dir_texts = ['front', 'side', 'back']
+ if 'SD' in self.guidance:
+ self.embeddings['SD']['default'] = self.guidance['SD'].get_all_text_embeds([self.opt.text])
+ neg_embedding = self.guidance['SD'].get_all_text_embeds([self.opt.negative])
+
+ for idx, d in enumerate(dir_texts):
+ text = f"{self.opt.text}, {d} view"
+ self.embeddings['SD'][d] = self.guidance['SD'].get_all_text_embeds([text])
+ if self.opt.dir_texts_neg:
+ text_neg = self.opt.negative + ', '.join([text+' view' for i, text in enumerate(dir_texts) if i != idx])
+ logger.info(f'dir_texts of {d}\n postive text: {text},\n negative text: {text_neg}')
+ neg_embedding= self.guidance['SD'].get_all_text_embeds([text_neg])
+ self.embeddings['SD'][d] = torch.cat((neg_embedding, self.embeddings['SD'][d]), dim=0)
+
+ if 'IF' in self.guidance:
+ self.embeddings['IF']['default'] = self.guidance['IF'].get_text_embeds([self.opt.text])
+ neg_embedding = self.guidance['IF'].get_text_embeds([self.opt.negative])
+
+ for idx, d in enumerate(dir_texts):
+ text = f"{self.opt.text}, {d} view"
+ self.embeddings['IF'][d] = self.guidance['IF'].get_text_embeds([text])
+ if self.opt.dir_texts_neg:
+ text_neg = self.opt.negative + ', '.join([text+' view' for i, text in enumerate(dir_texts) if i != idx])
+ logger.info(f'dir_texts of {d}\n postive text: {text},\n negative text: {text_neg}')
+ neg_embedding= self.guidance['IF'].get_all_text_embeds([text_neg])
+ self.embeddings['IF'][d] = torch.cat((neg_embedding, self.embeddings['IF'][d]), dim=0)
+
+ # if 'clip' in self.guidance:
+ # self.embeddings['clip']['text'] = self.guidance['clip'].get_text_embeds(self.opt.text)
+
+ if self.opt.images is not None:
+
+ h = int(self.opt.known_view_scale * self.opt.h)
+ w = int(self.opt.known_view_scale * self.opt.w)
+
+ # load processed image and remove edges
+ rgbas = []
+ rgbas_hw = []
+ mask_no_edges = []
+ for image in self.opt.images:
+ rgba = cv2.cvtColor(cv2.imread(image, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA)
+ rgbas.append(rgba)
+ rgba_hw = cv2.resize(rgba, (w, h), interpolation=cv2.INTER_AREA).astype(np.float32) / 255
+ rgbas_hw.append(rgba_hw)
+ if self.opt.rm_edge:
+ alpha = np.uint8(rgba_hw[..., 3] * 255.)
+ dilate = cv2.dilate(alpha, np.ones((self.opt.edge_width, self.opt.edge_width), np.uint8))
+ edge = cv2.absdiff(alpha, dilate).astype(np.float32) / 255
+ mask_no_edge = rgba_hw[..., 3] > 0.5
+ mask_no_edge[edge>self.opt.edge_threshold] = False
+ mask_no_edges.append(mask_no_edge)
+ rgba_hw = np.stack(rgbas_hw)
+ mask = rgba_hw[..., 3] > 0.5
+ if len(mask_no_edges) > 0:
+ mask_no_edge = np.stack(mask_no_edges)
+ else:
+ mask_no_edge = mask
+
+ # breakpoint()
+ # rgb
+ rgb_hw = rgba_hw[..., :3] * rgba_hw[..., 3:] + (1 - rgba_hw[..., 3:])
+ self.rgb = torch.from_numpy(rgb_hw).permute(0,3,1,2).contiguous().to(self.device)
+ self.mask = torch.from_numpy(mask).to(self.device)
+ self.opacity = torch.from_numpy(mask_no_edge).to(self.device).to(torch.float32).unsqueeze(0)
+ print(f'[INFO] dataset: load image prompt {self.opt.images} {self.rgb.shape}')
+
+ # load depth
+ depth_paths = [image.replace('rgba', 'depth') for image in self.opt.images]
+ if os.path.exists(depth_paths[0]):
+ depths = [cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) for depth_path in depth_paths]
+ depth = np.stack([cv2.resize(depth, (w, h), interpolation=cv2.INTER_AREA) for depth in depths])
+ self.depth = 1 - torch.from_numpy(depth.astype(np.float32) / 255).to(self.device)
+ if len(self.depth.shape) == 4 and self.depth.shape[-1] > 1:
+ self.depth = self.depth[..., 0]
+ logger.info(f'[WARN] dataset: {depth_paths[0]} has more than one channel, only use the first channel')
+ if self.opt.normalize_depth:
+ self.depth = nonzero_normalize_depth(self.depth, self.mask)
+ save_tensor2image(self.depth, os.path.join(self.workspace, 'depth_resized.jpg'))
+ self.depth = self.depth[self.mask]
+ print(f'[INFO] dataset: load depth prompt {depth_paths} {self.depth.shape}')
+ else:
+ self.depth = None
+ logger.info(f'[WARN] dataset: {depth_paths[0]} is not found')
+
+ # load normal
+ normal_paths = [image.replace('rgba', 'normal') for image in self.opt.images]
+ if os.path.exists(normal_paths[0]):
+ normals = []
+ for normal_path in normal_paths:
+ normal = cv2.imread(normal_path, cv2.IMREAD_UNCHANGED)
+ if normal.shape[-1] == 4:
+ normal = cv2.cvtColor(normal, cv2.COLOR_BGRA2RGB)
+ normals.append(normal)
+ normal = np.stack([cv2.resize(normal, (w, h), interpolation=cv2.INTER_AREA) for normal in normals])
+ self.normal = torch.from_numpy(normal.astype(np.float32) / 255).to(self.device)
+ save_tensor2image(self.normal, os.path.join(self.workspace, 'normal_resized.jpg'), channel_last=True)
+ print(f'[INFO] dataset: load normal prompt {normal_paths} {self.normal.shape}')
+ self.normal = self.normal[self.mask]
+ else:
+ self.normal = None
+ logger.info(f'[WARN] dataset: {normal_paths[0]} is not found')
+
+ # save for debug
+ save_tensor2image(self.rgb, os.path.join(self.workspace, 'rgb_resized.png'), channel_last=False)
+ save_tensor2image(self.opacity, os.path.join(self.workspace, 'opacity_resized.png'), channel_last=False)
+
+ # encode embeddings for zero123
+ if 'zero123' in self.guidance:
+ rgba_256 = np.stack([cv2.resize(rgba, (256, 256), interpolation=cv2.INTER_AREA).astype(np.float32) / 255 for rgba in rgbas])
+ rgbs_256 = rgba_256[..., :3] * rgba_256[..., 3:] + (1 - rgba_256[..., 3:])
+ rgb_256 = torch.from_numpy(rgbs_256).permute(0,3,1,2).contiguous().to(self.device)
+ guidance_embeds = self.guidance['zero123'].get_img_embeds(rgb_256)
+ self.embeddings['zero123']['default'] = {
+ 'zero123_ws' : self.opt.zero123_ws,
+ 'c_crossattn' : guidance_embeds[0],
+ 'c_concat' : guidance_embeds[1],
+ 'ref_polars' : self.opt.ref_polars,
+ 'ref_azimuths' : self.opt.ref_azimuths,
+ 'ref_radii' : self.opt.ref_radii,
+ }
+
+ # if 'clip' in self.guidance:
+ # self.embeddings['clip']['image'] = self.guidance['clip'].get_img_embeds(self.rgb)
+ # encoder image for clip
+ if self.opt.use_clip:
+ self.rgb_clip_embed = self.guidance.get_clip_img_embeds(self.rgb)
+ # debug.
+ scaler = torch.cuda.amp.GradScaler()
+ image = torch.randn((1,3,512,512), device=self.device, requires_grad=True)
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
+ loss = self.guidance.clip_loss(self.rgb_clip_embed, image)
+ scaler.scale(loss).backward()
+ else:
+ self.rgb_clip_embed = None
+
+
+ # ------------------------------
+ @torch.no_grad()
+ def match_known(self, **kwargs):
+ self.model.eval()
+ data = self.default_view_data
+ rays_o = data['rays_o'] # [B, N, 3]
+ rays_d = data['rays_d'] # [B, N, 3]
+ mvp = data['mvp'] # [B, 4, 4]
+
+ B, N = rays_o.shape[:2]
+ H, W = data['H'], data['W']
+
+ ambient_ratio = 1.0
+ shading = self.opt.known_shading
+ binarize = False
+ bg_color = self.get_bg_color(
+ self.opt.bg_color_known, B*N, rays_o.device)
+
+ # add camera noise to avoid grid-like artifect
+ # * (1 - self.global_step / self.opt.iters)
+ noise_scale = self.opt.known_view_noise_scale
+ rays_o = rays_o + torch.randn(3, device=self.device) * noise_scale
+ rays_d = rays_d + torch.randn(3, device=self.device) * noise_scale
+
+ outputs = self.model.render(rays_o, rays_d, mvp, H, W, staged=False, perturb=True,
+ bg_color=bg_color, ambient_ratio=ambient_ratio, shading=shading, binarize=binarize)
+ pred_rgb = outputs['image'].reshape(B, H, W, 3).permute(
+ 0, 3, 1, 2).contiguous() # [1, 3, H, W]
+ pred_mask = outputs['weights_sum'].reshape(B, 1, H, W)
+
+ rgb_loss = self.opt.lambda_rgb * \
+ F.mse_loss(pred_rgb*self.opacity,
+ self.rgb*self.opacity)
+ mask_loss = self.opt.lambda_mask * \
+ F.mse_loss(pred_mask, self.mask.to(torch.float32).unsqueeze(0))
+ return pred_rgb, pred_mask, rgb_loss, mask_loss
+
+ def get_bg_color(self, bg_type, N, device):
+ if bg_type is None:
+ return None
+ elif isinstance(bg_type, str):
+ if bg_type == 'pixelnoise':
+ bg_color = torch.rand((N, 3), device=device)
+ elif bg_type == 'noise':
+ bg_color = torch.rand((1, 3), device=device).repeat(N, 1)
+ elif bg_type == 'white':
+ bg_color = torch.ones((N, 3), device=device)
+ return bg_color
+ elif isinstance(bg_type, Tensor):
+ bg_color = bg_color.to(device)
+ return bg_color
+ else:
+ raise NotImplementedError(f"{bg_type} is not implemented")
+
+ # def margin_rank_loss(self, depth):
+ # # high res, only calc on fg
+ # output = depth.squeeze().view(-1)
+ # output = output[self.fg_idx]
+ # num = output.shape[0] # [n, 1]
+ # # print(num)
+ # output = output.reshape(1, -1)
+ # o1 = output.expand(num, -1).reshape(-1)
+ # o2 = output.T.expand(-1, num).reshape(-1)
+ # return F.margin_ranking_loss(o1, o2, self.rank_loss_target)
+
+ def train_step(self, data):
+ # perform RGBD loss instead of SDS if is image-conditioned
+ do_rgbd_loss = self.opt.images is not None and \
+ (self.global_step < self.opt.known_iters) or (self.global_step % self.opt.known_view_interval == 0)
+
+ # override random camera with fixed known camera
+ if do_rgbd_loss:
+ data = self.default_view_data
+
+ # progressively relaxing view range
+ if self.opt.progressive_view:
+ r = min(1.0, 0.2 + self.global_step / (0.5 * self.opt.iters))
+ self.opt.phi_range = [self.opt.default_azimuth * (1 - r) + self.opt.full_phi_range[0] * r,
+ self.opt.default_azimuth * (1 - r) + self.opt.full_phi_range[1] * r]
+ self.opt.theta_range = [self.opt.default_polar * (1 - r) + self.opt.full_theta_range[0] * r,
+ self.opt.default_polar * (1 - r) + self.opt.full_theta_range[1] * r]
+ self.opt.radius_range = [self.opt.default_radius * (1 - r) + self.opt.full_radius_range[0] * r,
+ self.opt.default_radius * (1 - r) + self.opt.full_radius_range[1] * r]
+ self.opt.fovy_range = [self.opt.default_fovy * (1 - r) + self.opt.full_fovy_range[0] * r,
+ self.opt.default_fovy * (1 - r) + self.opt.full_fovy_range[1] * r]
+
+ # progressively increase max_level
+ if self.opt.progressive_level:
+ self.model.max_level = min(1.0, 0.25 + self.global_step / (0.5 * self.opt.iters))
+
+ rays_o = data['rays_o'] # [B, N, 3]
+ rays_d = data['rays_d'] # [B, N, 3]
+ mvp = data['mvp'] # [B, 4, 4]
+
+ B, N = rays_o.shape[:2]
+ H, W = data['H'], data['W']
+
+ # When ref_data has B images > opt.batch_size
+ if B > self.opt.batch_size:
+ # choose batch_size images out of those B images
+ choice = torch.randperm(B)[:self.opt.batch_size]
+ B = self.opt.batch_size
+ rays_o = rays_o[choice]
+ rays_d = rays_d[choice]
+ mvp = mvp[choice]
+
+ if do_rgbd_loss:
+ ambient_ratio = 1.0
+ shading = 'lambertian' # use lambertian instead of albedo to get normal
+ as_latent = False
+ binarize = False
+ bg_color = self.get_bg_color(
+ self.opt.bg_color_known, B*N, rays_o.device)
+
+ # add camera noise to avoid grid-like artifact
+ if self.opt.known_view_noise_scale > 0:
+ noise_scale = self.opt.known_view_noise_scale #* (1 - self.global_step / self.opt.iters)
+ rays_o = rays_o + torch.randn(3, device=self.device) * noise_scale
+ rays_d = rays_d + torch.randn(3, device=self.device) * noise_scale
+
+ elif self.global_step < (self.opt.latent_iter_ratio * self.opt.iters):
+ ambient_ratio = 1.0
+ shading = 'normal'
+ as_latent = True
+ binarize = False
+ bg_color = None
+
+ else:
+ if self.global_step < (self.opt.normal_iter_ratio * self.opt.iters):
+ ambient_ratio = 1.0
+ shading = 'normal'
+ elif self.global_step < (self.opt.textureless_iter_ratio * self.opt.iters):
+ ambient_ratio = 0.1 + 0.9 * random.random()
+ shading = 'textureless'
+ elif self.global_step < (self.opt.albedo_iter_ratio * self.opt.iters):
+ ambient_ratio = 1.0
+ shading = 'albedo'
+ else:
+ # random shading
+ ambient_ratio = 0.1 + 0.9 * random.random()
+ rand = random.random()
+ if rand > 0.8:
+ shading = 'textureless'
+ else:
+ shading = 'lambertian'
+
+ as_latent = False
+
+ # random weights binarization (like mobile-nerf) [NOT WORKING NOW]
+ # binarize_thresh = min(0.5, -0.5 + self.global_step / self.opt.iters)
+ # binarize = random.random() < binarize_thresh
+ binarize = False
+
+ # random background
+ rand = random.random()
+ if self.opt.bg_radius > 0 and rand > 0.5:
+ bg_color = None # use bg_net
+ else:
+ bg_color = torch.rand(3).to(self.device) # single color random bg
+
+ outputs = self.model.render(rays_o, rays_d, mvp, H, W, staged=False, perturb=True, bg_color=bg_color, ambient_ratio=ambient_ratio, shading=shading, binarize=binarize)
+ pred_depth = outputs['depth'].reshape(B, 1, H, W)
+ if self.opt.normalize_depth:
+ pred_depth = nonzero_normalize_depth(pred_depth)
+ pred_mask = outputs['weights_sum'].reshape(B, 1, H, W)
+ if 'normal_image' in outputs:
+ pred_normal = outputs['normal_image'].reshape(B, H, W, 3)
+ else:
+ pred_normal = None
+
+ if as_latent:
+ # abuse normal & mask as latent code for faster geometry initialization (ref: fantasia3D)
+ pred_rgb = torch.cat([outputs['image'], outputs['weights_sum'].unsqueeze(-1)], dim=-1).reshape(B, H, W, 4).permute(0, 3, 1, 2).contiguous() # [B, 4, H, W]
+ else:
+ pred_rgb = outputs['image'].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous() # [B, 3, H, W]
+ out_dict = {
+ 'rgb': pred_rgb,
+ 'depth': pred_depth,
+ 'mask': pred_mask,
+ 'normal': pred_normal,
+ }
+
+ # Loss
+ # known view loss
+ loss_rgb, loss_mask, loss_normal, loss_depth, loss_sds, loss_if, loss_zero123, loss_clip, loss_entropy, loss_opacity, loss_orient, loss_smooth, loss_smooth2d, loss_smooth3d, loss_mesh_normal, loss_mesh_lap = torch.zeros(16, device=self.device)
+ # known view loss
+ if do_rgbd_loss:
+ gt_mask = self.mask # [B, H, W]
+ gt_rgb = self.rgb # [B, 3, H, W]
+ gt_opacity = self.opacity # [B, 1, H, W]
+ gt_normal = self.normal # [B, H, W, 3]
+ gt_depth = self.depth # [B, H, W]
+
+ if len(gt_rgb) > self.opt.batch_size:
+ gt_mask = gt_mask[choice]
+ gt_rgb = gt_rgb[choice]
+ gt_opacity = gt_opacity[choice]
+ gt_normal = gt_normal[choice]
+ gt_depth = gt_depth[choice]
+
+ # color loss
+ loss_rgb = self.opt.lambda_rgb * \
+ F.mse_loss(pred_rgb*gt_opacity, gt_rgb*gt_opacity)
+
+ # mask loss
+ loss_mask = self.opt.lambda_mask * F.mse_loss(pred_mask, gt_mask.to(torch.float32).unsqueeze(0))
+
+ # normal loss
+ if self.opt.lambda_normal > 0 and 'normal_image' in outputs and self.normal is not None:
+ pred_normal = pred_normal[self.mask]
+ lambda_normal = self.opt.lambda_normal * \
+ min(1, self.global_step / self.opt.iters)
+ loss_normal = lambda_normal * \
+ (1 - F.cosine_similarity(pred_normal, self.normal).mean())/2
+
+ # relative depth loss
+ if self.opt.lambda_depth > 0 and self.depth is not None:
+ valid_pred_depth = pred_depth[:, 0][self.mask]
+ loss_depth = self.opt.lambda_depth * (1 - pearson_corrcoef(valid_pred_depth, self.depth))/2
+
+ loss = loss_rgb + loss_mask + loss_normal + loss_depth
+ # novel view loss
+ else:
+ save_guidance_path = os.path.join(self.opt.workspace, 'guidance', f'train_step{self.global_step}_guidance.jpg') if self.opt.save_guidance_every > 0 and self.novel_view_step % self.opt.save_guidance_every ==0 else None
+ if 'SD' in self.guidance:
+ # interpolate text_z
+ azimuth = data['azimuth'] # [-180, 180]
+
+ # ENHANCE: remove loop to handle batch size > 1
+ text_z = []
+ for b in range(azimuth.shape[0]):
+ if azimuth[b] >= -90 and azimuth[b] < 90:
+ if azimuth[b] >= 0:
+ r = 1 - azimuth[b] / 90
+ else:
+ r = 1 + azimuth[b] / 90
+ start_z = self.embeddings['SD']['front']
+ end_z = self.embeddings['SD']['side']
+ else:
+ if azimuth[b] >= 0:
+ r = 1 - (azimuth[b] - 90) / 90
+ else:
+ r = 1 + (azimuth[b] + 90) / 90
+ start_z = self.embeddings['SD']['side']
+ end_z = self.embeddings['SD']['back']
+ text_z.append(r * start_z + (1 - r) * end_z)
+ text_z = torch.stack(text_z, dim=0).transpose(0, 1).flatten(0, 1)
+ text_z_sds = text_z[:, :-1]
+ loss_sds, _ = self.guidance['SD'].train_step(text_z_sds, pred_rgb, as_latent=as_latent, guidance_scale=self.opt.guidance_scale['SD'], grad_scale=self.opt.lambda_guidance['SD'],
+ density=pred_mask if self.opt.gudiance_spatial_weighting else None,
+ save_guidance_path=save_guidance_path
+ )
+ # if self.opt.lambda_clip > 0:
+ # lambda_clip = 10 * (1 - abs(azimuth) / 180) * self.opt.lambda_clip
+ # if self.opt.clip_image_loss:
+ # loss_clip = lambda_clip * self.guidance.clip_loss(self.rgb_clip_embed, pred_rgb)
+ # else:
+ # loss_clip = lambda_clip * self.guidance.clip_loss(text_z_clip, pred_rgb)
+
+ if 'IF' in self.guidance:
+ # interpolate text_z
+ azimuth = data['azimuth'] # [-180, 180]
+
+ # ENHANCE: remove loop to handle batch size > 1
+ # ENHANCE: remove loop to handle batch size > 1
+ text_z = []
+ for b in range(azimuth.shape[0]):
+ if azimuth[b] >= -90 and azimuth[b] < 90:
+ if azimuth[b] >= 0:
+ r = 1 - azimuth[b] / 90
+ else:
+ r = 1 + azimuth[b] / 90
+ start_z = self.embeddings['IF']['front']
+ end_z = self.embeddings['IF']['side']
+ else:
+ if azimuth[b] >= 0:
+ r = 1 - (azimuth[b] - 90) / 90
+ else:
+ r = 1 + (azimuth[b] + 90) / 90
+ start_z = self.embeddings['IF']['side']
+ end_z = self.embeddings['IF']['back']
+ text_z.append(r * start_z + (1 - r) * end_z)
+ text_z = torch.stack(text_z, dim=0).transpose(0, 1).flatten(0, 1)
+ text_z = torch.cat(text_z, dim=1).reshape(B, 2, start_z.shape[-2]-1, start_z.shape[-1]).transpose(0, 1).flatten(0, 1)
+ loss_if = self.guidance['IF'].train_step(text_z, pred_rgb, guidance_scale=self.opt.guidance_scale['IF'], grad_scale=self.opt.lambda_guidance['IF'])
+
+ if 'zero123' in self.guidance:
+
+ polar = data['polar']
+ azimuth = data['azimuth']
+ radius = data['radius']
+
+ loss_zero123 = self.guidance['zero123'].train_step(self.embeddings['zero123']['default'], pred_rgb, polar, azimuth, radius, guidance_scale=self.opt.guidance_scale['zero123'],
+ as_latent=as_latent, grad_scale=self.opt.lambda_guidance['zero123'], save_guidance_path=save_guidance_path)
+
+ if 'clip' in self.guidance:
+
+ # empirical, far view should apply smaller CLIP loss
+ lambda_guidance = 10 * (1 - abs(azimuth) / 180) * self.opt.lambda_guidance['clip']
+ loss_clip = self.guidance['clip'].train_step(self.embeddings['clip'], pred_rgb, grad_scale=lambda_guidance)
+ loss = loss_sds + loss_if + loss_zero123 + loss_clip
+
+ # regularizations
+ if not self.opt.dmtet:
+
+ if self.opt.lambda_opacity > 0:
+ loss_opacity = self.opt.lambda_opacity * (outputs['weights_sum'] ** 2).mean()
+
+ if self.opt.lambda_entropy > 0:
+ lambda_entropy = self.opt.lambda_entropy * \
+ min(1, 2 * self.global_step / self.opt.iters)
+ alphas = outputs['weights'].clamp(1e-5, 1 - 1e-5)
+ # alphas = alphas ** 2 # skewed entropy, favors 0 over 1
+ loss_entropy = lambda_entropy * (- alphas * torch.log2(alphas) -
+ (1 - alphas) * torch.log2(1 - alphas)).mean()
+
+ if self.opt.lambda_normal_smooth > 0 and 'normal_image' in outputs:
+ pred_vals = outputs['normal_image'].reshape(B, H, W, 3)
+ # total-variation
+ loss_smooth = (pred_vals[:, 1:, :, :] - pred_vals[:, :-1, :, :]).square().mean() + \
+ (pred_vals[:, :, 1:, :] -
+ pred_vals[:, :, :-1, :]).square().mean()
+ loss_smooth = self.opt.lambda_normal_smooth * loss_smooth
+
+ if self.opt.lambda_normal_smooth2d > 0 and 'normal_image' in outputs:
+ pred_vals = outputs['normal_image'].reshape(
+ B, H, W, 3).permute(0, 3, 1, 2).contiguous()
+ smoothed_vals = TF.gaussian_blur(pred_vals, kernel_size=9)
+ loss_smooth2d = self.opt.lambda_normal_smooth2d * F.mse_loss(pred_vals, smoothed_vals)
+
+ if self.opt.lambda_orient > 0 and 'loss_orient' in outputs:
+ loss_orient = self.opt.lambda_orient * outputs['loss_orient']
+
+ if self.opt.lambda_3d_normal_smooth > 0 and 'loss_normal_perturb' in outputs:
+ loss_smooth3d = self.opt.lambda_3d_normal_smooth * outputs['loss_normal_perturb']
+
+ loss += loss_opacity + loss_entropy + loss_smooth + loss_smooth2d + loss_orient + loss_smooth3d
+
+ else:
+ if self.opt.lambda_mesh_normal > 0:
+ loss_mesh_normal = self.opt.lambda_mesh_normal * \
+ outputs['loss_normal']
+
+ if self.opt.lambda_mesh_lap > 0:
+ loss_mesh_lap = self.opt.lambda_mesh_lap * outputs['loss_lap']
+ loss += loss_mesh_normal + loss_mesh_lap
+
+ losses_dict = {
+ 'loss': loss.item(),
+ 'loss_sds': loss_sds.item(),
+ 'loss_if': loss_if.item(),
+ 'loss_zero123': loss_zero123.item(),
+ 'loss_clip': loss_clip.item(),
+ 'loss_rgb': loss_rgb.item(),
+ 'loss_mask': loss_mask.item(),
+ 'loss_normal': loss_normal.item(),
+ 'loss_depth': loss_depth.item(),
+ 'loss_opacity': loss_opacity.item(),
+ 'loss_entropy': loss_entropy.item(),
+ 'loss_smooth': loss_smooth.item(),
+ 'loss_smooth2d': loss_smooth2d.item(),
+ 'loss_smooth3d': loss_smooth3d.item(),
+ 'loss_orient': loss_orient.item(),
+ 'loss_mesh_normal': loss_mesh_normal.item(),
+ 'loss_mesh_lap': loss_mesh_lap.item(),
+ }
+ # if loss_guidance_dict:
+ # for key, val in loss_guidance_dict.items():
+ # losses_dict[key] = val.item() if isinstance(val, torch.Tensor) else val
+
+ if 'normal' in out_dict:
+ out_dict['normal'] = out_dict['normal'].permute(0, 3, 1, 2).contiguous()
+
+ # save for debug purpose
+ if self.opt.save_train_every > 0 and self.global_step % self.opt.save_train_every == 0:
+ image_save_path = os.path.join(self.workspace, 'train_debug',)
+ os.makedirs(image_save_path, exist_ok=True)
+ for key, value in out_dict.items():
+ if value is not None:
+ value = ((value - value.min()) / (value.max() - value.min() + 1e-6)).detach().mul(255).to(torch.uint8)
+ try:
+ save_tensor2image(value, os.path.join(image_save_path, f'train_{self.global_step:06d}_{key}.jpg'), channel_last=False)
+ except:
+ pass
+ return loss, losses_dict, out_dict
+
+ def post_train_step(self):
+
+ # unscale grad before modifying it!
+ # ref: https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-clipping
+ self.scaler.unscale_(self.optimizer)
+
+ # clip grad
+ if self.opt.grad_clip >= 0:
+ torch.nn.utils.clip_grad_value_(self.model.parameters(), self.opt.grad_clip)
+
+ if not self.opt.dmtet and self.opt.backbone == 'grid':
+
+ if self.opt.lambda_tv > 0:
+ lambda_tv = min(1.0, self.global_step / (0.5 * self.opt.iters)) * self.opt.lambda_tv
+ self.model.encoder.grad_total_variation(lambda_tv, None, self.model.bound)
+ if self.opt.lambda_wd > 0:
+ self.model.encoder.grad_weight_decay(self.opt.lambda_wd)
+
+ def eval_step(self, data):
+
+ rays_o = data['rays_o'] # [B, N, 3]
+ rays_d = data['rays_d'] # [B, N, 3]
+ mvp = data['mvp']
+
+ B, N = rays_o.shape[:2]
+ H, W = data['H'], data['W']
+
+ shading = data['shading'] if 'shading' in data else 'lambertian'
+ ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0
+ light_d = data['light_d'] if 'light_d' in data else None
+
+ outputs = self.model.render(rays_o, rays_d, mvp, H, W, staged=True, perturb=False, bg_color=None, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading)
+ pred_rgb = outputs['image'].reshape(B, H, W, 3)
+ pred_depth = outputs['depth'].reshape(B, H, W, 1)
+ if self.opt.normalize_depth:
+ pred_depth = nonzero_normalize_depth(pred_depth)
+ if 'normal_image' in outputs:
+ pred_normal = outputs['normal_image'].reshape(B, H, W, 3)
+ else:
+ pred_normal = None
+ out_dict = {
+ shading: pred_rgb,
+ 'depth': pred_depth,
+ 'normal_image': pred_normal,
+ }
+ # dummy
+ loss = torch.zeros([1], device=pred_rgb.device, dtype=pred_rgb.dtype)
+ return out_dict, loss
+
+ def test_step(self, data, bg_color=None, perturb=False, shading='lambertian'):
+ rays_o = data['rays_o'] # [B, N, 3]
+ rays_d = data['rays_d'] # [B, N, 3]
+ mvp = data['mvp']
+
+ B, N = rays_o.shape[:2]
+ H, W = data['H'], data['W']
+
+ bg_color = self.get_bg_color(bg_color, B*N, rays_o.device)
+
+ shading = data['shading'] if 'shading' in data else shading
+ ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0
+ light_d = data['light_d'] if 'light_d' in data else None
+
+ outputs = self.model.render(rays_o, rays_d, mvp, H, W, staged=True, perturb=perturb, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading, bg_color=bg_color)
+
+ pred_rgb = outputs['image'].reshape(B, H, W, 3)
+ pred_depth = outputs['depth'].reshape(B, H, W, 1)
+ pred_mask = outputs['weights_sum'].reshape(B, H, W, 1)
+ # if self.opt.normalize_depth:
+ pred_depth = nonzero_normalize_depth(pred_depth)
+ if 'normal_image' in outputs:
+ pred_normal = outputs['normal_image'].reshape(B, H, W, 3)
+ pred_normal = pred_normal * pred_mask + (1.0 - pred_mask)
+ else:
+ pred_normal = None
+ out_dict = {
+ shading: pred_rgb,
+ 'depth': pred_depth,
+ 'normal_image': pred_normal,
+ 'mask': pred_mask,
+ }
+ return out_dict
+
+ def save_mesh(self, loader=None, save_path=None):
+
+ if save_path is None:
+ save_path = os.path.join(self.workspace, 'mesh')
+
+ logger.info(f"==> Saving mesh to {save_path}")
+
+ os.makedirs(save_path, exist_ok=True)
+
+ self.model.export_mesh(save_path, resolution=self.opt.mcubes_resolution, decimate_target=self.opt.decimate_target)
+
+ logger.info(f"==> Finished saving mesh.")
+
+ ### ------------------------------
+
+ def train(self, train_loader, valid_loader, test_loader, max_epochs):
+
+ if self.use_tensorboard and self.local_rank == 0:
+ self.writer = SummaryWriter(
+ os.path.join(self.workspace, "run", self.name))
+
+ # init from nerf should be performed after Shap-E, since Shap-E will rescale dmtet
+ if self.opt.dmtet and (self.opt.init_ckpt and os.path.exists(self.opt.init_ckpt)):
+ reset_scale = False if self.opt.use_shape else True
+ old_sdf = self.model.get_sdf_from_nerf(reset_scale)
+ if not self.opt.tet_mlp:
+ self.model.dmtet.init_tet_from_sdf(old_sdf)
+ self.test(valid_loader, name=f'init_ckpt', write_video=False, save_each_frame=False, subfolder='check_init')
+ else:
+ old_sdf = None
+
+ if self.opt.use_shape and self.opt.dmtet:
+ os.makedirs(os.path.join(self.opt.workspace, 'shape'), exist_ok=True)
+ best_loss = torch.inf
+ best_idx = 0
+ for idx, (sdf, color) in enumerate(zip(self.opt.rpsts, self.opt.colors)):
+ self.model.init_tet_from_sdf_color(sdf)
+ pred_rgb, pred_mask, rgb_loss, mask_loss = self.match_known()
+ best_loss = min(best_loss, mask_loss)
+ if best_loss == mask_loss:
+ best_idx = idx
+ logger.info(f"==> Current best match shape known sdf idx: {best_idx}")
+ save_tensor2image(pred_mask, os.path.join(self.opt.workspace, 'shape', f"match_shape_known_{idx}_rgb.jpg"), channel_last=False)
+ self.test(valid_loader, name=f'idx_{idx}', write_video=False, save_each_frame=False, subfolder='check_init')
+
+ sdf = self.opt.rpsts[best_idx]
+ self.model.init_tet_from_sdf_color(sdf, self.opt.colors[best_idx])
+ self.test(valid_loader, name=f'shape_only', write_video=False, save_each_frame=False, subfolder='check_init')
+
+ # Enable mixture model
+ if self.opt.base_mesh:
+ logger.info(f"==> Enable mixture model with base mesh {self.opt.base_mesh}")
+ mesh_sdf = self.model.dmtet.get_sdf_from_mesh(self.opt.base_mesh)
+ sdf = (mesh_sdf.clamp(0, 1) + sdf.clamp(0,1) ).clamp(0, 1)
+
+ if old_sdf is not None:
+ sdf = (sdf.clamp(0, 1) + old_sdf.clamp(0, 1)).clamp(0, 1)
+
+ self.model.init_tet_from_sdf_color(sdf, self.opt.colors[best_idx])
+ self.test(valid_loader, name=f'shape_merge', write_video=False, save_each_frame=False, subfolder='check_init')
+
+ del best_loss, best_idx, pred_rgb, pred_mask, rgb_loss, mask_loss
+ self.opt.rpsts = None
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # init shape for NeRF. NOTE: Does not work yet.. in progress.
+ # if self.opt.use_shape and not self.opt.dmtet:
+ # os.makedirs(os.path.join(self.opt.workspace, 'shape'), exist_ok=True)
+ # best_loss = torch.inf
+ # best_idx = 0
+ # for idx, (density, color) in enumerate(zip(self.opt.rpsts, self.opt.colors)):
+ # self.model.init_nerf_from_sdf_color(density, color, self.opt.points, lr=0.001)
+ # pred_rgb, pred_mask, rgb_loss, mask_loss = self.match_known()
+ # best_loss = min(best_loss, mask_loss)
+ # if best_loss == mask_loss:
+ # best_idx = idx
+ # logger.info(f"==> Current best match shape known sdf idx: {best_idx}")
+ # save_tensor2image(pred_mask, os.path.join(self.opt.workspace, 'shape', f"match_shape_known_{idx}_rgb.jpg"), channel_last=False)
+ # self.evaluate_one_epoch(valid_loader, f'idx_{idx}')
+ # self.model.init_nerf_from_sdf_color(self.opt.rpsts[best_idx], self.opt.colors[best_idx])
+ # self.evaluate_one_epoch(valid_loader, f'init_from_shape_{idx}')
+
+ # del best_loss, best_idx, pred_rgb, pred_mask, rgb_loss, mask_loss
+ # self.opt.rpsts = None
+ # self.opt.colors = None
+ # self.opt.points = None
+ # gc.collect()
+ # torch.cuda.empty_cache()
+
+ start_t = time.time()
+
+ for epoch in range(self.epoch + 1, max_epochs + 1):
+ self.epoch = epoch
+
+ self.train_one_epoch(train_loader, max_epochs)
+
+ if self.workspace is not None and self.local_rank == 0:
+ self.save_checkpoint(full=True, best=False)
+
+ if self.epoch % self.opt.eval_interval == 0:
+ self.evaluate_one_epoch(valid_loader)
+ self.save_checkpoint(full=False, best=True)
+
+ if self.epoch % self.opt.test_interval == 0 or self.epoch == max_epochs:
+ self.test(test_loader, img_folder='images' if self.epoch == max_epochs else f'images_ep{self.epoch:04d}')
+
+ end_t = time.time()
+
+ self.total_train_t = end_t - start_t + self.total_train_t
+
+ logger.info(f"[INFO] training takes {(self.total_train_t)/ 60:.4f} minutes.")
+
+ if self.use_tensorboard and self.local_rank == 0:
+ self.writer.close()
+
+ def evaluate(self, loader, name=None):
+ self.use_tensorboard, use_tensorboard = False, self.use_tensorboard
+ self.evaluate_one_epoch(loader, name)
+ self.use_tensorboard = use_tensorboard
+
+ def test(self, loader, save_path=None, name=None,
+ write_video=True, save_each_frame=True, shading='lambertian',
+ subfolder='results', img_folder='images'
+ ):
+
+ if save_path is None:
+ save_path = os.path.join(self.workspace, subfolder)
+ image_save_path = os.path.join(self.workspace, subfolder, img_folder)
+
+ if name is None:
+ name = f'{self.name}_ep{self.epoch:04d}'
+
+ os.makedirs(save_path, exist_ok=True)
+ os.makedirs(image_save_path, exist_ok=True)
+
+ logger.info(f"==> Start Test, saving {shading} results to {save_path}")
+
+ pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
+ self.model.eval()
+
+ all_outputs = {}
+ with torch.no_grad():
+ for i, data in enumerate(loader):
+ with torch.cuda.amp.autocast(enabled=self.fp16):
+ outputs = self.test_step(data, bg_color=self.opt.bg_color_test, shading=shading)
+ for key, value in outputs.items():
+ if value is not None:
+ value = ((value - value.min()) / (value.max() - value.min() + 1e-6)).detach().mul(255).to(torch.uint8)
+ if save_each_frame:
+ save_tensor2image(value, os.path.join(image_save_path, f'{name}_{i:04d}_{key}.jpg'), channel_last=True)
+ if key not in all_outputs.keys():
+ all_outputs[key] = []
+ all_outputs[key].append(value)
+ pbar.update(loader.batch_size)
+
+ for key, value in all_outputs.items():
+ all_outputs[key] = torch.cat(value, dim=0)
+
+ if write_video:
+ for key, value in all_outputs.items():
+ # current version torchvision does not support writing a single-channel video
+ # torchvision.io.write_video(os.path.join(save_path, f'{name}_{key}.mp4'), all_outputs[key].detach().cpu(), fps=25)
+ imageio.mimwrite(os.path.join(save_path, f'{name}_{key}.mp4'), all_outputs[key].detach().cpu().numpy(), fps=25, quality=8, macro_block_size=1)
+ for key, value in all_outputs.items():
+ save_tensor2image(value, os.path.join(save_path, f'{name}_{key}.jpg'), channel_last=True)
+ logger.info(f"==> Finished Test.")
+
+ # [GUI] train text step.
+ def train_gui(self, train_loader, step=16):
+
+ self.model.train()
+
+ total_loss = torch.tensor([0], dtype=torch.float32, device=self.device)
+
+ loader = iter(train_loader)
+
+ for _ in range(step):
+
+ # mimic an infinite loop dataloader (in case the total dataset is smaller than step)
+ try:
+ data = next(loader)
+ except StopIteration:
+ loader = iter(train_loader)
+ data = next(loader)
+
+ # update grid every 16 steps
+ if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0:
+ with torch.cuda.amp.autocast(enabled=self.fp16):
+ self.model.update_extra_state()
+
+ self.global_step += 1
+
+ self.optimizer.zero_grad()
+
+ with torch.cuda.amp.autocast(enabled=self.fp16):
+ loss, loss_dicts, outputs = self.train_step(data)
+
+ self.scaler.scale(loss).backward()
+ self.post_train_step()
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+
+ if self.scheduler_update_every_step:
+ self.lr_scheduler.step()
+
+ self.loss_meter.update(loss_dicts)
+
+ if self.ema is not None:
+ self.ema.update()
+
+ average_loss = self.loss_meter.meters['loss'].avg
+
+ if not self.scheduler_update_every_step:
+ if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
+ self.lr_scheduler.step(average_loss)
+ else:
+ self.lr_scheduler.step()
+
+ outputs = {
+ 'loss': average_loss,
+ 'lr': self.optimizer.param_groups[0]['lr'],
+ }
+
+ return outputs
+
+
+ # [GUI] test on a single image
+ def test_gui(self, pose, intrinsics, mvp, W, H, bg_color=None, spp=1, downscale=1, light_d=None, ambient_ratio=1.0, shading='albedo'):
+
+ # render resolution (may need downscale to for better frame rate)
+ rH = int(H * downscale)
+ rW = int(W * downscale)
+ intrinsics = intrinsics * downscale
+
+ pose = torch.from_numpy(pose).unsqueeze(0).to(self.device)
+ mvp = torch.from_numpy(mvp).unsqueeze(0).to(self.device)
+
+ rays = get_rays(pose, intrinsics, rH, rW, -1)
+
+ # from degree theta/phi to 3D normalized vec
+ light_d = np.deg2rad(light_d)
+ light_d = np.array([
+ np.sin(light_d[0]) * np.sin(light_d[1]),
+ np.cos(light_d[0]),
+ np.sin(light_d[0]) * np.cos(light_d[1]),
+ ], dtype=np.float32)
+ light_d = torch.from_numpy(light_d).to(self.device)
+
+ data = {
+ 'rays_o': rays['rays_o'],
+ 'rays_d': rays['rays_d'],
+ 'mvp': mvp,
+ 'H': rH,
+ 'W': rW,
+ 'light_d': light_d,
+ 'ambient_ratio': ambient_ratio,
+ 'shading': shading,
+ }
+
+ self.model.eval()
+
+ if self.ema is not None:
+ self.ema.store()
+ self.ema.copy_to()
+
+ with torch.no_grad():
+ with torch.cuda.amp.autocast(enabled=self.fp16):
+ # here spp is used as perturb random seed!
+ outputs = self.test_step(
+ data, bg_color=bg_color, perturb=False if spp == 1 else spp)
+
+ if self.ema is not None:
+ self.ema.restore()
+
+ # interpolation to the original resolution
+ if downscale != 1:
+ # have to permute twice with torch...
+ outputs[shading] = F.interpolate(outputs[shading].permute(0, 3, 1, 2), size=(
+ H, W), mode='nearest').permute(0, 2, 3, 1).contiguous()
+ outputs['depth'] = F.interpolate(outputs['depth'].unsqueeze(
+ 1), size=(H, W), mode='nearest').squeeze(1)
+
+ if outputs['normal_imagea'] is not None:
+ outputs['normal_image'] = F.interpolate(outputs['normal_image'].unsqueeze(1), size=(H, W), mode='nearest').squeeze(1)
+
+ return outputs
+
+ def train_one_epoch(self, loader, max_epochs):
+ logger.info(f"==> [{time.strftime('%Y-%m-%d_%H-%M-%S')}] Start Training {self.workspace} Epoch {self.epoch}/{max_epochs}, lr={self.optimizer.param_groups[0]['lr']:.6f} ...")
+
+ if self.local_rank == 0 and self.report_metric_at_train:
+ for metric in self.metrics:
+ metric.clear()
+
+ self.model.train()
+
+ # distributedSampler: must call set_epoch() to shuffle indices across multiple epochs
+ # ref: https://pytorch.org/docs/stable/data.html
+ if self.world_size > 1:
+ loader.sampler.set_epoch(self.epoch)
+
+ self.local_step = 0
+
+ for data in loader:
+
+ # update grid every 16 steps
+ if (self.model.cuda_ray or self.model.taichi_ray) and self.global_step % self.opt.update_extra_interval == 0:
+ with torch.cuda.amp.autocast(enabled=self.fp16):
+ self.model.update_extra_state()
+
+ # Update grid
+ if self.opt.grid_levels_mask > 0:
+ if self.global_step > self.opt.grid_levels_mask_iters:
+ self.model.grid_levels_mask = 0
+ else:
+ self.model.grid_levels_mask = self.opt.grid_levels_mask
+
+ self.local_step += 1
+ self.global_step += 1
+
+ self.optimizer.zero_grad()
+
+ with torch.cuda.amp.autocast(enabled=self.fp16):
+ loss, losses_dict, outputs = self.train_step(data)
+
+ # hooked grad clipping for RGB space
+ if self.opt.grad_clip_rgb >= 0:
+ def _hook(grad):
+ if self.opt.fp16:
+ # correctly handle the scale
+ grad_scale = self.scaler._get_scale_async()
+ return grad.clamp(grad_scale * -self.opt.grad_clip_rgb, grad_scale * self.opt.grad_clip_rgb)
+ else:
+ return grad.clamp(-self.opt.grad_clip_rgb, self.opt.grad_clip_rgb)
+ outputs['rgb'].register_hook(_hook)
+ # if (self.global_step <= self.opt.known_iters or self.global_step % self.opt.known_view_interval == 0) and self.opt.image is not None and self.opt.joint_known_unknown and known_rgbs is not None:
+ # known_rgbs.register_hook(_hook)
+ # pred_rgbs.retain_grad()
+
+ self.scaler.scale(loss).backward()
+
+ self.post_train_step()
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+
+ if self.scheduler_update_every_step:
+ self.lr_scheduler.step()
+
+ self.loss_meter.update(losses_dict)
+ if self.local_rank == 0:
+ # if self.report_metric_at_train:
+ # for metric in self.metrics:
+ # metric.update(preds, truths)
+
+ if self.use_tensorboard:
+
+ for key, val in losses_dict.items():
+ self.writer.add_scalar(
+ f"train/{key}", val, self.global_step)
+
+ self.writer.add_scalar(
+ "train/lr", self.optimizer.param_groups[0]['lr'], self.global_step)
+
+ if self.global_step % self.opt.log_every == 0:
+ strings = f"==> Train [Step] {self.global_step}/{self.opt.iters}"
+ for key, value in losses_dict.items():
+ strings += f", {key}={value:.4f}"
+ logger.info(strings)
+ strings = f"==> Train [Avg] {self.global_step}/{self.opt.iters}"
+ for key in self.loss_meter.meters.keys():
+ strings += f", {key}={self.loss_meter.meters[key].avg:.4f}"
+ logger.info(strings)
+
+ if self.ema is not None:
+ self.ema.update()
+
+ average_loss = self.loss_meter.meters['loss'].avg
+ self.stats["loss"].append(average_loss)
+
+ if self.local_rank == 0:
+ # pbar.close()
+ if self.report_metric_at_train:
+ for metric in self.metrics:
+ logger.info(metric.report(), style="red")
+ if self.use_tensorboard:
+ metric.write(self.writer, self.epoch, prefix="train")
+ metric.clear()
+
+ if not self.scheduler_update_every_step:
+ if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
+ self.lr_scheduler.step(average_loss)
+ else:
+ self.lr_scheduler.step()
+
+ # Visualize Training
+ if self.local_rank == 0:
+ # save image
+ save_path = os.path.join(
+ self.workspace, 'training')
+ os.makedirs(save_path, exist_ok=True)
+ name = f'train_{self.name}_ep{self.epoch:04d}'
+ for key, value in outputs.items():
+ save_tensor2image(value, os.path.join(save_path, f'{name}_{key}.jpg'), channel_last=False)
+ gpu_mem = get_GPU_mem()[0]
+ logger.info(f"==> [Finished Epoch {self.epoch}/{max_epochs}. GPU={gpu_mem:.1f}GB.")
+
+ def evaluate_one_epoch(self, loader, name=None):
+ logger.info(f"++> Evaluate {self.workspace} at epoch {self.epoch} ...")
+
+ if name is None:
+ name = f'{self.name}_ep{self.epoch:04d}'
+
+ total_loss = 0
+ if self.local_rank == 0:
+ for metric in self.metrics:
+ metric.clear()
+
+ self.model.eval()
+
+ if self.ema is not None:
+ self.ema.store()
+ self.ema.copy_to()
+
+ if self.local_rank == 0:
+ pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
+
+ with torch.no_grad():
+ self.local_step = 0
+
+ all_outputs = {}
+ for data in loader:
+ self.local_step += 1
+
+ with torch.cuda.amp.autocast(enabled=self.fp16):
+ outputs, loss = self.eval_step(data)
+
+ # all_gather/reduce the statistics (NCCL only support all_*)
+ if self.world_size > 1:
+ dist.all_reduce(loss, op=dist.ReduceOp.SUM)
+ loss = loss / self.world_size
+
+ for key, value in outputs.items():
+ if value is not None:
+ dist.all_gather(outputs[key])
+ outputs[key] = torch.cat(outputs[key], dim=0)
+
+ loss_val = loss.item()
+ total_loss += loss_val
+
+ # only rank = 0 will perform evaluation.
+ if self.local_rank == 0:
+
+ # save image
+ save_path = os.path.join(
+ self.workspace, 'validation')
+
+ # logger.info(f"==> Saving validation image to {save_path}")
+ os.makedirs(save_path, exist_ok=True)
+
+ for key, value in outputs.items():
+ if value is not None:
+ value = ((value - value.min()) / (value.max() - value.min() + 1e-6)).detach().mul(255).to(torch.uint8)
+ # save_tensor2image(value, os.path.join(save_path, f'{name}_{self.local_step:04d}_{key}.jpg'))
+ if key not in all_outputs.keys():
+ all_outputs[key] = []
+ all_outputs[key].append(value)
+
+ pbar.set_description(
+ f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})")
+ pbar.update(loader.batch_size)
+
+
+ average_loss = total_loss / self.local_step
+ self.stats["valid_loss"].append(average_loss)
+
+ if self.local_rank == 0:
+ pbar.close()
+ if not self.use_loss_as_metric and len(self.metrics) > 0:
+ result = self.metrics[0].measure()
+ self.stats["results"].append(result if self.best_mode == 'min' else - result) # if max mode, use -result
+ else:
+ self.stats["results"].append(average_loss) # if no metric, choose best by min loss
+
+ for metric in self.metrics:
+ logger.info(metric.report(), style="blue")
+ if self.use_tensorboard:
+ metric.write(self.writer, self.epoch, prefix="evaluate")
+ metric.clear()
+
+ for key, value in all_outputs.items():
+ all_outputs[key] = torch.cat(value, dim=0)
+ save_tensor2image(all_outputs[key], os.path.join(save_path, f'{name}_{key}.jpg'), channel_last=True)
+ if self.ema is not None:
+ self.ema.restore()
+
+ logger.info(f"++> Evaluate epoch {self.epoch} Finished.")
+
+ def save_checkpoint(self, name=None, full=False, best=False):
+
+ if name is None:
+ name = f'{self.name}_ep{self.epoch:04d}'
+
+ state = {
+ 'epoch': self.epoch,
+ 'global_step': self.global_step,
+ 'stats': self.stats,
+ }
+
+ if self.model.cuda_ray:
+ state['mean_density'] = self.model.mean_density
+
+ if self.opt.dmtet:
+ state['tet_scale'] = self.model.dmtet.tet_scale.cpu().numpy()
+
+ if full:
+ state['optimizer'] = self.optimizer.state_dict()
+ state['lr_scheduler'] = self.lr_scheduler.state_dict()
+ state['scaler'] = self.scaler.state_dict()
+ if self.ema is not None:
+ state['ema'] = self.ema.state_dict()
+
+ if not best:
+
+ state['model'] = self.model.state_dict()
+
+ file_path = f"{name}.pth"
+
+ self.stats["checkpoints"].append(file_path)
+
+ if len(self.stats["checkpoints"]) > self.max_keep_ckpt:
+ old_ckpt = os.path.join(
+ self.opt.ckpt_path, self.stats["checkpoints"].pop(0))
+ if os.path.exists(old_ckpt):
+ os.remove(old_ckpt)
+
+ torch.save(state, os.path.join(self.opt.ckpt_path, file_path))
+
+ else:
+ if len(self.stats["results"]) > 0:
+ # always save best since loss cannot reflect performance.
+ if True:
+ # logger.info(f"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}")
+ # self.stats["best_result"] = self.stats["results"][-1]
+
+ # save ema results
+ if self.ema is not None:
+ self.ema.store()
+ self.ema.copy_to()
+
+ state['model'] = self.model.state_dict()
+
+ if self.ema is not None:
+ self.ema.restore()
+
+ torch.save(state, self.opt.best_path)
+ else:
+ logger.info(
+ f"[WARN] no evaluated results found, skip saving best checkpoint.")
+
+ def load_checkpoint(self, checkpoint=None, model_only=False):
+ if checkpoint is None:
+ checkpoint_list = sorted(glob.glob(f'{self.opt.ckpt_path}/*.pth'))
+ if checkpoint_list:
+ checkpoint = checkpoint_list[-1]
+ logger.info(f"[INFO] Latest checkpoint is {checkpoint}")
+ else:
+ logger.info(
+ "[WARN] No checkpoint found, model randomly initialized.")
+ return
+
+ checkpoint_dict = torch.load(checkpoint, map_location=self.device)
+
+ if 'model' not in checkpoint_dict:
+ self.model.load_state_dict(checkpoint_dict)
+ logger.info("[INFO] loaded model.")
+ return
+
+ missing_keys, unexpected_keys = self.model.load_state_dict(checkpoint_dict['model'], strict=False)
+ logger.info("[INFO] loaded model.")
+ if len(missing_keys) > 0:
+ logger.info(f"[WARN] missing keys: {missing_keys}")
+ if len(unexpected_keys) > 0:
+ logger.info(f"[WARN] unexpected keys: {unexpected_keys}")
+
+ if self.ema is not None and 'ema' in checkpoint_dict:
+ try:
+ self.ema.load_state_dict(checkpoint_dict['ema'])
+ logger.info("[INFO] loaded EMA.")
+ except:
+ logger.info("[WARN] failed to loaded EMA.")
+
+ if self.model.cuda_ray:
+ if 'mean_density' in checkpoint_dict:
+ self.model.mean_density = checkpoint_dict['mean_density']
+
+ if self.opt.dmtet:
+ if 'tet_scale' in checkpoint_dict:
+ new_scale = torch.from_numpy(
+ checkpoint_dict['tet_scale']).to(self.device)
+ self.model.dmtet.verts *= new_scale / self.model.dmtet.tet_scale
+ self.model.dmtet.tet_scale = new_scale
+ # self.model.init_tet()
+ if model_only:
+ return
+
+ self.stats = checkpoint_dict['stats']
+ self.epoch = checkpoint_dict['epoch']
+ self.global_step = checkpoint_dict['global_step']
+ logger.info(
+ f"[INFO] load at epoch {self.epoch}, global step {self.global_step}")
+
+ if self.optimizer and 'optimizer' in checkpoint_dict:
+ try:
+ self.optimizer.load_state_dict(checkpoint_dict['optimizer'])
+ logger.info("[INFO] loaded optimizer.")
+ except:
+ logger.info("[WARN] Failed to load optimizer.")
+
+ if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict:
+ try:
+ self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler'])
+ logger.info("[INFO] loaded scheduler.")
+ except:
+ logger.info("[WARN] Failed to load scheduler.")
+
+ if self.scaler and 'scaler' in checkpoint_dict:
+ try:
+ self.scaler.load_state_dict(checkpoint_dict['scaler'])
+ logger.info("[INFO] loaded scaler.")
+ except:
+ logger.info("[WARN] Failed to load scaler.")
+
+
+def get_CPU_mem():
+ return psutil.Process(os.getpid()).memory_info().rss /1024**3
+
+
+def get_GPU_mem():
+ num = torch.cuda.device_count()
+ mem, mems = 0, []
+ for i in range(num):
+ mem_free, mem_total = torch.cuda.mem_get_info(i)
+ mems.append(int(((mem_total - mem_free)/1024**3)*1000)/1000)
+ mem += mems[-1]
+ return mem, mems
diff --git a/optimizer.py b/optimizer.py
new file mode 100644
index 0000000..f5bb64f
--- /dev/null
+++ b/optimizer.py
@@ -0,0 +1,325 @@
+# Copyright 2022 Garena Online Private Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import List
+
+import torch
+from torch import Tensor
+from torch.optim.optimizer import Optimizer
+
+
+class Adan(Optimizer):
+ """
+ Implements a pytorch variant of Adan
+ Adan was proposed in
+ Adan: Adaptive Nesterov Momentum Algorithm for
+ Faster Optimizing Deep Models[J].arXiv preprint arXiv:2208.06677, 2022.
+ https://arxiv.org/abs/2208.06677
+ Arguments:
+ params (iterable): iterable of parameters to optimize or
+ dicts defining parameter groups.
+ lr (float, optional): learning rate. (default: 1e-3)
+ betas (Tuple[float, float, flot], optional): coefficients used for
+ first- and second-order moments. (default: (0.98, 0.92, 0.99))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability. (default: 1e-8)
+ weight_decay (float, optional): decoupled weight decay
+ (L2 penalty) (default: 0)
+ max_grad_norm (float, optional): value used to clip
+ global grad norm (default: 0.0 no clip)
+ no_prox (bool): how to perform the decoupled weight decay
+ (default: False)
+ foreach (bool): if True would use torch._foreach implementation.
+ It's faster but uses slightly more memory. (default: True)
+ """
+ def __init__(self,
+ params,
+ lr=1e-3,
+ betas=(0.98, 0.92, 0.99),
+ eps=1e-8,
+ weight_decay=0.0,
+ max_grad_norm=0.0,
+ no_prox=False,
+ foreach: bool = True):
+ if not 0.0 <= max_grad_norm:
+ raise ValueError('Invalid Max grad norm: {}'.format(max_grad_norm))
+ if not 0.0 <= lr:
+ raise ValueError('Invalid learning rate: {}'.format(lr))
+ if not 0.0 <= eps:
+ raise ValueError('Invalid epsilon value: {}'.format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError('Invalid beta parameter at index 0: {}'.format(
+ betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError('Invalid beta parameter at index 1: {}'.format(
+ betas[1]))
+ if not 0.0 <= betas[2] < 1.0:
+ raise ValueError('Invalid beta parameter at index 2: {}'.format(
+ betas[2]))
+ defaults = dict(lr=lr,
+ betas=betas,
+ eps=eps,
+ weight_decay=weight_decay,
+ max_grad_norm=max_grad_norm,
+ no_prox=no_prox,
+ foreach=foreach)
+ super().__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super(Adan, self).__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault('no_prox', False)
+
+ @torch.no_grad()
+ def restart_opt(self):
+ for group in self.param_groups:
+ group['step'] = 0
+ for p in group['params']:
+ if p.requires_grad:
+ state = self.state[p]
+ # State initialization
+
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros_like(p)
+ # Exponential moving average of gradient difference
+ state['exp_avg_diff'] = torch.zeros_like(p)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step."""
+
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ if self.defaults['max_grad_norm'] > 0:
+ device = self.param_groups[0]['params'][0].device
+ global_grad_norm = torch.zeros(1, device=device)
+
+ max_grad_norm = torch.tensor(self.defaults['max_grad_norm'],
+ device=device)
+ for group in self.param_groups:
+
+ for p in group['params']:
+ if p.grad is not None:
+ grad = p.grad
+ global_grad_norm.add_(grad.pow(2).sum())
+
+ global_grad_norm = torch.sqrt(global_grad_norm)
+
+ clip_global_grad_norm = torch.clamp(
+ max_grad_norm / (global_grad_norm + group['eps']),
+ max=1.0).item()
+ else:
+ clip_global_grad_norm = 1.0
+
+ for group in self.param_groups:
+ params_with_grad = []
+ grads = []
+ exp_avgs = []
+ exp_avg_sqs = []
+ exp_avg_diffs = []
+ neg_pre_grads = []
+
+ beta1, beta2, beta3 = group['betas']
+ # assume same step across group now to simplify things
+ # per parameter step can be easily support
+ # by making it tensor, or pass list into kernel
+ if 'step' in group:
+ group['step'] += 1
+ else:
+ group['step'] = 1
+
+ bias_correction1 = 1.0 - beta1**group['step']
+ bias_correction2 = 1.0 - beta2**group['step']
+ bias_correction3 = 1.0 - beta3**group['step']
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ params_with_grad.append(p)
+ grads.append(p.grad)
+
+ state = self.state[p]
+ if len(state) == 0:
+ state['exp_avg'] = torch.zeros_like(p)
+ state['exp_avg_sq'] = torch.zeros_like(p)
+ state['exp_avg_diff'] = torch.zeros_like(p)
+
+ if 'neg_pre_grad' not in state or group['step'] == 1:
+ state['neg_pre_grad'] = p.grad.clone().mul_(
+ -clip_global_grad_norm)
+
+ exp_avgs.append(state['exp_avg'])
+ exp_avg_sqs.append(state['exp_avg_sq'])
+ exp_avg_diffs.append(state['exp_avg_diff'])
+ neg_pre_grads.append(state['neg_pre_grad'])
+
+ kwargs = dict(
+ params=params_with_grad,
+ grads=grads,
+ exp_avgs=exp_avgs,
+ exp_avg_sqs=exp_avg_sqs,
+ exp_avg_diffs=exp_avg_diffs,
+ neg_pre_grads=neg_pre_grads,
+ beta1=beta1,
+ beta2=beta2,
+ beta3=beta3,
+ bias_correction1=bias_correction1,
+ bias_correction2=bias_correction2,
+ bias_correction3_sqrt=math.sqrt(bias_correction3),
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ no_prox=group['no_prox'],
+ clip_global_grad_norm=clip_global_grad_norm,
+ )
+
+ if group['foreach']:
+ _multi_tensor_adan(**kwargs)
+ else:
+ _single_tensor_adan(**kwargs)
+
+ return loss
+
+
+def _single_tensor_adan(
+ params: List[Tensor],
+ grads: List[Tensor],
+ exp_avgs: List[Tensor],
+ exp_avg_sqs: List[Tensor],
+ exp_avg_diffs: List[Tensor],
+ neg_pre_grads: List[Tensor],
+ *,
+ beta1: float,
+ beta2: float,
+ beta3: float,
+ bias_correction1: float,
+ bias_correction2: float,
+ bias_correction3_sqrt: float,
+ lr: float,
+ weight_decay: float,
+ eps: float,
+ no_prox: bool,
+ clip_global_grad_norm: Tensor,
+):
+ for i, param in enumerate(params):
+ grad = grads[i]
+ exp_avg = exp_avgs[i]
+ exp_avg_sq = exp_avg_sqs[i]
+ exp_avg_diff = exp_avg_diffs[i]
+ neg_grad_or_diff = neg_pre_grads[i]
+
+ grad.mul_(clip_global_grad_norm)
+
+ # for memory saving, we use `neg_grad_or_diff`
+ # to get some temp variable in a inplace way
+ neg_grad_or_diff.add_(grad)
+
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t
+ exp_avg_diff.mul_(beta2).add_(neg_grad_or_diff,
+ alpha=1 - beta2) # diff_t
+
+ neg_grad_or_diff.mul_(beta2).add_(grad)
+ exp_avg_sq.mul_(beta3).addcmul_(neg_grad_or_diff,
+ neg_grad_or_diff,
+ value=1 - beta3) # n_t
+
+ denom = ((exp_avg_sq).sqrt() / bias_correction3_sqrt).add_(eps)
+ step_size_diff = lr * beta2 / bias_correction2
+ step_size = lr / bias_correction1
+
+ if no_prox:
+ param.mul_(1 - lr * weight_decay)
+ param.addcdiv_(exp_avg, denom, value=-step_size)
+ param.addcdiv_(exp_avg_diff, denom, value=-step_size_diff)
+ else:
+ param.addcdiv_(exp_avg, denom, value=-step_size)
+ param.addcdiv_(exp_avg_diff, denom, value=-step_size_diff)
+ param.div_(1 + lr * weight_decay)
+
+ neg_grad_or_diff.zero_().add_(grad, alpha=-1.0)
+
+
+def _multi_tensor_adan(
+ params: List[Tensor],
+ grads: List[Tensor],
+ exp_avgs: List[Tensor],
+ exp_avg_sqs: List[Tensor],
+ exp_avg_diffs: List[Tensor],
+ neg_pre_grads: List[Tensor],
+ *,
+ beta1: float,
+ beta2: float,
+ beta3: float,
+ bias_correction1: float,
+ bias_correction2: float,
+ bias_correction3_sqrt: float,
+ lr: float,
+ weight_decay: float,
+ eps: float,
+ no_prox: bool,
+ clip_global_grad_norm: Tensor,
+):
+ if len(params) == 0:
+ return
+
+ torch._foreach_mul_(grads, clip_global_grad_norm)
+
+ # for memory saving, we use `neg_pre_grads`
+ # to get some temp variable in a inplace way
+ torch._foreach_add_(neg_pre_grads, grads)
+
+ torch._foreach_mul_(exp_avgs, beta1)
+ torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) # m_t
+
+ torch._foreach_mul_(exp_avg_diffs, beta2)
+ torch._foreach_add_(exp_avg_diffs, neg_pre_grads,
+ alpha=1 - beta2) # diff_t
+
+ torch._foreach_mul_(neg_pre_grads, beta2)
+ torch._foreach_add_(neg_pre_grads, grads)
+ torch._foreach_mul_(exp_avg_sqs, beta3)
+ torch._foreach_addcmul_(exp_avg_sqs,
+ neg_pre_grads,
+ neg_pre_grads,
+ value=1 - beta3) # n_t
+
+ denom = torch._foreach_sqrt(exp_avg_sqs)
+ torch._foreach_div_(denom, bias_correction3_sqrt)
+ torch._foreach_add_(denom, eps)
+
+ step_size_diff = lr * beta2 / bias_correction2
+ step_size = lr / bias_correction1
+
+ if no_prox:
+ torch._foreach_mul_(params, 1 - lr * weight_decay)
+ torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size)
+ torch._foreach_addcdiv_(params,
+ exp_avg_diffs,
+ denom,
+ value=-step_size_diff)
+ else:
+ torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size)
+ torch._foreach_addcdiv_(params,
+ exp_avg_diffs,
+ denom,
+ value=-step_size_diff)
+ torch._foreach_div_(params, 1 + lr * weight_decay)
+ torch._foreach_zero_(neg_pre_grads)
+ torch._foreach_add_(neg_pre_grads, grads, alpha=-1.0)
\ No newline at end of file
diff --git a/preprocess_image.py b/preprocess_image.py
new file mode 100644
index 0000000..13b1813
--- /dev/null
+++ b/preprocess_image.py
@@ -0,0 +1,251 @@
+import os
+import sys
+import cv2
+import argparse
+import numpy as np
+import matplotlib.pyplot as plt
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision import transforms
+from PIL import Image
+
+from easydict import EasyDict as edict
+
+class BackgroundRemoval():
+ def __init__(self, device='cuda'):
+
+ from carvekit.api.high import HiInterface
+ self.interface = HiInterface(
+ object_type="object", # Can be "object" or "hairs-like".
+ batch_size_seg=5,
+ batch_size_matting=1,
+ device=device,
+ seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
+ matting_mask_size=2048,
+ trimap_prob_threshold=231,
+ trimap_dilation=30,
+ trimap_erosion_iters=5,
+ fp16=True,
+ )
+
+ @torch.no_grad()
+ def __call__(self, image):
+ # image: [H, W, 3] array in [0, 255].
+ image = Image.fromarray(image)
+ image = self.interface([image])[0]
+ image = np.array(image)
+ return image
+
+
+def get_rgba(image, alpha_matting=False):
+ try:
+ from rembg import remove
+ except ImportError:
+ print('Please install rembg with "pip install rembg"')
+ sys.exit()
+ return remove(image, alpha_matting=alpha_matting)
+
+
+class BLIP2():
+ def __init__(self, device='cuda'):
+ self.device = device
+ from transformers import AutoProcessor, Blip2ForConditionalGeneration
+ self.processor = AutoProcessor.from_pretrained(
+ "Salesforce/blip2-opt-2.7b")
+ self.model = Blip2ForConditionalGeneration.from_pretrained(
+ "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16).to(device)
+
+ @torch.no_grad()
+ def __call__(self, image):
+ image = Image.fromarray(image)
+ inputs = self.processor(image, return_tensors="pt").to(
+ self.device, torch.float16)
+
+ generated_ids = self.model.generate(**inputs, max_new_tokens=20)
+ generated_text = self.processor.batch_decode(
+ generated_ids, skip_special_tokens=True)[0].strip()
+
+ return generated_text
+
+
+class DPT():
+ def __init__(self, task='depth', device='cuda'):
+
+ self.task = task
+ self.device = device
+
+ from dpt import DPTDepthModel
+
+ if task == 'depth':
+ path = 'pretrained/omnidata/omnidata_dpt_depth_v2.ckpt'
+ self.model = DPTDepthModel(backbone='vitb_rn50_384')
+ self.aug = transforms.Compose([
+ transforms.Resize((384, 384)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=0.5, std=0.5)
+ ])
+
+ else: # normal
+ path = 'pretrained/omnidata/omnidata_dpt_normal_v2.ckpt'
+ self.model = DPTDepthModel(backbone='vitb_rn50_384', num_channels=3)
+ self.aug = transforms.Compose([
+ transforms.Resize((384, 384)),
+ transforms.ToTensor()
+ ])
+
+ # load model
+ checkpoint = torch.load(path, map_location='cpu')
+ if 'state_dict' in checkpoint:
+ state_dict = {}
+ for k, v in checkpoint['state_dict'].items():
+ state_dict[k[6:]] = v
+ else:
+ state_dict = checkpoint
+ self.model.load_state_dict(state_dict)
+ self.model.eval().to(device)
+
+
+ @torch.no_grad()
+ def __call__(self, image):
+ # image: np.ndarray, uint8, [H, W, 3]
+ H, W = image.shape[:2]
+ image = Image.fromarray(image)
+
+ image = self.aug(image).unsqueeze(0).to(self.device)
+
+ if self.task == 'depth':
+ depth = self.model(image).clamp(0, 1)
+ depth = F.interpolate(depth.unsqueeze(1), size=(H, W), mode='bicubic', align_corners=False)
+ depth = depth.squeeze(1).cpu().numpy()
+ return depth
+ else:
+ normal = self.model(image).clamp(0, 1)
+ normal = F.interpolate(normal, size=(H, W), mode='bicubic', align_corners=False)
+ normal = normal.cpu().numpy()
+ return normal
+
+
+# from munch import DefaultMunch
+from midas.model_loader import default_models, load_model
+
+depth_config={
+ "input_path": None,
+ "output_path": None,
+ "model_weights": "pretrained/midas/dpt_beit_large_512.pt",
+ "model_type": "dpt_beit_large_512",
+ "side": False,
+ "optimize": False,
+ "height": None,
+ "square": False,
+ "device":0,
+ "grayscale": False
+}
+
+
+class DepthEstimator:
+ def __init__(self,**kwargs):
+ # update coming args
+ for key, value in kwargs.items():
+ depth_config[key]=value
+
+ # self.config=DefaultMunch.fromDict(depth_config)
+ self.config = edict(depth_config)
+
+ # select device
+ self.device = torch.device(self.config.device)
+ model, transform, net_w, net_h = load_model(f"cuda:{self.config.device}", self.config.model_weights, self.config.model_type,
+ self.config.optimize, self.config.height, self.config.square)
+ self.model, self.transform, self.net_w, self.net_h=model, transform, net_w, net_h
+ self.first_execution = True
+
+ @torch.no_grad()
+ def process(self,image,target_size):
+ sample = torch.from_numpy(image).to(self.device).unsqueeze(0)
+
+
+ if self.first_execution:
+ height, width = sample.shape[2:]
+ print(f" Input resized to {width}x{height} before entering the encoder")
+ self.first_execution = False
+
+ prediction = self.model.forward(sample)
+ prediction = (
+ torch.nn.functional.interpolate(
+ prediction.unsqueeze(1),
+ size=target_size[::-1],
+ mode="bicubic",
+ align_corners=False,
+ )
+ .squeeze()
+ .cpu()
+ .numpy()
+ )
+ return prediction
+
+ @torch.no_grad()
+ def get_monocular_depth(self,rgb, output_path=None):
+ original_image_rgb=rgb
+ image = self.transform({"image": original_image_rgb})["image"]
+
+ prediction = self.process(image, original_image_rgb.shape[1::-1])
+ return prediction
+
+
+
+def process_single_image(image_path, depth_estimator, normal_estimator=None):
+ out_dir = os.path.dirname(image_path)
+ rgba_path = os.path.join(out_dir, 'rgba.png')
+ depth_path = os.path.join(out_dir, 'depth.png')
+ # out_normal = os.path.join(out_dir, 'normal.png')
+
+ if os.path.exists(rgba_path):
+ print(f'[INFO] loading rgba image {rgba_path}...')
+ rgba = cv2.cvtColor(cv2.imread(rgba_path, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA)
+ image = cv2.cvtColor(rgba, cv2.COLOR_RGBA2RGB)
+ else:
+ print(f'[INFO] loading image {image_path}...')
+ image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
+ if image.shape[-1] == 4:
+ image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
+ else:
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ print(f'[INFO] background removal...')
+ rgba = BackgroundRemoval()(image) # [H, W, 4]
+ cv2.imwrite(rgba_path, cv2.cvtColor(rgba, cv2.COLOR_RGBA2BGRA))
+ # rgba = get_rgba(image) # [H, W, 4]
+ # cv2.imwrite(rgba_path.replace('rgba', 'rgba2'), cv2.cvtColor(rgba, cv2.COLOR_RGBA2BGRA))
+
+ # Predict depth using Midas
+ mask = rgba[..., -1] > 0
+ depth = depth_estimator.get_monocular_depth(image/255)
+ depth[mask] = (depth[mask] - depth[mask].min()) / (depth[mask].max() - depth[mask].min() + 1e-9)
+ depth[~mask] = 0
+ depth = (depth * 255).astype(np.uint8)
+
+ # print(f'[INFO] normal estimation...')
+ # normal = normal_estimator(image)[0]
+ # normal = (normal.clip(0, 1) * 255).astype(np.uint8).transpose(1, 2, 0)
+ # normal[~mask] = 0
+
+ cv2.imwrite(depth_path, depth)
+ # cv2.imwrite(out_normal, cv2.cvtColor(normal, cv2.COLOR_RGB2BGR))
+ if not os.path.exists(rgba_path):
+ cv2.imwrite(rgba_path, cv2.cvtColor(rgba, cv2.COLOR_RGBA2BGRA))
+
+if __name__ == '__main__':
+ import glob
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--path', default=None, type=str, nargs='*', help="path to image (png, jpeg, etc.)")
+ parser.add_argument('--folder', default=None, type=str, help="path to image (png, jpeg, etc.)")
+ opt = parser.parse_args()
+
+ depth_estimator = DepthEstimator()
+ # normal_estimator = DPT(task='normal')
+
+ paths = opt.path if opt.path is not None else glob.glob(os.path.join(opt.folder, '*/rgba.png'))
+ for path in paths:
+ process_single_image(path, depth_estimator,
+ # normal_estimator
+ )
\ No newline at end of file
diff --git a/pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml b/pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml
new file mode 100755
index 0000000..2448cd4
--- /dev/null
+++ b/pretrained/zero123/sd-objaverse-finetune-c_concat-256.yaml
@@ -0,0 +1,117 @@
+model:
+ base_learning_rate: 1.0e-04
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "image_target"
+ cond_stage_key: "image_cond"
+ image_size: 32
+ channels: 4
+ cond_stage_trainable: false # Note: different from the one we trained before
+ conditioning_key: hybrid
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+
+ scheduler_config: # 10000 warmup steps
+ target: ldm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps: [ 100 ]
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+ f_start: [ 1.e-6 ]
+ f_max: [ 1. ]
+ f_min: [ 1. ]
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ image_size: 32 # unused
+ in_channels: 8
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 768
+ use_checkpoint: True
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder
+
+
+# data:
+# target: ldm.data.simple.ObjaverseDataModuleFromConfig
+# params:
+# root_dir: 'views_whole_sphere'
+# batch_size: 192
+# num_workers: 16
+# total_view: 4
+# train:
+# validation: False
+# image_transforms:
+# size: 256
+
+# validation:
+# validation: True
+# image_transforms:
+# size: 256
+
+
+# lightning:
+# find_unused_parameters: false
+# metrics_over_trainsteps_checkpoint: True
+# modelcheckpoint:
+# params:
+# every_n_train_steps: 5000
+# callbacks:
+# image_logger:
+# target: main.ImageLogger
+# params:
+# batch_frequency: 500
+# max_images: 32
+# increase_log_steps: False
+# log_first_step: True
+# log_images_kwargs:
+# use_ema_scope: False
+# inpaint: False
+# plot_progressive_rows: False
+# plot_diffusion_rows: False
+# N: 32
+# unconditional_scale: 3.0
+# unconditional_label: [""]
+
+# trainer:
+# benchmark: True
+# val_check_interval: 5000000 # really sorry
+# num_sanity_val_steps: 0
+# accumulate_grad_batches: 1
diff --git a/raymarching/__init__.py b/raymarching/__init__.py
new file mode 100644
index 0000000..26d3cc6
--- /dev/null
+++ b/raymarching/__init__.py
@@ -0,0 +1 @@
+from .raymarching import *
\ No newline at end of file
diff --git a/raymarching/backend.py b/raymarching/backend.py
new file mode 100644
index 0000000..7cc0d76
--- /dev/null
+++ b/raymarching/backend.py
@@ -0,0 +1,41 @@
+import os
+from torch.utils.cpp_extension import load
+
+_src_path = os.path.dirname(os.path.abspath(__file__))
+
+nvcc_flags = [
+ '-O3', '-std=c++14',
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
+]
+
+if os.name == "posix":
+ c_flags = ['-O3', '-std=c++14']
+elif os.name == "nt":
+ c_flags = ['/O2', '/std:c++17']
+
+ # find cl.exe
+ def find_cl_path():
+ import glob
+ for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
+ paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
+ if paths:
+ return paths[0]
+
+ # If cl.exe is not on path, try to find it.
+ if os.system("where cl.exe >nul 2>nul") != 0:
+ cl_path = find_cl_path()
+ if cl_path is None:
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
+ os.environ["PATH"] += ";" + cl_path
+
+_backend = load(name='_raymarching',
+ extra_cflags=c_flags,
+ extra_cuda_cflags=nvcc_flags,
+ sources=[os.path.join(_src_path, 'src', f) for f in [
+ 'raymarching.cu',
+ 'bindings.cpp',
+ ]],
+ )
+
+__all__ = ['_backend']
\ No newline at end of file
diff --git a/raymarching/raymarching.py b/raymarching/raymarching.py
new file mode 100644
index 0000000..760d730
--- /dev/null
+++ b/raymarching/raymarching.py
@@ -0,0 +1,398 @@
+import numpy as np
+import time
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+# lazy building:
+# `import raymarching` will not immediately build the extension, only if you actually call any functions.
+
+BACKEND = None
+
+def get_backend():
+ global BACKEND
+
+ if BACKEND is None:
+ try:
+ import _raymarching as _backend
+ except ImportError:
+ from .backend import _backend
+
+ BACKEND = _backend
+
+ return BACKEND
+
+# ----------------------------------------
+# utils
+# ----------------------------------------
+
+class _near_far_from_aabb(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
+ def forward(ctx, rays_o, rays_d, aabb, min_near=0.2):
+ ''' near_far_from_aabb, CUDA implementation
+ Calculate rays' intersection time (near and far) with aabb
+ Args:
+ rays_o: float, [N, 3]
+ rays_d: float, [N, 3]
+ aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax)
+ min_near: float, scalar
+ Returns:
+ nears: float, [N]
+ fars: float, [N]
+ '''
+ if not rays_o.is_cuda: rays_o = rays_o.cuda()
+ if not rays_d.is_cuda: rays_d = rays_d.cuda()
+
+ rays_o = rays_o.contiguous().view(-1, 3)
+ rays_d = rays_d.contiguous().view(-1, 3)
+
+ N = rays_o.shape[0] # num rays
+
+ nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
+ fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
+
+ get_backend().near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars)
+
+ return nears, fars
+
+near_far_from_aabb = _near_far_from_aabb.apply
+
+
+class _sph_from_ray(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
+ def forward(ctx, rays_o, rays_d, radius):
+ ''' sph_from_ray, CUDA implementation
+ get spherical coordinate on the background sphere from rays.
+ Assume rays_o are inside the Sphere(radius).
+ Args:
+ rays_o: [N, 3]
+ rays_d: [N, 3]
+ radius: scalar, float
+ Return:
+ coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface)
+ '''
+ if not rays_o.is_cuda: rays_o = rays_o.cuda()
+ if not rays_d.is_cuda: rays_d = rays_d.cuda()
+
+ rays_o = rays_o.contiguous().view(-1, 3)
+ rays_d = rays_d.contiguous().view(-1, 3)
+
+ N = rays_o.shape[0] # num rays
+
+ coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device)
+
+ get_backend().sph_from_ray(rays_o, rays_d, radius, N, coords)
+
+ return coords
+
+sph_from_ray = _sph_from_ray.apply
+
+
+class _morton3D(Function):
+ @staticmethod
+ def forward(ctx, coords):
+ ''' morton3D, CUDA implementation
+ Args:
+ coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...)
+ TODO: check if the coord range is valid! (current 128 is safe)
+ Returns:
+ indices: [N], int32, in [0, 128^3)
+
+ '''
+ if not coords.is_cuda: coords = coords.cuda()
+
+ N = coords.shape[0]
+
+ indices = torch.empty(N, dtype=torch.int32, device=coords.device)
+
+ get_backend().morton3D(coords.int(), N, indices)
+
+ return indices
+
+morton3D = _morton3D.apply
+
+class _morton3D_invert(Function):
+ @staticmethod
+ def forward(ctx, indices):
+ ''' morton3D_invert, CUDA implementation
+ Args:
+ indices: [N], int32, in [0, 128^3)
+ Returns:
+ coords: [N, 3], int32, in [0, 128)
+
+ '''
+ if not indices.is_cuda: indices = indices.cuda()
+
+ N = indices.shape[0]
+
+ coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device)
+
+ get_backend().morton3D_invert(indices.int(), N, coords)
+
+ return coords
+
+morton3D_invert = _morton3D_invert.apply
+
+
+class _packbits(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
+ def forward(ctx, grid, thresh, bitfield=None):
+ ''' packbits, CUDA implementation
+ Pack up the density grid into a bit field to accelerate ray marching.
+ Args:
+ grid: float, [C, H * H * H], assume H % 2 == 0
+ thresh: float, threshold
+ Returns:
+ bitfield: uint8, [C, H * H * H / 8]
+ '''
+ if not grid.is_cuda: grid = grid.cuda()
+ grid = grid.contiguous()
+
+ C = grid.shape[0]
+ H3 = grid.shape[1]
+ N = C * H3 // 8
+
+ if bitfield is None:
+ bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device)
+
+ get_backend().packbits(grid, N, thresh, bitfield)
+
+ return bitfield
+
+packbits = _packbits.apply
+
+
+class _flatten_rays(Function):
+ @staticmethod
+ def forward(ctx, rays, M):
+ ''' flatten rays
+ Args:
+ rays: [N, 2], all rays' (point_offset, point_count),
+ M: scalar, int, count of points (we cannot get this info from rays unfortunately...)
+ Returns:
+ res: [M], flattened ray index.
+ '''
+ if not rays.is_cuda: rays = rays.cuda()
+ rays = rays.contiguous()
+
+ N = rays.shape[0]
+
+ res = torch.zeros(M, dtype=torch.int, device=rays.device)
+
+ get_backend().flatten_rays(rays, N, M, res)
+
+ return res
+
+flatten_rays = _flatten_rays.apply
+
+# ----------------------------------------
+# train functions
+# ----------------------------------------
+
+class _march_rays_train(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
+ def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, perturb=False, dt_gamma=0, max_steps=1024, contract=False):
+ ''' march rays to generate points (forward only)
+ Args:
+ rays_o/d: float, [N, 3]
+ bound: float, scalar
+ density_bitfield: uint8: [CHHH // 8]
+ C: int
+ H: int
+ nears/fars: float, [N]
+ step_counter: int32, (2), used to count the actual number of generated points.
+ mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.)
+ perturb: bool
+ align: int, pad output so its size is dividable by align, set to -1 to disable.
+ force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays.
+ dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
+ max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
+ Returns:
+ xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray)
+ dirs: float, [M, 3], all generated points' view dirs.
+ ts: float, [M, 2], all generated points' ts.
+ rays: int32, [N, 2], all rays' (point_offset, point_count), e.g., xyzs[rays[i, 0]:(rays[i, 0] + rays[i, 1])] --> points belonging to rays[i, 0]
+ '''
+
+ if not rays_o.is_cuda: rays_o = rays_o.cuda()
+ if not rays_d.is_cuda: rays_d = rays_d.cuda()
+ if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda()
+
+ rays_o = rays_o.float().contiguous().view(-1, 3)
+ rays_d = rays_d.float().contiguous().view(-1, 3)
+ density_bitfield = density_bitfield.contiguous()
+
+ N = rays_o.shape[0] # num rays
+
+ step_counter = torch.zeros(1, dtype=torch.int32, device=rays_o.device) # point counter, ray counter
+
+ if perturb:
+ noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device)
+ else:
+ noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device)
+
+ # first pass: write rays, get total number of points M to render
+ rays = torch.empty(N, 2, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps
+ get_backend().march_rays_train(rays_o, rays_d, density_bitfield, bound, contract, dt_gamma, max_steps, N, C, H, nears, fars, None, None, None, rays, step_counter, noises)
+
+ # allocate based on M
+ M = step_counter.item()
+ # print(M, N)
+ # print(rays[:, 0].max())
+
+ xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
+ dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
+ ts = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device)
+
+ # second pass: write outputs
+ get_backend().march_rays_train(rays_o, rays_d, density_bitfield, bound, contract, dt_gamma, max_steps, N, C, H, nears, fars, xyzs, dirs, ts, rays, step_counter, noises)
+
+ return xyzs, dirs, ts, rays
+
+march_rays_train = _march_rays_train.apply
+
+
+class _composite_rays_train(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
+ def forward(ctx, sigmas, rgbs, ts, rays, T_thresh=1e-4, binarize=False):
+ ''' composite rays' rgbs, according to the ray marching formula.
+ Args:
+ rgbs: float, [M, 3]
+ sigmas: float, [M,]
+ ts: float, [M, 2]
+ rays: int32, [N, 3]
+ Returns:
+ weights: float, [M]
+ weights_sum: float, [N,], the alpha channel
+ depth: float, [N, ], the Depth
+ image: float, [N, 3], the RGB channel (after multiplying alpha!)
+ '''
+
+ sigmas = sigmas.float().contiguous()
+ rgbs = rgbs.float().contiguous()
+
+ M = sigmas.shape[0]
+ N = rays.shape[0]
+
+ weights = torch.zeros(M, dtype=sigmas.dtype, device=sigmas.device) # may leave unmodified, so init with 0
+ weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
+
+ depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
+ image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)
+
+ get_backend().composite_rays_train_forward(sigmas, rgbs, ts, rays, M, N, T_thresh, binarize, weights, weights_sum, depth, image)
+
+ ctx.save_for_backward(sigmas, rgbs, ts, rays, weights_sum, depth, image)
+ ctx.dims = [M, N, T_thresh, binarize]
+
+ return weights, weights_sum, depth, image
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_weights, grad_weights_sum, grad_depth, grad_image):
+
+ grad_weights = grad_weights.contiguous()
+ grad_weights_sum = grad_weights_sum.contiguous()
+ grad_depth = grad_depth.contiguous()
+ grad_image = grad_image.contiguous()
+
+ sigmas, rgbs, ts, rays, weights_sum, depth, image = ctx.saved_tensors
+ M, N, T_thresh, binarize = ctx.dims
+
+ grad_sigmas = torch.zeros_like(sigmas)
+ grad_rgbs = torch.zeros_like(rgbs)
+
+ get_backend().composite_rays_train_backward(grad_weights, grad_weights_sum, grad_depth, grad_image, sigmas, rgbs, ts, rays, weights_sum, depth, image, M, N, T_thresh, binarize, grad_sigmas, grad_rgbs)
+
+ return grad_sigmas, grad_rgbs, None, None, None, None
+
+
+composite_rays_train = _composite_rays_train.apply
+
+# ----------------------------------------
+# infer functions
+# ----------------------------------------
+
+class _march_rays(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
+ def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, perturb=False, dt_gamma=0, max_steps=1024, contract=False):
+ ''' march rays to generate points (forward only, for inference)
+ Args:
+ n_alive: int, number of alive rays
+ n_step: int, how many steps we march
+ rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive)
+ rays_t: float, [N], the alive rays' time, we only use the first n_alive.
+ rays_o/d: float, [N, 3]
+ bound: float, scalar
+ density_bitfield: uint8: [CHHH // 8]
+ C: int
+ H: int
+ nears/fars: float, [N]
+ align: int, pad output so its size is dividable by align, set to -1 to disable.
+ perturb: bool/int, int > 0 is used as the random seed.
+ dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
+ max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
+ Returns:
+ xyzs: float, [n_alive * n_step, 3], all generated points' coords
+ dirs: float, [n_alive * n_step, 3], all generated points' view dirs.
+ ts: float, [n_alive * n_step, 2], all generated points' ts
+ '''
+
+ if not rays_o.is_cuda: rays_o = rays_o.cuda()
+ if not rays_d.is_cuda: rays_d = rays_d.cuda()
+
+ rays_o = rays_o.float().contiguous().view(-1, 3)
+ rays_d = rays_d.float().contiguous().view(-1, 3)
+
+ M = n_alive * n_step
+
+ xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
+ dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
+ ts = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth
+
+ if perturb:
+ # torch.manual_seed(perturb) # test_gui uses spp index as seed
+ noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device)
+ else:
+ noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device)
+
+ get_backend().march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, contract, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, ts, noises)
+
+ return xyzs, dirs, ts
+
+march_rays = _march_rays.apply
+
+
+class _composite_rays(Function):
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
+ def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, ts, weights_sum, depth, image, T_thresh=1e-2, binarize=False):
+ ''' composite rays' rgbs, according to the ray marching formula. (for inference)
+ Args:
+ n_alive: int, number of alive rays
+ n_step: int, how many steps we march
+ rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive)
+ rays_t: float, [N], the alive rays' time
+ sigmas: float, [n_alive * n_step,]
+ rgbs: float, [n_alive * n_step, 3]
+ ts: float, [n_alive * n_step, 2]
+ In-place Outputs:
+ weights_sum: float, [N,], the alpha channel
+ depth: float, [N,], the depth value
+ image: float, [N, 3], the RGB channel (after multiplying alpha!)
+ '''
+ sigmas = sigmas.float().contiguous()
+ rgbs = rgbs.float().contiguous()
+ get_backend().composite_rays(n_alive, n_step, T_thresh, binarize, rays_alive, rays_t, sigmas, rgbs, ts, weights_sum, depth, image)
+ return tuple()
+
+
+composite_rays = _composite_rays.apply
\ No newline at end of file
diff --git a/raymarching/setup.py b/raymarching/setup.py
new file mode 100644
index 0000000..4d32fa7
--- /dev/null
+++ b/raymarching/setup.py
@@ -0,0 +1,63 @@
+import os
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+_src_path = os.path.dirname(os.path.abspath(__file__))
+
+nvcc_flags = [
+ '-O3', '-std=c++14',
+ '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
+]
+
+if os.name == "posix":
+ c_flags = ['-O3', '-std=c++14']
+elif os.name == "nt":
+ c_flags = ['/O2', '/std:c++17']
+
+ # find cl.exe
+ def find_cl_path():
+ import glob
+ for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
+ for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
+ paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
+ if paths:
+ return paths[0]
+
+ # If cl.exe is not on path, try to find it.
+ if os.system("where cl.exe >nul 2>nul") != 0:
+ cl_path = find_cl_path()
+ if cl_path is None:
+ raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
+ os.environ["PATH"] += ";" + cl_path
+
+'''
+Usage:
+
+python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory)
+
+python setup.py install # build extensions and install (copy) to PATH.
+pip install . # ditto but better (e.g., dependency & metadata handling)
+
+python setup.py develop # build extensions and install (symbolic) to PATH.
+pip install -e . # ditto but better (e.g., dependency & metadata handling)
+
+'''
+setup(
+ name='raymarching', # package name, import this to use python API
+ ext_modules=[
+ CUDAExtension(
+ name='_raymarching', # extension name, import this to use CUDA API
+ sources=[os.path.join(_src_path, 'src', f) for f in [
+ 'raymarching.cu',
+ 'bindings.cpp',
+ ]],
+ extra_compile_args={
+ 'cxx': c_flags,
+ 'nvcc': nvcc_flags,
+ }
+ ),
+ ],
+ cmdclass={
+ 'build_ext': BuildExtension,
+ }
+)
\ No newline at end of file
diff --git a/raymarching/src/bindings.cpp b/raymarching/src/bindings.cpp
new file mode 100644
index 0000000..eb8f122
--- /dev/null
+++ b/raymarching/src/bindings.cpp
@@ -0,0 +1,20 @@
+#include
+
+#include "raymarching.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ // utils
+ m.def("flatten_rays", &flatten_rays, "flatten_rays (CUDA)");
+ m.def("packbits", &packbits, "packbits (CUDA)");
+ m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)");
+ m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)");
+ m.def("morton3D", &morton3D, "morton3D (CUDA)");
+ m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)");
+ // train
+ m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)");
+ m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)");
+ m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)");
+ // infer
+ m.def("march_rays", &march_rays, "march rays (CUDA)");
+ m.def("composite_rays", &composite_rays, "composite rays (CUDA)");
+}
\ No newline at end of file
diff --git a/raymarching/src/raymarching.cu b/raymarching/src/raymarching.cu
new file mode 100644
index 0000000..0292f1c
--- /dev/null
+++ b/raymarching/src/raymarching.cu
@@ -0,0 +1,934 @@
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+#include
+#include
+#include
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
+#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
+#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
+
+
+inline constexpr __device__ float SQRT3() { return 1.7320508075688772f; }
+inline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; }
+inline constexpr __device__ float PI() { return 3.141592653589793f; }
+inline constexpr __device__ float RPI() { return 0.3183098861837907f; }
+
+
+template
+inline __host__ __device__ T div_round_up(T val, T divisor) {
+ return (val + divisor - 1) / divisor;
+}
+
+inline __host__ __device__ float signf(const float x) {
+ return copysignf(1.0, x);
+}
+
+inline __host__ __device__ float clamp(const float x, const float min, const float max) {
+ return fminf(max, fmaxf(min, x));
+}
+
+inline __host__ __device__ void swapf(float& a, float& b) {
+ float c = a; a = b; b = c;
+}
+
+inline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) {
+ const float mx = fmaxf(fabsf(x), fmaxf(fabsf(y), fabsf(z)));
+ int exponent;
+ frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, [2, 4) --> 2, ...
+ return fminf(max_cascade - 1, fmaxf(0, exponent));
+}
+
+inline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) {
+ const float mx = dt * H * 0.5;
+ int exponent;
+ frexpf(mx, &exponent);
+ return fminf(max_cascade - 1, fmaxf(0, exponent));
+}
+
+inline __host__ __device__ uint32_t __expand_bits(uint32_t v)
+{
+ v = (v * 0x00010001u) & 0xFF0000FFu;
+ v = (v * 0x00000101u) & 0x0F00F00Fu;
+ v = (v * 0x00000011u) & 0xC30C30C3u;
+ v = (v * 0x00000005u) & 0x49249249u;
+ return v;
+}
+
+inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z)
+{
+ uint32_t xx = __expand_bits(x);
+ uint32_t yy = __expand_bits(y);
+ uint32_t zz = __expand_bits(z);
+ return xx | (yy << 1) | (zz << 2);
+}
+
+inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x)
+{
+ x = x & 0x49249249;
+ x = (x | (x >> 2)) & 0xc30c30c3;
+ x = (x | (x >> 4)) & 0x0f00f00f;
+ x = (x | (x >> 8)) & 0xff0000ff;
+ x = (x | (x >> 16)) & 0x0000ffff;
+ return x;
+}
+
+
+////////////////////////////////////////////////////
+///////////// utils /////////////
+////////////////////////////////////////////////////
+
+// rays_o/d: [N, 3]
+// nears/fars: [N]
+// scalar_t should always be float in use.
+template
+__global__ void kernel_near_far_from_aabb(
+ const scalar_t * __restrict__ rays_o,
+ const scalar_t * __restrict__ rays_d,
+ const scalar_t * __restrict__ aabb,
+ const uint32_t N,
+ const float min_near,
+ scalar_t * nears, scalar_t * fars
+) {
+ // parallel per ray
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= N) return;
+
+ // locate
+ rays_o += n * 3;
+ rays_d += n * 3;
+
+ const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
+ const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
+ const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
+
+ // get near far (assume cube scene)
+ float near = (aabb[0] - ox) * rdx;
+ float far = (aabb[3] - ox) * rdx;
+ if (near > far) swapf(near, far);
+
+ float near_y = (aabb[1] - oy) * rdy;
+ float far_y = (aabb[4] - oy) * rdy;
+ if (near_y > far_y) swapf(near_y, far_y);
+
+ if (near > far_y || near_y > far) {
+ nears[n] = fars[n] = std::numeric_limits::max();
+ return;
+ }
+
+ if (near_y > near) near = near_y;
+ if (far_y < far) far = far_y;
+
+ float near_z = (aabb[2] - oz) * rdz;
+ float far_z = (aabb[5] - oz) * rdz;
+ if (near_z > far_z) swapf(near_z, far_z);
+
+ if (near > far_z || near_z > far) {
+ nears[n] = fars[n] = std::numeric_limits::max();
+ return;
+ }
+
+ if (near_z > near) near = near_z;
+ if (far_z < far) far = far_z;
+
+ if (near < min_near) near = min_near;
+
+ nears[n] = near;
+ fars[n] = far;
+}
+
+
+void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) {
+
+ static constexpr uint32_t N_THREAD = 128;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ rays_o.scalar_type(), "near_far_from_aabb", ([&] {
+ kernel_near_far_from_aabb<<>>(rays_o.data_ptr(), rays_d.data_ptr(), aabb.data_ptr(), N, min_near, nears.data_ptr(), fars.data_ptr());
+ }));
+}
+
+
+// rays_o/d: [N, 3]
+// radius: float
+// coords: [N, 2]
+template
+__global__ void kernel_sph_from_ray(
+ const scalar_t * __restrict__ rays_o,
+ const scalar_t * __restrict__ rays_d,
+ const float radius,
+ const uint32_t N,
+ scalar_t * coords
+) {
+ // parallel per ray
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= N) return;
+
+ // locate
+ rays_o += n * 3;
+ rays_d += n * 3;
+ coords += n * 2;
+
+ const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
+ const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
+ // const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
+
+ // solve t from || o + td || = radius
+ const float A = dx * dx + dy * dy + dz * dz;
+ const float B = ox * dx + oy * dy + oz * dz; // in fact B / 2
+ const float C = ox * ox + oy * oy + oz * oz - radius * radius;
+
+ const float t = (- B + sqrtf(B * B - A * C)) / A; // always use the larger solution (positive)
+
+ // solve theta, phi (assume y is the up axis)
+ const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz;
+ const float theta = atan2(sqrtf(x * x + z * z), y); // [0, PI)
+ const float phi = atan2(z, x); // [-PI, PI)
+
+ // normalize to [-1, 1]
+ coords[0] = 2 * theta * RPI() - 1;
+ coords[1] = phi * RPI();
+}
+
+
+void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) {
+
+ static constexpr uint32_t N_THREAD = 128;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ rays_o.scalar_type(), "sph_from_ray", ([&] {
+ kernel_sph_from_ray<<>>(rays_o.data_ptr(), rays_d.data_ptr(), radius, N, coords.data_ptr());
+ }));
+}
+
+
+// coords: int32, [N, 3]
+// indices: int32, [N]
+__global__ void kernel_morton3D(
+ const int * __restrict__ coords,
+ const uint32_t N,
+ int * indices
+) {
+ // parallel
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= N) return;
+
+ // locate
+ coords += n * 3;
+ indices[n] = __morton3D(coords[0], coords[1], coords[2]);
+}
+
+
+void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices) {
+ static constexpr uint32_t N_THREAD = 128;
+ kernel_morton3D<<>>(coords.data_ptr(), N, indices.data_ptr());
+}
+
+
+// indices: int32, [N]
+// coords: int32, [N, 3]
+__global__ void kernel_morton3D_invert(
+ const int * __restrict__ indices,
+ const uint32_t N,
+ int * coords
+) {
+ // parallel
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= N) return;
+
+ // locate
+ coords += n * 3;
+
+ const int ind = indices[n];
+
+ coords[0] = __morton3D_invert(ind >> 0);
+ coords[1] = __morton3D_invert(ind >> 1);
+ coords[2] = __morton3D_invert(ind >> 2);
+}
+
+
+void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords) {
+ static constexpr uint32_t N_THREAD = 128;
+ kernel_morton3D_invert<<>>(indices.data_ptr(), N, coords.data_ptr());
+}
+
+
+// grid: float, [C, H, H, H]
+// N: int, C * H * H * H / 8
+// density_thresh: float
+// bitfield: uint8, [N]
+template
+__global__ void kernel_packbits(
+ const scalar_t * __restrict__ grid,
+ const uint32_t N,
+ const float density_thresh,
+ uint8_t * bitfield
+) {
+ // parallel per byte
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= N) return;
+
+ // locate
+ grid += n * 8;
+
+ uint8_t bits = 0;
+
+ #pragma unroll
+ for (uint8_t i = 0; i < 8; i++) {
+ bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0;
+ }
+
+ bitfield[n] = bits;
+}
+
+
+void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) {
+
+ static constexpr uint32_t N_THREAD = 128;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ grid.scalar_type(), "packbits", ([&] {
+ kernel_packbits<<>>(grid.data_ptr(), N, density_thresh, bitfield.data_ptr());
+ }));
+}
+
+
+__global__ void kernel_flatten_rays(
+ const int * __restrict__ rays,
+ const uint32_t N, const uint32_t M,
+ int * res
+) {
+ // parallel per ray
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= N) return;
+
+ // locate
+ uint32_t offset = rays[n * 2];
+ uint32_t num_steps = rays[n * 2 + 1];
+
+ // write to res
+ res += offset;
+ for (int i = 0; i < num_steps; i++) res[i] = n;
+}
+
+void flatten_rays(const at::Tensor rays, const uint32_t N, const uint32_t M, at::Tensor res) {
+
+ static constexpr uint32_t N_THREAD = 128;
+
+ kernel_flatten_rays<<>>(rays.data_ptr(), N, M, res.data_ptr());
+}
+
+////////////////////////////////////////////////////
+///////////// training /////////////
+////////////////////////////////////////////////////
+
+// rays_o/d: [N, 3]
+// grid: [CHHH / 8]
+// xyzs, dirs, ts: [M, 3], [M, 3], [M, 2]
+// dirs: [M, 3]
+// rays: [N, 3], idx, offset, num_steps
+template
+__global__ void kernel_march_rays_train(
+ const scalar_t * __restrict__ rays_o,
+ const scalar_t * __restrict__ rays_d,
+ const uint8_t * __restrict__ grid,
+ const float bound, const bool contract,
+ const float dt_gamma, const uint32_t max_steps,
+ const uint32_t N, const uint32_t C, const uint32_t H,
+ const scalar_t* __restrict__ nears,
+ const scalar_t* __restrict__ fars,
+ scalar_t * xyzs, scalar_t * dirs, scalar_t * ts,
+ int * rays,
+ int * counter,
+ const scalar_t* __restrict__ noises
+) {
+ // parallel per ray
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= N) return;
+
+ // is first pass running.
+ const bool first_pass = (xyzs == nullptr);
+
+ // locate
+ rays_o += n * 3;
+ rays_d += n * 3;
+ rays += n * 2;
+
+ uint32_t num_steps = max_steps;
+
+ if (!first_pass) {
+ uint32_t point_index = rays[0];
+ num_steps = rays[1];
+ xyzs += point_index * 3;
+ dirs += point_index * 3;
+ ts += point_index * 2;
+ }
+
+ // ray marching
+ const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
+ const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
+ const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
+ const float rH = 1 / (float)H;
+ const float H3 = H * H * H;
+
+ const float near = nears[n];
+ const float far = fars[n];
+ const float noise = noises[n];
+
+ const float dt_min = 2 * SQRT3() / max_steps;
+ const float dt_max = 2 * SQRT3() * bound / H;
+ // const float dt_max = 1e10f;
+
+ float t0 = near;
+ t0 += clamp(t0 * dt_gamma, dt_min, dt_max) * noise;
+ float t = t0;
+ uint32_t step = 0;
+
+ //if (t < far) printf("valid ray %d t=%f near=%f far=%f \n", n, t, near, far);
+
+ while (t < far && step < num_steps) {
+ // current point
+ const float x = clamp(ox + t * dx, -bound, bound);
+ const float y = clamp(oy + t * dy, -bound, bound);
+ const float z = clamp(oz + t * dz, -bound, bound);
+
+ float dt = clamp(t * dt_gamma, dt_min, dt_max);
+
+ // get mip level
+ const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
+
+ const float mip_bound = fminf(scalbnf(1.0f, level), bound);
+ const float mip_rbound = 1 / mip_bound;
+
+ // contraction
+ float cx = x, cy = y, cz = z;
+ const float mag = fmaxf(fabsf(x), fmaxf(fabsf(y), fabsf(z)));
+ if (contract && mag > 1) {
+ // L-INF norm
+ const float Linf_scale = (2 - 1 / mag) / mag;
+ cx *= Linf_scale;
+ cy *= Linf_scale;
+ cz *= Linf_scale;
+ }
+
+ // convert to nearest grid position
+ const int nx = clamp(0.5 * (cx * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
+ const int ny = clamp(0.5 * (cy * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
+ const int nz = clamp(0.5 * (cz * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
+
+ const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
+ const bool occ = grid[index / 8] & (1 << (index % 8));
+
+ // if occpuied, advance a small step, and write to output
+ //if (n == 0) printf("t=%f density=%f vs thresh=%f step=%d\n", t, density, density_thresh, step);
+
+ if (occ) {
+ step++;
+ t += dt;
+ if (!first_pass) {
+ xyzs[0] = cx; // write contracted coordinates!
+ xyzs[1] = cy;
+ xyzs[2] = cz;
+ dirs[0] = dx;
+ dirs[1] = dy;
+ dirs[2] = dz;
+ ts[0] = t;
+ ts[1] = dt;
+ xyzs += 3;
+ dirs += 3;
+ ts += 2;
+ }
+ // contraction case: cannot apply voxel skipping.
+ } else if (contract && mag > 1) {
+ t += dt;
+ // else, skip a large step (basically skip a voxel grid)
+ } else {
+ // calc distance to next voxel
+ const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - cx) * rdx;
+ const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - cy) * rdy;
+ const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - cz) * rdz;
+
+ const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
+ // step until next voxel
+ do {
+ dt = clamp(t * dt_gamma, dt_min, dt_max);
+ t += dt;
+ } while (t < tt);
+ }
+ }
+
+ //printf("[n=%d] step=%d, near=%f, far=%f, dt=%f, num_steps=%f\n", n, step, near, far, dt_min, (far - near) / dt_min);
+
+ // write rays
+ if (first_pass) {
+ uint32_t point_index = atomicAdd(counter, step);
+ rays[0] = point_index;
+ rays[1] = step;
+ }
+}
+
+void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const at::Tensor nears, const at::Tensor fars, at::optional xyzs, at::optional dirs, at::optional ts, at::Tensor rays, at::Tensor counter, at::Tensor noises) {
+
+ static constexpr uint32_t N_THREAD = 128;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ rays_o.scalar_type(), "march_rays_train", ([&] {
+ kernel_march_rays_train<<>>(rays_o.data_ptr(), rays_d.data_ptr(), grid.data_ptr(), bound, contract, dt_gamma, max_steps, N, C, H, nears.data_ptr(), fars.data_ptr(),
+ xyzs.has_value() ? xyzs.value().data_ptr() : nullptr,
+ dirs.has_value() ? dirs.value().data_ptr() : nullptr,
+ ts.has_value() ? ts.value().data_ptr() : nullptr,
+ rays.data_ptr(), counter.data_ptr(), noises.data_ptr());
+ }));
+}
+
+
+// sigmas: [M]
+// rgbs: [M, 3]
+// ts: [M, 2]
+// rays: [N, 2], offset, num_steps
+// weights: [M]
+// weights_sum: [N], final pixel alpha
+// depth: [N,]
+// image: [N, 3]
+template
+__global__ void kernel_composite_rays_train_forward(
+ const scalar_t * __restrict__ sigmas,
+ const scalar_t * __restrict__ rgbs,
+ const scalar_t * __restrict__ ts,
+ const int * __restrict__ rays,
+ const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize,
+ scalar_t * weights,
+ scalar_t * weights_sum,
+ scalar_t * depth,
+ scalar_t * image
+) {
+ // parallel per ray
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= N) return;
+
+ // locate
+ uint32_t offset = rays[n * 2];
+ uint32_t num_steps = rays[n * 2 + 1];
+
+ // empty ray, or ray that exceed max step count.
+ if (num_steps == 0 || offset + num_steps > M) {
+ weights_sum[n] = 0;
+ depth[n] = 0;
+ image[n * 3] = 0;
+ image[n * 3 + 1] = 0;
+ image[n * 3 + 2] = 0;
+ return;
+ }
+
+ ts += offset * 2;
+ weights += offset;
+ sigmas += offset;
+ rgbs += offset * 3;
+
+ // accumulate
+ uint32_t step = 0;
+
+ float T = 1.0f;
+ float r = 0, g = 0, b = 0, ws = 0, d = 0;
+
+ while (step < num_steps) {
+
+ const float real_alpha = 1.0f - __expf(- sigmas[0] * ts[1]);
+ const float alpha = binarize ? (real_alpha > 0.5 ? 1.0 : 0.0) : real_alpha;
+ const float weight = alpha * T;
+
+ weights[0] = weight;
+
+ r += weight * rgbs[0];
+ g += weight * rgbs[1];
+ b += weight * rgbs[2];
+ ws += weight;
+ d += weight * ts[0];
+
+ T *= 1.0f - alpha;
+
+ // minimal remained transmittence
+ if (T < T_thresh) break;
+
+ //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
+
+ // locate
+ weights++;
+ sigmas++;
+ rgbs += 3;
+ ts += 2;
+
+ step++;
+ }
+
+ //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
+
+ // write
+ weights_sum[n] = ws; // weights_sum
+ depth[n] = d;
+ image[n * 3] = r;
+ image[n * 3 + 1] = g;
+ image[n * 3 + 2] = b;
+}
+
+
+void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, at::Tensor weights, at::Tensor weights_sum, at::Tensor depth, at::Tensor image) {
+
+ static constexpr uint32_t N_THREAD = 128;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ sigmas.scalar_type(), "composite_rays_train_forward", ([&] {
+ kernel_composite_rays_train_forward<<>>(sigmas.data_ptr(), rgbs.data_ptr(), ts.data_ptr(), rays.data_ptr(), M, N, T_thresh, binarize, weights.data_ptr(), weights_sum.data_ptr(), depth.data_ptr(), image.data_ptr());
+ }));
+}
+
+
+// grad_weights: [M,]
+// grad_weights_sum: [N,]
+// grad_image: [N, 3]
+// grad_depth: [N,]
+// sigmas: [M]
+// rgbs: [M, 3]
+// ts: [M, 2]
+// rays: [N, 2], offset, num_steps
+// weights_sum: [N,], weights_sum here
+// image: [N, 3]
+// grad_sigmas: [M]
+// grad_rgbs: [M, 3]
+template
+__global__ void kernel_composite_rays_train_backward(
+ const scalar_t * __restrict__ grad_weights,
+ const scalar_t * __restrict__ grad_weights_sum,
+ const scalar_t * __restrict__ grad_depth,
+ const scalar_t * __restrict__ grad_image,
+ const scalar_t * __restrict__ sigmas,
+ const scalar_t * __restrict__ rgbs,
+ const scalar_t * __restrict__ ts,
+ const int * __restrict__ rays,
+ const scalar_t * __restrict__ weights_sum,
+ const scalar_t * __restrict__ depth,
+ const scalar_t * __restrict__ image,
+ const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize,
+ scalar_t * grad_sigmas,
+ scalar_t * grad_rgbs
+) {
+ // parallel per ray
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= N) return;
+
+ // locate
+ uint32_t offset = rays[n * 2];
+ uint32_t num_steps = rays[n * 2 + 1];
+
+ if (num_steps == 0 || offset + num_steps > M) return;
+
+ grad_weights += offset;
+ grad_weights_sum += n;
+ grad_depth += n;
+ grad_image += n * 3;
+ weights_sum += n;
+ depth += n;
+ image += n * 3;
+ sigmas += offset;
+ rgbs += offset * 3;
+ ts += offset * 2;
+ grad_sigmas += offset;
+ grad_rgbs += offset * 3;
+
+ // accumulate
+ uint32_t step = 0;
+
+ float T = 1.0f;
+ const float r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0], d_final = depth[0];
+ float r = 0, g = 0, b = 0, ws = 0, d = 0;
+
+ while (step < num_steps) {
+
+ const float real_alpha = 1.0f - __expf(- sigmas[0] * ts[1]);
+ const float alpha = binarize ? (real_alpha > 0.5 ? 1.0 : 0.0) : real_alpha;
+ const float weight = alpha * T;
+
+ r += weight * rgbs[0];
+ g += weight * rgbs[1];
+ b += weight * rgbs[2];
+ ws += weight;
+ d += weight * ts[0];
+
+ T *= 1.0f - alpha;
+
+ // check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation.
+ // write grad_rgbs
+ grad_rgbs[0] = grad_image[0] * weight;
+ grad_rgbs[1] = grad_image[1] * weight;
+ grad_rgbs[2] = grad_image[2] * weight;
+
+ // write grad_sigmas
+ grad_sigmas[0] = ts[1] * (
+ grad_image[0] * (T * rgbs[0] - (r_final - r)) +
+ grad_image[1] * (T * rgbs[1] - (g_final - g)) +
+ grad_image[2] * (T * rgbs[2] - (b_final - b)) +
+ (grad_weights_sum[0] + grad_weights[0]) * (T - (ws_final - ws)) +
+ grad_depth[0] * (T * ts[0] - (d_final - d))
+ );
+
+ //printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r);
+ // minimal remained transmittence
+ if (T < T_thresh) break;
+
+ // locate
+ sigmas++;
+ rgbs += 3;
+ ts += 2;
+ grad_weights++;
+ grad_sigmas++;
+ grad_rgbs += 3;
+
+ step++;
+ }
+}
+
+
+void composite_rays_train_backward(const at::Tensor grad_weights, const at::Tensor grad_weights_sum, const at::Tensor grad_depth, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor depth, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, at::Tensor grad_sigmas, at::Tensor grad_rgbs) {
+
+ static constexpr uint32_t N_THREAD = 128;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ grad_image.scalar_type(), "composite_rays_train_backward", ([&] {
+ kernel_composite_rays_train_backward<<>>(grad_weights.data_ptr(), grad_weights_sum.data_ptr(), grad_depth.data_ptr(), grad_image.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), ts.data_ptr(), rays.data_ptr(), weights_sum.data_ptr(), depth.data_ptr(), image.data_ptr(), M, N, T_thresh, binarize, grad_sigmas.data_ptr(), grad_rgbs.data_ptr());
+ }));
+}
+
+
+////////////////////////////////////////////////////
+///////////// infernce /////////////
+////////////////////////////////////////////////////
+
+template
+__global__ void kernel_march_rays(
+ const uint32_t n_alive,
+ const uint32_t n_step,
+ const int* __restrict__ rays_alive,
+ const scalar_t* __restrict__ rays_t,
+ const scalar_t* __restrict__ rays_o,
+ const scalar_t* __restrict__ rays_d,
+ const float bound, const bool contract,
+ const float dt_gamma, const uint32_t max_steps,
+ const uint32_t C, const uint32_t H,
+ const uint8_t * __restrict__ grid,
+ const scalar_t* __restrict__ nears,
+ const scalar_t* __restrict__ fars,
+ scalar_t* xyzs, scalar_t* dirs, scalar_t* ts,
+ const scalar_t* __restrict__ noises
+) {
+ const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
+ if (n >= n_alive) return;
+
+ const int index = rays_alive[n]; // ray id
+ const float noise = noises[n];
+
+ // locate
+ rays_o += index * 3;
+ rays_d += index * 3;
+ xyzs += n * n_step * 3;
+ dirs += n * n_step * 3;
+ ts += n * n_step * 2;
+
+ const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
+ const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
+ const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
+ const float rH = 1 / (float)H;
+ const float H3 = H * H * H;
+
+ const float near = nears[index], far = fars[index];
+
+ const float dt_min = 2 * SQRT3() / max_steps;
+ const float dt_max = 2 * SQRT3() * bound / H;
+ // const float dt_max = 1e10f;
+
+ // march for n_step steps, record points
+ float t = rays_t[index];
+ t += clamp(t * dt_gamma, dt_min, dt_max) * noise;
+ uint32_t step = 0;
+
+ while (t < far && step < n_step) {
+ // current point
+ const float x = clamp(ox + t * dx, -bound, bound);
+ const float y = clamp(oy + t * dy, -bound, bound);
+ const float z = clamp(oz + t * dz, -bound, bound);
+
+ float dt = clamp(t * dt_gamma, dt_min, dt_max);
+
+ // get mip level
+ const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
+
+ const float mip_bound = fminf(scalbnf(1, level), bound);
+ const float mip_rbound = 1 / mip_bound;
+
+ // contraction
+ float cx = x, cy = y, cz = z;
+ const float mag = fmaxf(fabsf(x), fmaxf(fabsf(y), fabsf(z)));
+ if (contract && mag > 1) {
+ // L-INF norm
+ const float Linf_scale = (2 - 1 / mag) / mag;
+ cx *= Linf_scale;
+ cy *= Linf_scale;
+ cz *= Linf_scale;
+ }
+
+ // convert to nearest grid position
+ const int nx = clamp(0.5 * (cx * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
+ const int ny = clamp(0.5 * (cy * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
+ const int nz = clamp(0.5 * (cz * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
+
+ const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
+ const bool occ = grid[index / 8] & (1 << (index % 8));
+
+ // if occpuied, advance a small step, and write to output
+ if (occ) {
+ // write step
+ xyzs[0] = cx;
+ xyzs[1] = cy;
+ xyzs[2] = cz;
+ dirs[0] = dx;
+ dirs[1] = dy;
+ dirs[2] = dz;
+ // calc dt
+ t += dt;
+ ts[0] = t;
+ ts[1] = dt;
+ // step
+ xyzs += 3;
+ dirs += 3;
+ ts += 2;
+ step++;
+
+ // contraction case
+ } else if (contract && mag > 1) {
+ t += dt;
+ // else, skip a large step (basically skip a voxel grid)
+ } else {
+ // calc distance to next voxel
+ const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - cx) * rdx;
+ const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - cy) * rdy;
+ const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - cz) * rdz;
+ const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
+ // step until next voxel
+ do {
+ dt = clamp(t * dt_gamma, dt_min, dt_max);
+ t += dt;
+ } while (t < tt);
+ }
+ }
+}
+
+
+void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor near, const at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor ts, at::Tensor noises) {
+ static constexpr uint32_t N_THREAD = 128;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ rays_o.scalar_type(), "march_rays", ([&] {
+ kernel_march_rays<<>>(n_alive, n_step, rays_alive.data_ptr(), rays_t.data_ptr(), rays_o.data_ptr(), rays_d.data_ptr(), bound, contract, dt_gamma, max_steps, C, H, grid.data_ptr(), near.data_ptr(), far.data_ptr(), xyzs.data_ptr(), dirs.data_ptr(), ts.data_ptr