81 lines
2.8 KiB
Python
81 lines
2.8 KiB
Python
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
|
from shap_e.models.transmitter.base import Transmitter
|
|
from shap_e.models.query import Query
|
|
from shap_e.models.nerstf.renderer import NeRSTFRenderer
|
|
from shap_e.util.collections import AttrDict
|
|
from shap_e.diffusion.sample import sample_latents
|
|
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
|
|
from shap_e.models.download import load_model, load_config
|
|
from shap_e.util.image_util import load_image
|
|
from shap_e.models.nn.meta import subdict
|
|
import torch
|
|
import gc
|
|
|
|
|
|
camera_to_shapes = [
|
|
torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=torch.float32),
|
|
torch.tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=torch.float32), # to bird view
|
|
torch.tensor([[0, 1, 0], [0, 0, 1], [-1, 0, 0]], dtype=torch.float32), # to rotaed bird view
|
|
torch.tensor([[0, -1, 0], [0, 0, 1], [-1, 0, 0]], dtype=torch.float32), # to rotaed bird view
|
|
torch.tensor([[-1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=torch.float32), # to bird view
|
|
]
|
|
|
|
|
|
def get_density(
|
|
render,
|
|
query: Query,
|
|
params: Dict[str, torch.Tensor],
|
|
options: AttrDict[str, Any],
|
|
) -> torch.Tensor:
|
|
assert render.nerstf is not None
|
|
return render.nerstf(query, params=subdict(params, "nerstf"), options=options).density
|
|
|
|
|
|
@torch.no_grad()
|
|
def get_shape_from_image(image_path, pos,
|
|
rpst_type='sdf', # or 'density'
|
|
get_color=True,
|
|
shape_guidance=3, device='cuda'):
|
|
xm = load_model('transmitter', device=device)
|
|
model = load_model('image300M', device=device)
|
|
diffusion = diffusion_from_config(load_config('diffusion'))
|
|
latent = sample_latents(
|
|
batch_size=1,
|
|
model=model,
|
|
diffusion=diffusion,
|
|
guidance_scale=shape_guidance,
|
|
model_kwargs=dict(images=[load_image(image_path)]),
|
|
progress=True,
|
|
clip_denoised=True,
|
|
use_fp16=True,
|
|
use_karras=True,
|
|
karras_steps=64,
|
|
sigma_min=1e-3,
|
|
sigma_max=160,
|
|
s_churn=0,
|
|
)[0]
|
|
|
|
params = (xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params(
|
|
latent[None]
|
|
)
|
|
|
|
rpsts, colors = [], []
|
|
for camera_to_shape in camera_to_shapes:
|
|
query = Query(
|
|
position=pos @ camera_to_shape.to(pos.device),
|
|
direction=None,
|
|
)
|
|
|
|
if rpst_type == 'sdf':
|
|
rpst = xm.renderer.get_signed_distance(query, params, AttrDict())
|
|
else:
|
|
rpst = get_density(xm.renderer, query, params, AttrDict())
|
|
rpsts.append(rpst.squeeze())
|
|
|
|
if get_color:
|
|
color = xm.renderer.get_texture(query, params, AttrDict())
|
|
else:
|
|
color = None
|
|
colors.append(color)
|
|
|
|
return rpsts, colors |