Files
Magic123/guidance/shape_utils.py
Guocheng Qian 13e18567fa first commit
2023-08-02 19:51:43 -07:00

81 lines
2.8 KiB
Python

from typing import Any, Dict, Optional, Sequence, Tuple, Union
from shap_e.models.transmitter.base import Transmitter
from shap_e.models.query import Query
from shap_e.models.nerstf.renderer import NeRSTFRenderer
from shap_e.util.collections import AttrDict
from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.util.image_util import load_image
from shap_e.models.nn.meta import subdict
import torch
import gc
camera_to_shapes = [
torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32),
torch.tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=torch.float32), # to bird view
torch.tensor([[0, 1, 0], [0, 0, 1], [-1, 0, 0]], dtype=torch.float32), # to rotaed bird view
torch.tensor([[0, -1, 0], [0, 0, 1], [-1, 0, 0]], dtype=torch.float32), # to rotaed bird view
torch.tensor([[-1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=torch.float32), # to bird view
]
def get_density(
render,
query: Query,
params: Dict[str, torch.Tensor],
options: AttrDict[str, Any],
) -> torch.Tensor:
assert render.nerstf is not None
return render.nerstf(query, params=subdict(params, "nerstf"), options=options).density
@torch.no_grad()
def get_shape_from_image(image_path, pos,
rpst_type='sdf', # or 'density'
get_color=True,
shape_guidance=3, device='cuda'):
xm = load_model('transmitter', device=device)
model = load_model('image300M', device=device)
diffusion = diffusion_from_config(load_config('diffusion'))
latent = sample_latents(
batch_size=1,
model=model,
diffusion=diffusion,
guidance_scale=shape_guidance,
model_kwargs=dict(images=[load_image(image_path)]),
progress=True,
clip_denoised=True,
use_fp16=True,
use_karras=True,
karras_steps=64,
sigma_min=1e-3,
sigma_max=160,
s_churn=0,
)[0]
params = (xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params(
latent[None]
)
rpsts, colors = [], []
for camera_to_shape in camera_to_shapes:
query = Query(
position=pos @ camera_to_shape.to(pos.device),
direction=None,
)
if rpst_type == 'sdf':
rpst = xm.renderer.get_signed_distance(query, params, AttrDict())
else:
rpst = get_density(xm.renderer, query, params, AttrDict())
rpsts.append(rpst.squeeze())
if get_color:
color = xm.renderer.get_texture(query, params, AttrDict())
else:
color = None
colors.append(color)
return rpsts, colors