import os import sys import cv2 import argparse import numpy as np import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from PIL import Image from easydict import EasyDict as edict class BackgroundRemoval(): def __init__(self, device='cuda'): from carvekit.api.high import HiInterface self.interface = HiInterface( object_type="object", # Can be "object" or "hairs-like". batch_size_seg=5, batch_size_matting=1, device=device, seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net matting_mask_size=2048, trimap_prob_threshold=231, trimap_dilation=30, trimap_erosion_iters=5, fp16=True, ) @torch.no_grad() def __call__(self, image): # image: [H, W, 3] array in [0, 255]. image = Image.fromarray(image) image = self.interface([image])[0] image = np.array(image) return image def get_rgba(image, alpha_matting=False): try: from rembg import remove except ImportError: print('Please install rembg with "pip install rembg"') sys.exit() return remove(image, alpha_matting=alpha_matting) class BLIP2(): def __init__(self, device='cuda'): self.device = device from transformers import AutoProcessor, Blip2ForConditionalGeneration self.processor = AutoProcessor.from_pretrained( "Salesforce/blip2-opt-2.7b") self.model = Blip2ForConditionalGeneration.from_pretrained( "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16).to(device) @torch.no_grad() def __call__(self, image): image = Image.fromarray(image) inputs = self.processor(image, return_tensors="pt").to( self.device, torch.float16) generated_ids = self.model.generate(**inputs, max_new_tokens=20) generated_text = self.processor.batch_decode( generated_ids, skip_special_tokens=True)[0].strip() return generated_text class DPT(): def __init__(self, task='depth', device='cuda'): self.task = task self.device = device from dpt import DPTDepthModel if task == 'depth': path = 'pretrained/omnidata/omnidata_dpt_depth_v2.ckpt' self.model = DPTDepthModel(backbone='vitb_rn50_384') self.aug = transforms.Compose([ transforms.Resize((384, 384)), transforms.ToTensor(), transforms.Normalize(mean=0.5, std=0.5) ]) else: # normal path = 'pretrained/omnidata/omnidata_dpt_normal_v2.ckpt' self.model = DPTDepthModel(backbone='vitb_rn50_384', num_channels=3) self.aug = transforms.Compose([ transforms.Resize((384, 384)), transforms.ToTensor() ]) # load model checkpoint = torch.load(path, map_location='cpu') if 'state_dict' in checkpoint: state_dict = {} for k, v in checkpoint['state_dict'].items(): state_dict[k[6:]] = v else: state_dict = checkpoint self.model.load_state_dict(state_dict) self.model.eval().to(device) @torch.no_grad() def __call__(self, image): # image: np.ndarray, uint8, [H, W, 3] H, W = image.shape[:2] image = Image.fromarray(image) image = self.aug(image).unsqueeze(0).to(self.device) if self.task == 'depth': depth = self.model(image).clamp(0, 1) depth = F.interpolate(depth.unsqueeze(1), size=(H, W), mode='bicubic', align_corners=False) depth = depth.squeeze(1).cpu().numpy() return depth else: normal = self.model(image).clamp(0, 1) normal = F.interpolate(normal, size=(H, W), mode='bicubic', align_corners=False) normal = normal.cpu().numpy() return normal # from munch import DefaultMunch from midas.model_loader import default_models, load_model depth_config={ "input_path": None, "output_path": None, "model_weights": "pretrained/midas/dpt_beit_large_512.pt", "model_type": "dpt_beit_large_512", "side": False, "optimize": False, "height": None, "square": False, "device":0, "grayscale": False } class DepthEstimator: def __init__(self,**kwargs): # update coming args for key, value in kwargs.items(): depth_config[key]=value # self.config=DefaultMunch.fromDict(depth_config) self.config = edict(depth_config) # select device self.device = torch.device(self.config.device) model, transform, net_w, net_h = load_model(f"cuda:{self.config.device}", self.config.model_weights, self.config.model_type, self.config.optimize, self.config.height, self.config.square) self.model, self.transform, self.net_w, self.net_h=model, transform, net_w, net_h self.first_execution = True @torch.no_grad() def process(self,image,target_size): sample = torch.from_numpy(image).to(self.device).unsqueeze(0) if self.first_execution: height, width = sample.shape[2:] print(f" Input resized to {width}x{height} before entering the encoder") self.first_execution = False prediction = self.model.forward(sample) prediction = ( torch.nn.functional.interpolate( prediction.unsqueeze(1), size=target_size[::-1], mode="bicubic", align_corners=False, ) .squeeze() .cpu() .numpy() ) return prediction @torch.no_grad() def get_monocular_depth(self,rgb, output_path=None): original_image_rgb=rgb image = self.transform({"image": original_image_rgb})["image"] prediction = self.process(image, original_image_rgb.shape[1::-1]) return prediction def process_single_image(image_path, depth_estimator, normal_estimator=None): out_dir = os.path.dirname(image_path) rgba_path = os.path.join(out_dir, 'rgba.png') depth_path = os.path.join(out_dir, 'depth.png') # out_normal = os.path.join(out_dir, 'normal.png') if os.path.exists(rgba_path): print(f'[INFO] loading rgba image {rgba_path}...') rgba = cv2.cvtColor(cv2.imread(rgba_path, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA) image = cv2.cvtColor(rgba, cv2.COLOR_RGBA2RGB) else: print(f'[INFO] loading image {image_path}...') image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) if image.shape[-1] == 4: image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB) else: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) print(f'[INFO] background removal...') rgba = BackgroundRemoval()(image) # [H, W, 4] # Predict depth using Midas mask = rgba[..., -1] > 0 depth = depth_estimator.get_monocular_depth(image/255) depth[mask] = (depth[mask] - depth[mask].min()) / (depth[mask].max() - depth[mask].min() + 1e-9) depth[~mask] = 0 depth = (depth * 255).astype(np.uint8) # print(f'[INFO] normal estimation...') # normal = normal_estimator(image)[0] # normal = (normal.clip(0, 1) * 255).astype(np.uint8).transpose(1, 2, 0) # normal[~mask] = 0 height, width, _ = image.shape # Determine the padding needed to make the image square if height > width: left_padding = (height - width) // 2 right_padding = height - width - left_padding padding = ((0, 0), (left_padding, right_padding), (0, 0)) padding2d = ((0, 0), (left_padding, right_padding)) elif width > height: top_padding = (width - height) // 2 bottom_padding = width - height - top_padding padding = ((top_padding, bottom_padding), (0, 0), (0, 0)) padding2d = ((top_padding, bottom_padding), (0, 0)) else: padding = ((0, 0), (0, 0), (0, 0)) padding2d = ((0, 0), (0, 0)) # Apply padding to the image image = np.pad(image, padding, mode='constant', constant_values=0) rgba = np.pad(rgba, padding, mode='constant', constant_values=0) depth = np.pad(depth, padding2d, mode='constant', constant_values=0) cv2.imwrite(depth_path, depth) # cv2.imwrite(out_normal, cv2.cvtColor(normal, cv2.COLOR_RGB2BGR)) # breakpoint() if not os.path.exists(rgba_path): cv2.imwrite(rgba_path, cv2.cvtColor(rgba, cv2.COLOR_RGBA2BGRA)) if __name__ == '__main__': import glob parser = argparse.ArgumentParser() parser.add_argument('--path', default=None, type=str, nargs='*', help="path to image (png, jpeg, etc.)") parser.add_argument('--folder', default=None, type=str, help="path to a folder of image (png, jpeg, etc.)") parser.add_argument('--imagepattern', default="image.png", type=str, help="image name pattern") parser.add_argument('--exclude', default='', type=str, nargs='*', help="path to image (png, jpeg, etc.) to exclude") opt = parser.parse_args() depth_estimator = DepthEstimator() # normal_estimator = DPT(task='normal') if opt.path is not None: paths = opt.path else: paths = glob.glob(os.path.join(opt.folder, f'*/{opt.imagepattern}')) for exclude_path in opt.exclude: if exclude_path in paths: del paths[exclude_path] for path in paths: process_single_image(path, depth_estimator, # normal_estimator )