280 lines
9.6 KiB
Python
280 lines
9.6 KiB
Python
import os
|
|
import sys
|
|
import cv2
|
|
import argparse
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torchvision import transforms
|
|
from PIL import Image
|
|
|
|
from easydict import EasyDict as edict
|
|
|
|
class BackgroundRemoval():
|
|
def __init__(self, device='cuda'):
|
|
|
|
from carvekit.api.high import HiInterface
|
|
self.interface = HiInterface(
|
|
object_type="object", # Can be "object" or "hairs-like".
|
|
batch_size_seg=5,
|
|
batch_size_matting=1,
|
|
device=device,
|
|
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
|
|
matting_mask_size=2048,
|
|
trimap_prob_threshold=231,
|
|
trimap_dilation=30,
|
|
trimap_erosion_iters=5,
|
|
fp16=True,
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def __call__(self, image):
|
|
# image: [H, W, 3] array in [0, 255].
|
|
image = Image.fromarray(image)
|
|
image = self.interface([image])[0]
|
|
image = np.array(image)
|
|
return image
|
|
|
|
|
|
def get_rgba(image, alpha_matting=False):
|
|
try:
|
|
from rembg import remove
|
|
except ImportError:
|
|
print('Please install rembg with "pip install rembg"')
|
|
sys.exit()
|
|
return remove(image, alpha_matting=alpha_matting)
|
|
|
|
|
|
class BLIP2():
|
|
def __init__(self, device='cuda'):
|
|
self.device = device
|
|
from transformers import AutoProcessor, Blip2ForConditionalGeneration
|
|
self.processor = AutoProcessor.from_pretrained(
|
|
"Salesforce/blip2-opt-2.7b")
|
|
self.model = Blip2ForConditionalGeneration.from_pretrained(
|
|
"Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16).to(device)
|
|
|
|
@torch.no_grad()
|
|
def __call__(self, image):
|
|
image = Image.fromarray(image)
|
|
inputs = self.processor(image, return_tensors="pt").to(
|
|
self.device, torch.float16)
|
|
|
|
generated_ids = self.model.generate(**inputs, max_new_tokens=20)
|
|
generated_text = self.processor.batch_decode(
|
|
generated_ids, skip_special_tokens=True)[0].strip()
|
|
|
|
return generated_text
|
|
|
|
|
|
class DPT():
|
|
def __init__(self, task='depth', device='cuda'):
|
|
|
|
self.task = task
|
|
self.device = device
|
|
|
|
from dpt import DPTDepthModel
|
|
|
|
if task == 'depth':
|
|
path = 'pretrained/omnidata/omnidata_dpt_depth_v2.ckpt'
|
|
self.model = DPTDepthModel(backbone='vitb_rn50_384')
|
|
self.aug = transforms.Compose([
|
|
transforms.Resize((384, 384)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=0.5, std=0.5)
|
|
])
|
|
|
|
else: # normal
|
|
path = 'pretrained/omnidata/omnidata_dpt_normal_v2.ckpt'
|
|
self.model = DPTDepthModel(backbone='vitb_rn50_384', num_channels=3)
|
|
self.aug = transforms.Compose([
|
|
transforms.Resize((384, 384)),
|
|
transforms.ToTensor()
|
|
])
|
|
|
|
# load model
|
|
checkpoint = torch.load(path, map_location='cpu')
|
|
if 'state_dict' in checkpoint:
|
|
state_dict = {}
|
|
for k, v in checkpoint['state_dict'].items():
|
|
state_dict[k[6:]] = v
|
|
else:
|
|
state_dict = checkpoint
|
|
self.model.load_state_dict(state_dict)
|
|
self.model.eval().to(device)
|
|
|
|
|
|
@torch.no_grad()
|
|
def __call__(self, image):
|
|
# image: np.ndarray, uint8, [H, W, 3]
|
|
H, W = image.shape[:2]
|
|
image = Image.fromarray(image)
|
|
|
|
image = self.aug(image).unsqueeze(0).to(self.device)
|
|
|
|
if self.task == 'depth':
|
|
depth = self.model(image).clamp(0, 1)
|
|
depth = F.interpolate(depth.unsqueeze(1), size=(H, W), mode='bicubic', align_corners=False)
|
|
depth = depth.squeeze(1).cpu().numpy()
|
|
return depth
|
|
else:
|
|
normal = self.model(image).clamp(0, 1)
|
|
normal = F.interpolate(normal, size=(H, W), mode='bicubic', align_corners=False)
|
|
normal = normal.cpu().numpy()
|
|
return normal
|
|
|
|
|
|
# from munch import DefaultMunch
|
|
from midas.model_loader import default_models, load_model
|
|
|
|
depth_config={
|
|
"input_path": None,
|
|
"output_path": None,
|
|
"model_weights": "pretrained/midas/dpt_beit_large_512.pt",
|
|
"model_type": "dpt_beit_large_512",
|
|
"side": False,
|
|
"optimize": False,
|
|
"height": None,
|
|
"square": False,
|
|
"device":0,
|
|
"grayscale": False
|
|
}
|
|
|
|
|
|
class DepthEstimator:
|
|
def __init__(self,**kwargs):
|
|
# update coming args
|
|
for key, value in kwargs.items():
|
|
depth_config[key]=value
|
|
|
|
# self.config=DefaultMunch.fromDict(depth_config)
|
|
self.config = edict(depth_config)
|
|
|
|
# select device
|
|
self.device = torch.device(self.config.device)
|
|
model, transform, net_w, net_h = load_model(f"cuda:{self.config.device}", self.config.model_weights, self.config.model_type,
|
|
self.config.optimize, self.config.height, self.config.square)
|
|
self.model, self.transform, self.net_w, self.net_h=model, transform, net_w, net_h
|
|
self.first_execution = True
|
|
|
|
@torch.no_grad()
|
|
def process(self,image,target_size):
|
|
sample = torch.from_numpy(image).to(self.device).unsqueeze(0)
|
|
|
|
|
|
if self.first_execution:
|
|
height, width = sample.shape[2:]
|
|
print(f" Input resized to {width}x{height} before entering the encoder")
|
|
self.first_execution = False
|
|
|
|
prediction = self.model.forward(sample)
|
|
prediction = (
|
|
torch.nn.functional.interpolate(
|
|
prediction.unsqueeze(1),
|
|
size=target_size[::-1],
|
|
mode="bicubic",
|
|
align_corners=False,
|
|
)
|
|
.squeeze()
|
|
.cpu()
|
|
.numpy()
|
|
)
|
|
return prediction
|
|
|
|
@torch.no_grad()
|
|
def get_monocular_depth(self,rgb, output_path=None):
|
|
original_image_rgb=rgb
|
|
image = self.transform({"image": original_image_rgb})["image"]
|
|
|
|
prediction = self.process(image, original_image_rgb.shape[1::-1])
|
|
return prediction
|
|
|
|
|
|
|
|
def process_single_image(image_path, depth_estimator, normal_estimator=None):
|
|
out_dir = os.path.dirname(image_path)
|
|
rgba_path = os.path.join(out_dir, 'rgba.png')
|
|
depth_path = os.path.join(out_dir, 'depth.png')
|
|
# out_normal = os.path.join(out_dir, 'normal.png')
|
|
|
|
if os.path.exists(rgba_path):
|
|
print(f'[INFO] loading rgba image {rgba_path}...')
|
|
rgba = cv2.cvtColor(cv2.imread(rgba_path, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA)
|
|
image = cv2.cvtColor(rgba, cv2.COLOR_RGBA2RGB)
|
|
|
|
else:
|
|
print(f'[INFO] loading image {image_path}...')
|
|
image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
|
|
if image.shape[-1] == 4:
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
|
|
else:
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
print(f'[INFO] background removal...')
|
|
rgba = BackgroundRemoval()(image) # [H, W, 4]
|
|
|
|
# Predict depth using Midas
|
|
mask = rgba[..., -1] > 0
|
|
depth = depth_estimator.get_monocular_depth(image/255)
|
|
depth[mask] = (depth[mask] - depth[mask].min()) / (depth[mask].max() - depth[mask].min() + 1e-9)
|
|
depth[~mask] = 0
|
|
depth = (depth * 255).astype(np.uint8)
|
|
|
|
# print(f'[INFO] normal estimation...')
|
|
# normal = normal_estimator(image)[0]
|
|
# normal = (normal.clip(0, 1) * 255).astype(np.uint8).transpose(1, 2, 0)
|
|
# normal[~mask] = 0
|
|
|
|
height, width, _ = image.shape
|
|
# Determine the padding needed to make the image square
|
|
if height > width:
|
|
left_padding = (height - width) // 2
|
|
right_padding = height - width - left_padding
|
|
padding = ((0, 0), (left_padding, right_padding), (0, 0))
|
|
padding2d = ((0, 0), (left_padding, right_padding))
|
|
elif width > height:
|
|
top_padding = (width - height) // 2
|
|
bottom_padding = width - height - top_padding
|
|
padding = ((top_padding, bottom_padding), (0, 0), (0, 0))
|
|
padding2d = ((top_padding, bottom_padding), (0, 0))
|
|
else:
|
|
padding = ((0, 0), (0, 0), (0, 0))
|
|
padding2d = ((0, 0), (0, 0))
|
|
|
|
# Apply padding to the image
|
|
image = np.pad(image, padding, mode='constant', constant_values=0)
|
|
rgba = np.pad(rgba, padding, mode='constant', constant_values=0)
|
|
depth = np.pad(depth, padding2d, mode='constant', constant_values=0)
|
|
|
|
cv2.imwrite(depth_path, depth)
|
|
# cv2.imwrite(out_normal, cv2.cvtColor(normal, cv2.COLOR_RGB2BGR))
|
|
# breakpoint()
|
|
if not os.path.exists(rgba_path):
|
|
cv2.imwrite(rgba_path, cv2.cvtColor(rgba, cv2.COLOR_RGBA2BGRA))
|
|
|
|
if __name__ == '__main__':
|
|
import glob
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--path', default=None, type=str, nargs='*', help="path to image (png, jpeg, etc.)")
|
|
parser.add_argument('--folder', default=None, type=str, help="path to a folder of image (png, jpeg, etc.)")
|
|
parser.add_argument('--imagepattern', default="image.png", type=str, help="image name pattern")
|
|
parser.add_argument('--exclude', default='', type=str, nargs='*', help="path to image (png, jpeg, etc.) to exclude")
|
|
opt = parser.parse_args()
|
|
|
|
depth_estimator = DepthEstimator()
|
|
# normal_estimator = DPT(task='normal')
|
|
|
|
if opt.path is not None:
|
|
paths = opt.path
|
|
else:
|
|
paths = glob.glob(os.path.join(opt.folder, f'*/{opt.imagepattern}'))
|
|
for exclude_path in opt.exclude:
|
|
if exclude_path in paths:
|
|
del paths[exclude_path]
|
|
for path in paths:
|
|
process_single_image(path, depth_estimator,
|
|
# normal_estimator
|
|
)
|