first commit
This commit is contained in:
340
taichi_modules/ray_march.py
Normal file
340
taichi_modules/ray_march.py
Normal file
@@ -0,0 +1,340 @@
|
||||
import taichi as ti
|
||||
import torch
|
||||
from taichi.math import vec3
|
||||
from torch.cuda.amp import custom_fwd
|
||||
|
||||
from .utils import __morton3D, calc_dt, mip_from_dt, mip_from_pos
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def raymarching_train(rays_o: ti.types.ndarray(ndim=2),
|
||||
rays_d: ti.types.ndarray(ndim=2),
|
||||
hits_t: ti.types.ndarray(ndim=2),
|
||||
density_bitfield: ti.types.ndarray(ndim=1),
|
||||
noise: ti.types.ndarray(ndim=1),
|
||||
counter: ti.types.ndarray(ndim=1),
|
||||
rays_a: ti.types.ndarray(ndim=2),
|
||||
xyzs: ti.types.ndarray(ndim=2),
|
||||
dirs: ti.types.ndarray(ndim=2),
|
||||
deltas: ti.types.ndarray(ndim=1),
|
||||
ts: ti.types.ndarray(ndim=1), cascades: int,
|
||||
grid_size: int, scale: float, exp_step_factor: float,
|
||||
max_samples: float):
|
||||
|
||||
# ti.loop_config(block_dim=256)
|
||||
for r in noise:
|
||||
ray_o = vec3(rays_o[r, 0], rays_o[r, 1], rays_o[r, 2])
|
||||
ray_d = vec3(rays_d[r, 0], rays_d[r, 1], rays_d[r, 2])
|
||||
d_inv = 1.0 / ray_d
|
||||
|
||||
t1, t2 = hits_t[r, 0], hits_t[r, 1]
|
||||
|
||||
grid_size3 = grid_size**3
|
||||
grid_size_inv = 1.0 / grid_size
|
||||
|
||||
if t1 >= 0:
|
||||
dt = calc_dt(t1, exp_step_factor, grid_size, scale)
|
||||
t1 += dt * noise[r]
|
||||
|
||||
t = t1
|
||||
N_samples = 0
|
||||
|
||||
while (0 <= t) & (t < t2) & (N_samples < max_samples):
|
||||
xyz = ray_o + t * ray_d
|
||||
dt = calc_dt(t, exp_step_factor, grid_size, scale)
|
||||
mip = ti.max(mip_from_pos(xyz, cascades),
|
||||
mip_from_dt(dt, grid_size, cascades))
|
||||
|
||||
# mip_bound = 0.5
|
||||
# mip_bound = ti.min(ti.pow(2., mip - 1), scale)
|
||||
mip_bound = scale
|
||||
mip_bound_inv = 1 / mip_bound
|
||||
|
||||
nxyz = ti.math.clamp(0.5 * (xyz * mip_bound_inv + 1) * grid_size,
|
||||
0.0, grid_size - 1.0)
|
||||
# nxyz = ti.ceil(nxyz)
|
||||
|
||||
idx = mip * grid_size3 + __morton3D(ti.cast(nxyz, ti.u32))
|
||||
occ = density_bitfield[ti.u32(idx // 8)] & (1 << ti.u32(idx % 8))
|
||||
# idx = __morton3D(ti.cast(nxyz, ti.uint32))
|
||||
# occ = density_bitfield[mip, idx//8] & (1 << ti.cast(idx%8, ti.uint32))
|
||||
|
||||
if occ:
|
||||
t += dt
|
||||
N_samples += 1
|
||||
else:
|
||||
# t += dt
|
||||
txyz = (((nxyz + 0.5 + 0.5 * ti.math.sign(ray_d)) *
|
||||
grid_size_inv * 2 - 1) * mip_bound - xyz) * d_inv
|
||||
|
||||
t_target = t + ti.max(0, txyz.min())
|
||||
t += calc_dt(t, exp_step_factor, grid_size, scale)
|
||||
while t < t_target:
|
||||
t += calc_dt(t, exp_step_factor, grid_size, scale)
|
||||
|
||||
start_idx = ti.atomic_add(counter[0], N_samples)
|
||||
ray_count = ti.atomic_add(counter[1], 1)
|
||||
|
||||
rays_a[ray_count, 0] = r
|
||||
rays_a[ray_count, 1] = start_idx
|
||||
rays_a[ray_count, 2] = N_samples
|
||||
|
||||
t = t1
|
||||
samples = 0
|
||||
|
||||
while (t < t2) & (samples < N_samples):
|
||||
xyz = ray_o + t * ray_d
|
||||
dt = calc_dt(t, exp_step_factor, grid_size, scale)
|
||||
mip = ti.max(mip_from_pos(xyz, cascades),
|
||||
mip_from_dt(dt, grid_size, cascades))
|
||||
|
||||
# mip_bound = 0.5
|
||||
# mip_bound = ti.min(ti.pow(2., mip - 1), scale)
|
||||
mip_bound = scale
|
||||
mip_bound_inv = 1 / mip_bound
|
||||
|
||||
nxyz = ti.math.clamp(0.5 * (xyz * mip_bound_inv + 1) * grid_size,
|
||||
0.0, grid_size - 1.0)
|
||||
# nxyz = ti.ceil(nxyz)
|
||||
|
||||
idx = mip * grid_size3 + __morton3D(ti.cast(nxyz, ti.u32))
|
||||
occ = density_bitfield[ti.u32(idx // 8)] & (1 << ti.u32(idx % 8))
|
||||
# idx = __morton3D(ti.cast(nxyz, ti.uint32))
|
||||
# occ = density_bitfield[mip, idx//8] & (1 << ti.cast(idx%8, ti.uint32))
|
||||
|
||||
if occ:
|
||||
s = start_idx + samples
|
||||
xyzs[s, 0] = xyz[0]
|
||||
xyzs[s, 1] = xyz[1]
|
||||
xyzs[s, 2] = xyz[2]
|
||||
dirs[s, 0] = ray_d[0]
|
||||
dirs[s, 1] = ray_d[1]
|
||||
dirs[s, 2] = ray_d[2]
|
||||
ts[s] = t
|
||||
deltas[s] = dt
|
||||
t += dt
|
||||
samples += 1
|
||||
else:
|
||||
# t += dt
|
||||
txyz = (((nxyz + 0.5 + 0.5 * ti.math.sign(ray_d)) *
|
||||
grid_size_inv * 2 - 1) * mip_bound - xyz) * d_inv
|
||||
|
||||
t_target = t + ti.max(0, txyz.min())
|
||||
t += calc_dt(t, exp_step_factor, grid_size, scale)
|
||||
while t < t_target:
|
||||
t += calc_dt(t, exp_step_factor, grid_size, scale)
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def raymarching_train_backword(segments: ti.types.ndarray(ndim=2),
|
||||
ts: ti.types.ndarray(ndim=1),
|
||||
dL_drays_o: ti.types.ndarray(ndim=2),
|
||||
dL_drays_d: ti.types.ndarray(ndim=2),
|
||||
dL_dxyzs: ti.types.ndarray(ndim=2),
|
||||
dL_ddirs: ti.types.ndarray(ndim=2)):
|
||||
|
||||
for s in segments:
|
||||
index = segments[s]
|
||||
dxyz = dL_dxyzs[index]
|
||||
ddir = dL_ddirs[index]
|
||||
|
||||
dL_drays_o[s] = dxyz
|
||||
dL_drays_d[s] = dxyz * ts[index] + ddir
|
||||
|
||||
|
||||
class RayMarcherTaichi(torch.nn.Module):
|
||||
|
||||
def __init__(self, batch_size=8192):
|
||||
super(RayMarcherTaichi, self).__init__()
|
||||
|
||||
self.register_buffer('rays_a',
|
||||
torch.zeros(batch_size, 3, dtype=torch.int32))
|
||||
self.register_buffer(
|
||||
'xyzs', torch.zeros(batch_size * 1024, 3, dtype=torch.float32))
|
||||
self.register_buffer(
|
||||
'dirs', torch.zeros(batch_size * 1024, 3, dtype=torch.float32))
|
||||
self.register_buffer(
|
||||
'deltas', torch.zeros(batch_size * 1024, dtype=torch.float32))
|
||||
self.register_buffer(
|
||||
'ts', torch.zeros(batch_size * 1024, dtype=torch.float32))
|
||||
|
||||
# self.register_buffer('dL_drays_o', torch.zeros(batch_size, dtype=torch.float32))
|
||||
# self.register_buffer('dL_drays_d', torch.zeros(batch_size, dtype=torch.float32))
|
||||
|
||||
class _module_function(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, rays_o, rays_d, hits_t, density_bitfield,
|
||||
cascades, scale, exp_step_factor, grid_size,
|
||||
max_samples):
|
||||
# noise to perturb the first sample of each ray
|
||||
noise = torch.rand_like(rays_o[:, 0])
|
||||
counter = torch.zeros(2,
|
||||
device=rays_o.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
raymarching_train(\
|
||||
rays_o, rays_d,
|
||||
hits_t.contiguous(),
|
||||
density_bitfield, noise, counter,
|
||||
self.rays_a.contiguous(),
|
||||
self.xyzs.contiguous(),
|
||||
self.dirs.contiguous(),
|
||||
self.deltas.contiguous(),
|
||||
self.ts.contiguous(),
|
||||
cascades, grid_size, scale,
|
||||
exp_step_factor, max_samples)
|
||||
|
||||
# ti.sync()
|
||||
|
||||
total_samples = counter[0] # total samples for all rays
|
||||
# remove redundant output
|
||||
xyzs = self.xyzs[:total_samples]
|
||||
dirs = self.dirs[:total_samples]
|
||||
deltas = self.deltas[:total_samples]
|
||||
ts = self.ts[:total_samples]
|
||||
|
||||
return self.rays_a, xyzs, dirs, deltas, ts, total_samples
|
||||
|
||||
# @staticmethod
|
||||
# @custom_bwd
|
||||
# def backward(ctx, dL_drays_a, dL_dxyzs, dL_ddirs, dL_ddeltas, dL_dts,
|
||||
# dL_dtotal_samples):
|
||||
# rays_a, ts = ctx.saved_tensors
|
||||
# # rays_a = rays_a.contiguous()
|
||||
# ts = ts.contiguous()
|
||||
# segments = torch.cat([rays_a[:, 1], rays_a[-1:, 1] + rays_a[-1:, 2]])
|
||||
# dL_drays_o = torch.zeros_like(rays_a[:, 0])
|
||||
# dL_drays_d = torch.zeros_like(rays_a[:, 0])
|
||||
# raymarching_train_backword(segments.contiguous(), ts, dL_drays_o,
|
||||
# dL_drays_d, dL_dxyzs, dL_ddirs)
|
||||
# # ti.sync()
|
||||
# # dL_drays_o = segment_csr(dL_dxyzs, segments)
|
||||
# # dL_drays_d = \
|
||||
# # segment_csr(dL_dxyzs*rearrange(ts, 'n -> n 1')+dL_ddirs, segments)
|
||||
|
||||
# return dL_drays_o, dL_drays_d, None, None, None, None, None, None, None
|
||||
|
||||
self._module_function = _module_function
|
||||
|
||||
def forward(self, rays_o, rays_d, hits_t, density_bitfield, cascades,
|
||||
scale, exp_step_factor, grid_size, max_samples):
|
||||
return self._module_function.apply(rays_o, rays_d, hits_t,
|
||||
density_bitfield, cascades, scale,
|
||||
exp_step_factor, grid_size,
|
||||
max_samples)
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def raymarching_test_kernel(
|
||||
rays_o: ti.types.ndarray(ndim=2),
|
||||
rays_d: ti.types.ndarray(ndim=2),
|
||||
hits_t: ti.types.ndarray(ndim=2),
|
||||
alive_indices: ti.types.ndarray(ndim=1),
|
||||
density_bitfield: ti.types.ndarray(ndim=1),
|
||||
cascades: int,
|
||||
grid_size: int,
|
||||
scale: float,
|
||||
exp_step_factor: float,
|
||||
N_samples: int,
|
||||
max_samples: int,
|
||||
xyzs: ti.types.ndarray(ndim=2),
|
||||
dirs: ti.types.ndarray(ndim=2),
|
||||
deltas: ti.types.ndarray(ndim=1),
|
||||
ts: ti.types.ndarray(ndim=1),
|
||||
N_eff_samples: ti.types.ndarray(ndim=1),
|
||||
):
|
||||
|
||||
for n in alive_indices:
|
||||
r = alive_indices[n]
|
||||
grid_size3 = grid_size**3
|
||||
grid_size_inv = 1.0 / grid_size
|
||||
|
||||
ray_o = vec3(rays_o[r, 0], rays_o[r, 1], rays_o[r, 2])
|
||||
ray_d = vec3(rays_d[r, 0], rays_d[r, 1], rays_d[r, 2])
|
||||
d_inv = 1.0 / ray_d
|
||||
|
||||
t = hits_t[r, 0]
|
||||
t2 = hits_t[r, 1]
|
||||
|
||||
s = 0
|
||||
|
||||
while (0 <= t) & (t < t2) & (s < N_samples):
|
||||
xyz = ray_o + t * ray_d
|
||||
dt = calc_dt(t, exp_step_factor, grid_size, scale)
|
||||
mip = ti.max(mip_from_pos(xyz, cascades),
|
||||
mip_from_dt(dt, grid_size, cascades))
|
||||
|
||||
# mip_bound = 0.5
|
||||
# mip_bound = ti.min(ti.pow(2., mip - 1), scale)
|
||||
mip_bound = scale
|
||||
mip_bound_inv = 1 / mip_bound
|
||||
|
||||
nxyz = ti.math.clamp(0.5 * (xyz * mip_bound_inv + 1) * grid_size,
|
||||
0.0, grid_size - 1.0)
|
||||
# nxyz = ti.ceil(nxyz)
|
||||
|
||||
idx = mip * grid_size3 + __morton3D(ti.cast(nxyz, ti.u32))
|
||||
occ = density_bitfield[ti.u32(idx // 8)] & (1 << ti.u32(idx % 8))
|
||||
|
||||
if occ:
|
||||
xyzs[n, s, 0] = xyz[0]
|
||||
xyzs[n, s, 1] = xyz[1]
|
||||
xyzs[n, s, 2] = xyz[2]
|
||||
dirs[n, s, 0] = ray_d[0]
|
||||
dirs[n, s, 1] = ray_d[1]
|
||||
dirs[n, s, 2] = ray_d[2]
|
||||
ts[n, s] = t
|
||||
deltas[n, s] = dt
|
||||
t += dt
|
||||
hits_t[r, 0] = t
|
||||
s += 1
|
||||
|
||||
else:
|
||||
txyz = (((nxyz + 0.5 + 0.5 * ti.math.sign(ray_d)) *
|
||||
grid_size_inv * 2 - 1) * mip_bound - xyz) * d_inv
|
||||
|
||||
t_target = t + ti.max(0, txyz.min())
|
||||
t += calc_dt(t, exp_step_factor, grid_size, scale)
|
||||
while t < t_target:
|
||||
t += calc_dt(t, exp_step_factor, grid_size, scale)
|
||||
|
||||
N_eff_samples[n] = s
|
||||
|
||||
|
||||
def raymarching_test(rays_o, rays_d, hits_t, alive_indices, density_bitfield,
|
||||
cascades, scale, exp_step_factor, grid_size, max_samples,
|
||||
N_samples):
|
||||
|
||||
N_rays = alive_indices.size(0)
|
||||
xyzs = torch.zeros(N_rays,
|
||||
N_samples,
|
||||
3,
|
||||
device=rays_o.device,
|
||||
dtype=rays_o.dtype)
|
||||
dirs = torch.zeros(N_rays,
|
||||
N_samples,
|
||||
3,
|
||||
device=rays_o.device,
|
||||
dtype=rays_o.dtype)
|
||||
deltas = torch.zeros(N_rays,
|
||||
N_samples,
|
||||
device=rays_o.device,
|
||||
dtype=rays_o.dtype)
|
||||
ts = torch.zeros(N_rays,
|
||||
N_samples,
|
||||
device=rays_o.device,
|
||||
dtype=rays_o.dtype)
|
||||
N_eff_samples = torch.zeros(N_rays,
|
||||
device=rays_o.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
raymarching_test_kernel(rays_o, rays_d, hits_t, alive_indices,
|
||||
density_bitfield, cascades, grid_size, scale,
|
||||
exp_step_factor, N_samples, max_samples, xyzs,
|
||||
dirs, deltas, ts, N_eff_samples)
|
||||
|
||||
# ti.sync()
|
||||
|
||||
return xyzs, dirs, deltas, ts, N_eff_samples
|
||||
Reference in New Issue
Block a user