first commit
This commit is contained in:
224
taichi_modules/utils.py
Normal file
224
taichi_modules/utils.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import taichi as ti
|
||||
import torch
|
||||
from taichi.math import uvec3
|
||||
|
||||
taichi_block_size = 128
|
||||
|
||||
data_type = ti.f32
|
||||
torch_type = torch.float32
|
||||
|
||||
MAX_SAMPLES = 1024
|
||||
NEAR_DISTANCE = 0.01
|
||||
SQRT3 = 1.7320508075688772
|
||||
SQRT3_MAX_SAMPLES = SQRT3 / 1024
|
||||
SQRT3_2 = 1.7320508075688772 * 2
|
||||
|
||||
|
||||
@ti.func
|
||||
def scalbn(x, exponent):
|
||||
return x * ti.math.pow(2, exponent)
|
||||
|
||||
|
||||
@ti.func
|
||||
def calc_dt(t, exp_step_factor, grid_size, scale):
|
||||
return ti.math.clamp(t * exp_step_factor, SQRT3_MAX_SAMPLES,
|
||||
SQRT3_2 * scale / grid_size)
|
||||
|
||||
|
||||
@ti.func
|
||||
def frexp_bit(x):
|
||||
exponent = 0
|
||||
if x != 0.0:
|
||||
# frac = ti.abs(x)
|
||||
bits = ti.bit_cast(x, ti.u32)
|
||||
exponent = ti.i32((bits & ti.u32(0x7f800000)) >> 23) - 127
|
||||
# exponent = (ti.i32(bits & ti.u32(0x7f800000)) >> 23) - 127
|
||||
bits &= ti.u32(0x7fffff)
|
||||
bits |= ti.u32(0x3f800000)
|
||||
frac = ti.bit_cast(bits, ti.f32)
|
||||
if frac < 0.5:
|
||||
exponent -= 1
|
||||
elif frac > 1.0:
|
||||
exponent += 1
|
||||
return exponent
|
||||
|
||||
|
||||
@ti.func
|
||||
def mip_from_pos(xyz, cascades):
|
||||
mx = ti.abs(xyz).max()
|
||||
# _, exponent = _frexp(mx)
|
||||
exponent = frexp_bit(ti.f32(mx)) + 1
|
||||
# frac, exponent = ti.frexp(ti.f32(mx))
|
||||
return ti.min(cascades - 1, ti.max(0, exponent))
|
||||
|
||||
|
||||
@ti.func
|
||||
def mip_from_dt(dt, grid_size, cascades):
|
||||
# _, exponent = _frexp(dt*grid_size)
|
||||
exponent = frexp_bit(ti.f32(dt * grid_size))
|
||||
# frac, exponent = ti.frexp(ti.f32(dt*grid_size))
|
||||
return ti.min(cascades - 1, ti.max(0, exponent))
|
||||
|
||||
|
||||
@ti.func
|
||||
def __expand_bits(v):
|
||||
v = (v * ti.uint32(0x00010001)) & ti.uint32(0xFF0000FF)
|
||||
v = (v * ti.uint32(0x00000101)) & ti.uint32(0x0F00F00F)
|
||||
v = (v * ti.uint32(0x00000011)) & ti.uint32(0xC30C30C3)
|
||||
v = (v * ti.uint32(0x00000005)) & ti.uint32(0x49249249)
|
||||
return v
|
||||
|
||||
|
||||
@ti.func
|
||||
def __morton3D(xyz):
|
||||
xyz = __expand_bits(xyz)
|
||||
return xyz[0] | (xyz[1] << 1) | (xyz[2] << 2)
|
||||
|
||||
|
||||
@ti.func
|
||||
def __morton3D_invert(x):
|
||||
x = x & (0x49249249)
|
||||
x = (x | (x >> 2)) & ti.uint32(0xc30c30c3)
|
||||
x = (x | (x >> 4)) & ti.uint32(0x0f00f00f)
|
||||
x = (x | (x >> 8)) & ti.uint32(0xff0000ff)
|
||||
x = (x | (x >> 16)) & ti.uint32(0x0000ffff)
|
||||
return ti.int32(x)
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def morton3D_invert_kernel(indices: ti.types.ndarray(ndim=1),
|
||||
coords: ti.types.ndarray(ndim=2)):
|
||||
for i in indices:
|
||||
ind = ti.uint32(indices[i])
|
||||
coords[i, 0] = __morton3D_invert(ind >> 0)
|
||||
coords[i, 1] = __morton3D_invert(ind >> 1)
|
||||
coords[i, 2] = __morton3D_invert(ind >> 2)
|
||||
|
||||
|
||||
def morton3D_invert(indices):
|
||||
coords = torch.zeros(indices.size(0),
|
||||
3,
|
||||
device=indices.device,
|
||||
dtype=torch.int32)
|
||||
morton3D_invert_kernel(indices.contiguous(), coords)
|
||||
ti.sync()
|
||||
return coords
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def morton3D_kernel(xyzs: ti.types.ndarray(ndim=2),
|
||||
indices: ti.types.ndarray(ndim=1)):
|
||||
for s in indices:
|
||||
xyz = uvec3([xyzs[s, 0], xyzs[s, 1], xyzs[s, 2]])
|
||||
indices[s] = ti.cast(__morton3D(xyz), ti.int32)
|
||||
|
||||
|
||||
def morton3D(coords1):
|
||||
indices = torch.zeros(coords1.size(0),
|
||||
device=coords1.device,
|
||||
dtype=torch.int32)
|
||||
morton3D_kernel(coords1.contiguous(), indices)
|
||||
ti.sync()
|
||||
return indices
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def packbits(density_grid: ti.types.ndarray(ndim=1),
|
||||
density_threshold: float,
|
||||
density_bitfield: ti.types.ndarray(ndim=1)):
|
||||
|
||||
for n in density_bitfield:
|
||||
bits = ti.uint8(0)
|
||||
|
||||
for i in ti.static(range(8)):
|
||||
bits |= (ti.uint8(1) << i) if (
|
||||
density_grid[8 * n + i] > density_threshold) else ti.uint8(0)
|
||||
|
||||
density_bitfield[n] = bits
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def torch2ti(field: ti.template(), data: ti.types.ndarray()):
|
||||
for I in ti.grouped(data):
|
||||
field[I] = data[I]
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def ti2torch(field: ti.template(), data: ti.types.ndarray()):
|
||||
for I in ti.grouped(data):
|
||||
data[I] = field[I]
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def ti2torch_grad(field: ti.template(), grad: ti.types.ndarray()):
|
||||
for I in ti.grouped(grad):
|
||||
grad[I] = field.grad[I]
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def torch2ti_grad(field: ti.template(), grad: ti.types.ndarray()):
|
||||
for I in ti.grouped(grad):
|
||||
field.grad[I] = grad[I]
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def torch2ti_vec(field: ti.template(), data: ti.types.ndarray()):
|
||||
for I in range(data.shape[0] // 2):
|
||||
field[I] = ti.Vector([data[I * 2], data[I * 2 + 1]])
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def ti2torch_vec(field: ti.template(), data: ti.types.ndarray()):
|
||||
for i, j in ti.ndrange(data.shape[0], data.shape[1] // 2):
|
||||
data[i, j * 2] = field[i, j][0]
|
||||
data[i, j * 2 + 1] = field[i, j][1]
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def ti2torch_grad_vec(field: ti.template(), grad: ti.types.ndarray()):
|
||||
for I in range(grad.shape[0] // 2):
|
||||
grad[I * 2] = field.grad[I][0]
|
||||
grad[I * 2 + 1] = field.grad[I][1]
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def torch2ti_grad_vec(field: ti.template(), grad: ti.types.ndarray()):
|
||||
for i, j in ti.ndrange(grad.shape[0], grad.shape[1] // 2):
|
||||
field.grad[i, j][0] = grad[i, j * 2]
|
||||
field.grad[i, j][1] = grad[i, j * 2 + 1]
|
||||
|
||||
|
||||
def extract_model_state_dict(ckpt_path,
|
||||
model_name='model',
|
||||
prefixes_to_ignore=[]):
|
||||
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
||||
checkpoint_ = {}
|
||||
if 'state_dict' in checkpoint: # if it's a pytorch-lightning checkpoint
|
||||
checkpoint = checkpoint['state_dict']
|
||||
for k, v in checkpoint.items():
|
||||
if not k.startswith(model_name):
|
||||
continue
|
||||
k = k[len(model_name) + 1:]
|
||||
for prefix in prefixes_to_ignore:
|
||||
if k.startswith(prefix):
|
||||
break
|
||||
else:
|
||||
checkpoint_[k] = v
|
||||
return checkpoint_
|
||||
|
||||
|
||||
def load_ckpt(model, ckpt_path, model_name='model', prefixes_to_ignore=[]):
|
||||
if not ckpt_path:
|
||||
return
|
||||
model_dict = model.state_dict()
|
||||
checkpoint_ = extract_model_state_dict(ckpt_path, model_name,
|
||||
prefixes_to_ignore)
|
||||
model_dict.update(checkpoint_)
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
def depth2img(depth):
|
||||
depth = (depth - depth.min()) / (depth.max() - depth.min())
|
||||
depth_img = cv2.applyColorMap((depth * 255).astype(np.uint8),
|
||||
cv2.COLORMAP_TURBO)
|
||||
|
||||
return depth_img
|
||||
Reference in New Issue
Block a user