update scripts and command

This commit is contained in:
Guocheng Qian
2023-08-15 13:34:36 +00:00
parent 6bfbbbf6a2
commit 5821d7bf8b
5 changed files with 118 additions and 17 deletions

View File

@@ -204,6 +204,7 @@ def process_single_image(image_path, depth_estimator, normal_estimator=None):
print(f'[INFO] loading rgba image {rgba_path}...')
rgba = cv2.cvtColor(cv2.imread(rgba_path, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA)
image = cv2.cvtColor(rgba, cv2.COLOR_RGBA2RGB)
else:
print(f'[INFO] loading image {image_path}...')
image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
@@ -213,9 +214,6 @@ def process_single_image(image_path, depth_estimator, normal_estimator=None):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
print(f'[INFO] background removal...')
rgba = BackgroundRemoval()(image) # [H, W, 4]
cv2.imwrite(rgba_path, cv2.cvtColor(rgba, cv2.COLOR_RGBA2BGRA))
# rgba = get_rgba(image) # [H, W, 4]
# cv2.imwrite(rgba_path.replace('rgba', 'rgba2'), cv2.cvtColor(rgba, cv2.COLOR_RGBA2BGRA))
# Predict depth using Midas
mask = rgba[..., -1] > 0
@@ -228,9 +226,31 @@ def process_single_image(image_path, depth_estimator, normal_estimator=None):
# normal = normal_estimator(image)[0]
# normal = (normal.clip(0, 1) * 255).astype(np.uint8).transpose(1, 2, 0)
# normal[~mask] = 0
height, width, _ = image.shape
# Determine the padding needed to make the image square
if height > width:
left_padding = (height - width) // 2
right_padding = height - width - left_padding
padding = ((0, 0), (left_padding, right_padding), (0, 0))
padding2d = ((0, 0), (left_padding, right_padding))
elif width > height:
top_padding = (width - height) // 2
bottom_padding = width - height - top_padding
padding = ((top_padding, bottom_padding), (0, 0), (0, 0))
padding2d = ((top_padding, bottom_padding), (0, 0))
else:
padding = ((0, 0), (0, 0), (0, 0))
padding2d = ((0, 0), (0, 0))
# Apply padding to the image
image = np.pad(image, padding, mode='constant', constant_values=0)
rgba = np.pad(rgba, padding, mode='constant', constant_values=0)
depth = np.pad(depth, padding2d, mode='constant', constant_values=0)
cv2.imwrite(depth_path, depth)
# cv2.imwrite(out_normal, cv2.cvtColor(normal, cv2.COLOR_RGB2BGR))
# breakpoint()
if not os.path.exists(rgba_path):
cv2.imwrite(rgba_path, cv2.cvtColor(rgba, cv2.COLOR_RGBA2BGRA))
@@ -238,14 +258,22 @@ if __name__ == '__main__':
import glob
parser = argparse.ArgumentParser()
parser.add_argument('--path', default=None, type=str, nargs='*', help="path to image (png, jpeg, etc.)")
parser.add_argument('--folder', default=None, type=str, help="path to image (png, jpeg, etc.)")
parser.add_argument('--folder', default=None, type=str, help="path to a folder of image (png, jpeg, etc.)")
parser.add_argument('--imagepattern', default="image.png", type=str, help="image name pattern")
parser.add_argument('--exclude', default='', type=str, nargs='*', help="path to image (png, jpeg, etc.) to exclude")
opt = parser.parse_args()
depth_estimator = DepthEstimator()
# normal_estimator = DPT(task='normal')
paths = opt.path if opt.path is not None else glob.glob(os.path.join(opt.folder, '*/rgba.png'))
if opt.path is not None:
paths = opt.path
else:
paths = glob.glob(os.path.join(opt.folder, f'*/{opt.imagepattern}'))
for exclude_path in opt.exclude:
if exclude_path in paths:
del paths[exclude_path]
for path in paths:
process_single_image(path, depth_estimator,
# normal_estimator
)
)