Files
Magic123/taichi_modules/utils.py
Guocheng Qian 13e18567fa first commit
2023-08-02 19:51:43 -07:00

224 lines
6.2 KiB
Python

import taichi as ti
import torch
from taichi.math import uvec3
taichi_block_size = 128
data_type = ti.f32
torch_type = torch.float32
MAX_SAMPLES = 1024
NEAR_DISTANCE = 0.01
SQRT3 = 1.7320508075688772
SQRT3_MAX_SAMPLES = SQRT3 / 1024
SQRT3_2 = 1.7320508075688772 * 2
@ti.func
def scalbn(x, exponent):
return x * ti.math.pow(2, exponent)
@ti.func
def calc_dt(t, exp_step_factor, grid_size, scale):
return ti.math.clamp(t * exp_step_factor, SQRT3_MAX_SAMPLES,
SQRT3_2 * scale / grid_size)
@ti.func
def frexp_bit(x):
exponent = 0
if x != 0.0:
# frac = ti.abs(x)
bits = ti.bit_cast(x, ti.u32)
exponent = ti.i32((bits & ti.u32(0x7f800000)) >> 23) - 127
# exponent = (ti.i32(bits & ti.u32(0x7f800000)) >> 23) - 127
bits &= ti.u32(0x7fffff)
bits |= ti.u32(0x3f800000)
frac = ti.bit_cast(bits, ti.f32)
if frac < 0.5:
exponent -= 1
elif frac > 1.0:
exponent += 1
return exponent
@ti.func
def mip_from_pos(xyz, cascades):
mx = ti.abs(xyz).max()
# _, exponent = _frexp(mx)
exponent = frexp_bit(ti.f32(mx)) + 1
# frac, exponent = ti.frexp(ti.f32(mx))
return ti.min(cascades - 1, ti.max(0, exponent))
@ti.func
def mip_from_dt(dt, grid_size, cascades):
# _, exponent = _frexp(dt*grid_size)
exponent = frexp_bit(ti.f32(dt * grid_size))
# frac, exponent = ti.frexp(ti.f32(dt*grid_size))
return ti.min(cascades - 1, ti.max(0, exponent))
@ti.func
def __expand_bits(v):
v = (v * ti.uint32(0x00010001)) & ti.uint32(0xFF0000FF)
v = (v * ti.uint32(0x00000101)) & ti.uint32(0x0F00F00F)
v = (v * ti.uint32(0x00000011)) & ti.uint32(0xC30C30C3)
v = (v * ti.uint32(0x00000005)) & ti.uint32(0x49249249)
return v
@ti.func
def __morton3D(xyz):
xyz = __expand_bits(xyz)
return xyz[0] | (xyz[1] << 1) | (xyz[2] << 2)
@ti.func
def __morton3D_invert(x):
x = x & (0x49249249)
x = (x | (x >> 2)) & ti.uint32(0xc30c30c3)
x = (x | (x >> 4)) & ti.uint32(0x0f00f00f)
x = (x | (x >> 8)) & ti.uint32(0xff0000ff)
x = (x | (x >> 16)) & ti.uint32(0x0000ffff)
return ti.int32(x)
@ti.kernel
def morton3D_invert_kernel(indices: ti.types.ndarray(ndim=1),
coords: ti.types.ndarray(ndim=2)):
for i in indices:
ind = ti.uint32(indices[i])
coords[i, 0] = __morton3D_invert(ind >> 0)
coords[i, 1] = __morton3D_invert(ind >> 1)
coords[i, 2] = __morton3D_invert(ind >> 2)
def morton3D_invert(indices):
coords = torch.zeros(indices.size(0),
3,
device=indices.device,
dtype=torch.int32)
morton3D_invert_kernel(indices.contiguous(), coords)
ti.sync()
return coords
@ti.kernel
def morton3D_kernel(xyzs: ti.types.ndarray(ndim=2),
indices: ti.types.ndarray(ndim=1)):
for s in indices:
xyz = uvec3([xyzs[s, 0], xyzs[s, 1], xyzs[s, 2]])
indices[s] = ti.cast(__morton3D(xyz), ti.int32)
def morton3D(coords1):
indices = torch.zeros(coords1.size(0),
device=coords1.device,
dtype=torch.int32)
morton3D_kernel(coords1.contiguous(), indices)
ti.sync()
return indices
@ti.kernel
def packbits(density_grid: ti.types.ndarray(ndim=1),
density_threshold: float,
density_bitfield: ti.types.ndarray(ndim=1)):
for n in density_bitfield:
bits = ti.uint8(0)
for i in ti.static(range(8)):
bits |= (ti.uint8(1) << i) if (
density_grid[8 * n + i] > density_threshold) else ti.uint8(0)
density_bitfield[n] = bits
@ti.kernel
def torch2ti(field: ti.template(), data: ti.types.ndarray()):
for I in ti.grouped(data):
field[I] = data[I]
@ti.kernel
def ti2torch(field: ti.template(), data: ti.types.ndarray()):
for I in ti.grouped(data):
data[I] = field[I]
@ti.kernel
def ti2torch_grad(field: ti.template(), grad: ti.types.ndarray()):
for I in ti.grouped(grad):
grad[I] = field.grad[I]
@ti.kernel
def torch2ti_grad(field: ti.template(), grad: ti.types.ndarray()):
for I in ti.grouped(grad):
field.grad[I] = grad[I]
@ti.kernel
def torch2ti_vec(field: ti.template(), data: ti.types.ndarray()):
for I in range(data.shape[0] // 2):
field[I] = ti.Vector([data[I * 2], data[I * 2 + 1]])
@ti.kernel
def ti2torch_vec(field: ti.template(), data: ti.types.ndarray()):
for i, j in ti.ndrange(data.shape[0], data.shape[1] // 2):
data[i, j * 2] = field[i, j][0]
data[i, j * 2 + 1] = field[i, j][1]
@ti.kernel
def ti2torch_grad_vec(field: ti.template(), grad: ti.types.ndarray()):
for I in range(grad.shape[0] // 2):
grad[I * 2] = field.grad[I][0]
grad[I * 2 + 1] = field.grad[I][1]
@ti.kernel
def torch2ti_grad_vec(field: ti.template(), grad: ti.types.ndarray()):
for i, j in ti.ndrange(grad.shape[0], grad.shape[1] // 2):
field.grad[i, j][0] = grad[i, j * 2]
field.grad[i, j][1] = grad[i, j * 2 + 1]
def extract_model_state_dict(ckpt_path,
model_name='model',
prefixes_to_ignore=[]):
checkpoint = torch.load(ckpt_path, map_location='cpu')
checkpoint_ = {}
if 'state_dict' in checkpoint: # if it's a pytorch-lightning checkpoint
checkpoint = checkpoint['state_dict']
for k, v in checkpoint.items():
if not k.startswith(model_name):
continue
k = k[len(model_name) + 1:]
for prefix in prefixes_to_ignore:
if k.startswith(prefix):
break
else:
checkpoint_[k] = v
return checkpoint_
def load_ckpt(model, ckpt_path, model_name='model', prefixes_to_ignore=[]):
if not ckpt_path:
return
model_dict = model.state_dict()
checkpoint_ = extract_model_state_dict(ckpt_path, model_name,
prefixes_to_ignore)
model_dict.update(checkpoint_)
model.load_state_dict(model_dict)
def depth2img(depth):
depth = (depth - depth.min()) / (depth.max() - depth.min())
depth_img = cv2.applyColorMap((depth * 255).astype(np.uint8),
cv2.COLORMAP_TURBO)
return depth_img