first commit
This commit is contained in:
5
taichi_modules/__init__.py
Normal file
5
taichi_modules/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .ray_march import RayMarcherTaichi, raymarching_test
|
||||
from .volume_train import VolumeRendererTaichi
|
||||
from .intersection import RayAABBIntersector
|
||||
from .volume_render_test import composite_test
|
||||
from .utils import packbits
|
||||
305
taichi_modules/hash_encoder.py
Normal file
305
taichi_modules/hash_encoder.py
Normal file
@@ -0,0 +1,305 @@
|
||||
import numpy as np
|
||||
import taichi as ti
|
||||
import torch
|
||||
from taichi.math import uvec3
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from .utils import (data_type, ti2torch, ti2torch_grad, ti2torch_grad_vec,
|
||||
ti2torch_vec, torch2ti, torch2ti_grad, torch2ti_grad_vec,
|
||||
torch2ti_vec, torch_type)
|
||||
|
||||
half2 = ti.types.vector(n=2, dtype=ti.f16)
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def random_initialize(data: ti.types.ndarray()):
|
||||
for I in ti.grouped(data):
|
||||
data[I] = (ti.random() * 2.0 - 1.0) * 1e-4
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def ti_copy(data1: ti.template(), data2: ti.template()):
|
||||
for I in ti.grouped(data1):
|
||||
data1[I] = data2[I]
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def ti_copy_array(data1: ti.types.ndarray(), data2: ti.types.ndarray()):
|
||||
for I in ti.grouped(data1):
|
||||
data1[I] = data2[I]
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def ti_copy_field_array(data1: ti.template(), data2: ti.types.ndarray()):
|
||||
for I in ti.grouped(data1):
|
||||
data1[I] = data2[I]
|
||||
|
||||
|
||||
@ti.func
|
||||
def fast_hash(pos_grid_local):
|
||||
result = ti.uint32(0)
|
||||
# primes = uvec3(ti.uint32(1), ti.uint32(1958374283), ti.uint32(2654435761))
|
||||
primes = uvec3(ti.uint32(1), ti.uint32(2654435761), ti.uint32(805459861))
|
||||
for i in ti.static(range(3)):
|
||||
result ^= ti.uint32(pos_grid_local[i]) * primes[i]
|
||||
return result
|
||||
|
||||
|
||||
@ti.func
|
||||
def under_hash(pos_grid_local, resolution):
|
||||
result = ti.uint32(0)
|
||||
stride = ti.uint32(1)
|
||||
for i in ti.static(range(3)):
|
||||
result += ti.uint32(pos_grid_local[i] * stride)
|
||||
stride *= resolution
|
||||
return result
|
||||
|
||||
|
||||
@ti.func
|
||||
def grid_pos2hash_index(indicator, pos_grid_local, resolution, map_size):
|
||||
hash_result = ti.uint32(0)
|
||||
if indicator == 1:
|
||||
hash_result = under_hash(pos_grid_local, resolution)
|
||||
else:
|
||||
hash_result = fast_hash(pos_grid_local)
|
||||
|
||||
return hash_result % map_size
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def hash_encode_kernel(
|
||||
xyzs: ti.template(), table: ti.template(),
|
||||
xyzs_embedding: ti.template(), hash_map_indicator: ti.template(),
|
||||
hash_map_sizes_field: ti.template(), offsets: ti.template(), B: ti.i32,
|
||||
per_level_scale: ti.f32):
|
||||
|
||||
# get hash table embedding
|
||||
ti.loop_config(block_dim=16)
|
||||
for i, level in ti.ndrange(B, 16):
|
||||
xyz = ti.Vector([xyzs[i, 0], xyzs[i, 1], xyzs[i, 2]])
|
||||
|
||||
scale = 16 * ti.exp(level * ti.log(per_level_scale)) - 1.0
|
||||
resolution = ti.cast(ti.ceil(scale), ti.uint32) + 1
|
||||
|
||||
offset = offsets[level] * 2
|
||||
|
||||
pos = xyz * scale + 0.5
|
||||
pos_grid_uint = ti.cast(ti.floor(pos), ti.uint32)
|
||||
pos -= pos_grid_uint
|
||||
|
||||
indicator = hash_map_indicator[level]
|
||||
map_size = hash_map_sizes_field[level]
|
||||
|
||||
local_feature_0 = 0.0
|
||||
local_feature_1 = 0.0
|
||||
|
||||
for idx in ti.static(range(8)):
|
||||
w = 1.
|
||||
pos_grid_local = uvec3(0)
|
||||
|
||||
for d in ti.static(range(3)):
|
||||
if (idx & (1 << d)) == 0:
|
||||
pos_grid_local[d] = pos_grid_uint[d]
|
||||
w *= 1 - pos[d]
|
||||
else:
|
||||
pos_grid_local[d] = pos_grid_uint[d] + 1
|
||||
w *= pos[d]
|
||||
|
||||
index = grid_pos2hash_index(indicator, pos_grid_local, resolution,
|
||||
map_size)
|
||||
index_table = offset + index * 2
|
||||
index_table_int = ti.cast(index_table, ti.int32)
|
||||
local_feature_0 += w * table[index_table_int]
|
||||
local_feature_1 += w * table[index_table_int + 1]
|
||||
|
||||
xyzs_embedding[i, level * 2] = local_feature_0
|
||||
xyzs_embedding[i, level * 2 + 1] = local_feature_1
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def hash_encode_kernel_half2(
|
||||
xyzs: ti.template(), table: ti.template(),
|
||||
xyzs_embedding: ti.template(), hash_map_indicator: ti.template(),
|
||||
hash_map_sizes_field: ti.template(), offsets: ti.template(), B: ti.i32,
|
||||
per_level_scale: ti.f16):
|
||||
|
||||
# get hash table embedding
|
||||
ti.loop_config(block_dim=32)
|
||||
for i, level in ti.ndrange(B, 16):
|
||||
xyz = ti.Vector([xyzs[i, 0], xyzs[i, 1], xyzs[i, 2]])
|
||||
|
||||
scale = 16 * ti.exp(level * ti.log(per_level_scale)) - 1.0
|
||||
resolution = ti.cast(ti.ceil(scale), ti.uint32) + 1
|
||||
|
||||
offset = offsets[level]
|
||||
|
||||
pos = xyz * scale + 0.5
|
||||
pos_grid_uint = ti.cast(ti.floor(pos), ti.uint32)
|
||||
pos -= pos_grid_uint
|
||||
|
||||
indicator = hash_map_indicator[level]
|
||||
map_size = hash_map_sizes_field[level]
|
||||
|
||||
local_feature = half2(0.0)
|
||||
for idx in ti.static(range(8)):
|
||||
w = ti.f32(1.0)
|
||||
pos_grid_local = uvec3(0)
|
||||
|
||||
for d in ti.static(range(3)):
|
||||
if (idx & (1 << d)) == 0:
|
||||
pos_grid_local[d] = pos_grid_uint[d]
|
||||
w *= 1 - pos[d]
|
||||
else:
|
||||
pos_grid_local[d] = pos_grid_uint[d] + 1
|
||||
w *= pos[d]
|
||||
|
||||
index = grid_pos2hash_index(indicator, pos_grid_local, resolution,
|
||||
map_size)
|
||||
|
||||
index_table = offset + index
|
||||
index_table_int = ti.cast(index_table, ti.int32)
|
||||
|
||||
local_feature += w * table[index_table_int]
|
||||
xyzs_embedding[i, level] = local_feature
|
||||
|
||||
|
||||
class HashEncoderTaichi(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
b=1.3195079565048218,
|
||||
batch_size=8192,
|
||||
data_type=data_type,
|
||||
half2_opt=False):
|
||||
super(HashEncoderTaichi, self).__init__()
|
||||
|
||||
self.per_level_scale = b
|
||||
if batch_size < 2048:
|
||||
batch_size = 2048
|
||||
|
||||
# per_level_scale = 1.3195079565048218
|
||||
print("per_level_scale: ", b)
|
||||
self.offsets = ti.field(ti.i32, shape=(16, ))
|
||||
self.hash_map_sizes_field = ti.field(ti.uint32, shape=(16, ))
|
||||
self.hash_map_indicator = ti.field(ti.i32, shape=(16, ))
|
||||
base_res = 16
|
||||
max_params = 2**19
|
||||
offset_ = 0
|
||||
hash_map_sizes = []
|
||||
for i in range(16):
|
||||
resolution = int(
|
||||
np.ceil(base_res * np.exp(i * np.log(self.per_level_scale)) -
|
||||
1.0)) + 1
|
||||
params_in_level = resolution**3
|
||||
params_in_level = int(resolution**
|
||||
3) if params_in_level % 8 == 0 else int(
|
||||
(params_in_level + 8 - 1) / 8) * 8
|
||||
params_in_level = min(max_params, params_in_level)
|
||||
self.offsets[i] = offset_
|
||||
hash_map_sizes.append(params_in_level)
|
||||
self.hash_map_indicator[
|
||||
i] = 1 if resolution**3 <= params_in_level else 0
|
||||
offset_ += params_in_level
|
||||
print("offset_: ", offset_)
|
||||
size = np.uint32(np.array(hash_map_sizes))
|
||||
self.hash_map_sizes_field.from_numpy(size)
|
||||
|
||||
self.total_hash_size = offset_ * 2
|
||||
print("total_hash_size: ", self.total_hash_size)
|
||||
|
||||
self.hash_table = torch.nn.Parameter(torch.zeros(self.total_hash_size,
|
||||
dtype=torch_type),
|
||||
requires_grad=True)
|
||||
random_initialize(self.hash_table)
|
||||
|
||||
if half2_opt:
|
||||
assert self.total_hash_size % 2 == 0
|
||||
self.parameter_fields = half2.field(shape=(self.total_hash_size //
|
||||
2, ),
|
||||
needs_grad=True)
|
||||
self.output_fields = half2.field(shape=(batch_size * 1024, 16),
|
||||
needs_grad=True)
|
||||
|
||||
self.torch2ti = torch2ti_vec
|
||||
self.ti2torch = ti2torch_vec
|
||||
self.ti2torch_grad = ti2torch_grad_vec
|
||||
self.torch2ti_grad = torch2ti_grad_vec
|
||||
|
||||
self._hash_encode_kernel = hash_encode_kernel_half2
|
||||
else:
|
||||
self.parameter_fields = ti.field(data_type,
|
||||
shape=(self.total_hash_size, ),
|
||||
needs_grad=True)
|
||||
self.output_fields = ti.field(dtype=data_type,
|
||||
shape=(batch_size * 1024, 32),
|
||||
needs_grad=True)
|
||||
self.torch2ti = torch2ti
|
||||
self.ti2torch = ti2torch
|
||||
self.ti2torch_grad = ti2torch_grad
|
||||
self.torch2ti_grad = torch2ti_grad
|
||||
|
||||
self._hash_encode_kernel = hash_encode_kernel
|
||||
|
||||
self.input_fields = ti.field(dtype=data_type,
|
||||
shape=(batch_size * 1024, 3),
|
||||
needs_grad=True)
|
||||
self.output_dim = 32 # the output dim: num levels (16) x level num (2)
|
||||
self.register_buffer(
|
||||
'hash_grad', torch.zeros(self.total_hash_size, dtype=torch_type))
|
||||
self.register_buffer(
|
||||
'output_embedding',
|
||||
torch.zeros(batch_size * 1024, 32, dtype=torch_type))
|
||||
|
||||
class _module_function(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch_type)
|
||||
def forward(ctx, input_pos, params):
|
||||
output_embedding = self.output_embedding[:input_pos.
|
||||
shape[0]].contiguous(
|
||||
)
|
||||
torch2ti(self.input_fields, input_pos.contiguous())
|
||||
self.torch2ti(self.parameter_fields, params.contiguous())
|
||||
|
||||
self._hash_encode_kernel(
|
||||
self.input_fields,
|
||||
self.parameter_fields,
|
||||
self.output_fields,
|
||||
self.hash_map_indicator,
|
||||
self.hash_map_sizes_field,
|
||||
self.offsets,
|
||||
input_pos.shape[0],
|
||||
self.per_level_scale,
|
||||
)
|
||||
self.ti2torch(self.output_fields, output_embedding)
|
||||
|
||||
return output_embedding
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, doutput):
|
||||
|
||||
self.zero_grad()
|
||||
|
||||
self.torch2ti_grad(self.output_fields, doutput.contiguous())
|
||||
self._hash_encode_kernel.grad(
|
||||
self.input_fields,
|
||||
self.parameter_fields,
|
||||
self.output_fields,
|
||||
self.hash_map_indicator,
|
||||
self.hash_map_sizes_field,
|
||||
self.offsets,
|
||||
doutput.shape[0],
|
||||
self.per_level_scale,
|
||||
)
|
||||
self.ti2torch_grad(self.parameter_fields,
|
||||
self.hash_grad.contiguous())
|
||||
return None, self.hash_grad
|
||||
|
||||
self._module_function = _module_function
|
||||
|
||||
def zero_grad(self):
|
||||
self.parameter_fields.grad.fill(0.)
|
||||
|
||||
def forward(self, positions, bound=1):
|
||||
positions = (positions + bound) / (2 * bound)
|
||||
return self._module_function.apply(positions, self.hash_table)
|
||||
68
taichi_modules/intersection.py
Normal file
68
taichi_modules/intersection.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import taichi as ti
|
||||
import torch
|
||||
from taichi.math import vec3
|
||||
from torch.cuda.amp import custom_fwd
|
||||
|
||||
from .utils import NEAR_DISTANCE
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def simple_ray_aabb_intersec_taichi_forward(
|
||||
hits_t: ti.types.ndarray(ndim=2),
|
||||
rays_o: ti.types.ndarray(ndim=2),
|
||||
rays_d: ti.types.ndarray(ndim=2),
|
||||
centers: ti.types.ndarray(ndim=2),
|
||||
half_sizes: ti.types.ndarray(ndim=2)):
|
||||
|
||||
for r in ti.ndrange(hits_t.shape[0]):
|
||||
ray_o = vec3([rays_o[r, 0], rays_o[r, 1], rays_o[r, 2]])
|
||||
ray_d = vec3([rays_d[r, 0], rays_d[r, 1], rays_d[r, 2]])
|
||||
inv_d = 1.0 / ray_d
|
||||
|
||||
center = vec3([centers[0, 0], centers[0, 1], centers[0, 2]])
|
||||
half_size = vec3(
|
||||
[half_sizes[0, 0], half_sizes[0, 1], half_sizes[0, 1]])
|
||||
|
||||
t_min = (center - half_size - ray_o) * inv_d
|
||||
t_max = (center + half_size - ray_o) * inv_d
|
||||
|
||||
_t1 = ti.min(t_min, t_max)
|
||||
_t2 = ti.max(t_min, t_max)
|
||||
t1 = _t1.max()
|
||||
t2 = _t2.min()
|
||||
|
||||
if t2 > 0.0:
|
||||
hits_t[r, 0, 0] = ti.max(t1, NEAR_DISTANCE)
|
||||
hits_t[r, 0, 1] = t2
|
||||
|
||||
|
||||
class RayAABBIntersector(torch.autograd.Function):
|
||||
"""
|
||||
Computes the intersections of rays and axis-aligned voxels.
|
||||
|
||||
Inputs:
|
||||
rays_o: (N_rays, 3) ray origins
|
||||
rays_d: (N_rays, 3) ray directions
|
||||
centers: (N_voxels, 3) voxel centers
|
||||
half_sizes: (N_voxels, 3) voxel half sizes
|
||||
max_hits: maximum number of intersected voxels to keep for one ray
|
||||
(for a cubic scene, this is at most 3*N_voxels^(1/3)-2)
|
||||
|
||||
Outputs:
|
||||
hits_cnt: (N_rays) number of hits for each ray
|
||||
(followings are from near to far)
|
||||
hits_t: (N_rays, max_hits, 2) hit t's (-1 if no hit)
|
||||
hits_voxel_idx: (N_rays, max_hits) hit voxel indices (-1 if no hit)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, rays_o, rays_d, center, half_size, max_hits):
|
||||
hits_t = (torch.zeros(
|
||||
rays_o.size(0), 1, 2, device=rays_o.device, dtype=torch.float32) -
|
||||
1).contiguous()
|
||||
|
||||
simple_ray_aabb_intersec_taichi_forward(hits_t, rays_o, rays_d, center,
|
||||
half_size)
|
||||
|
||||
return None, hits_t, None
|
||||
340
taichi_modules/ray_march.py
Normal file
340
taichi_modules/ray_march.py
Normal file
@@ -0,0 +1,340 @@
|
||||
import taichi as ti
|
||||
import torch
|
||||
from taichi.math import vec3
|
||||
from torch.cuda.amp import custom_fwd
|
||||
|
||||
from .utils import __morton3D, calc_dt, mip_from_dt, mip_from_pos
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def raymarching_train(rays_o: ti.types.ndarray(ndim=2),
|
||||
rays_d: ti.types.ndarray(ndim=2),
|
||||
hits_t: ti.types.ndarray(ndim=2),
|
||||
density_bitfield: ti.types.ndarray(ndim=1),
|
||||
noise: ti.types.ndarray(ndim=1),
|
||||
counter: ti.types.ndarray(ndim=1),
|
||||
rays_a: ti.types.ndarray(ndim=2),
|
||||
xyzs: ti.types.ndarray(ndim=2),
|
||||
dirs: ti.types.ndarray(ndim=2),
|
||||
deltas: ti.types.ndarray(ndim=1),
|
||||
ts: ti.types.ndarray(ndim=1), cascades: int,
|
||||
grid_size: int, scale: float, exp_step_factor: float,
|
||||
max_samples: float):
|
||||
|
||||
# ti.loop_config(block_dim=256)
|
||||
for r in noise:
|
||||
ray_o = vec3(rays_o[r, 0], rays_o[r, 1], rays_o[r, 2])
|
||||
ray_d = vec3(rays_d[r, 0], rays_d[r, 1], rays_d[r, 2])
|
||||
d_inv = 1.0 / ray_d
|
||||
|
||||
t1, t2 = hits_t[r, 0], hits_t[r, 1]
|
||||
|
||||
grid_size3 = grid_size**3
|
||||
grid_size_inv = 1.0 / grid_size
|
||||
|
||||
if t1 >= 0:
|
||||
dt = calc_dt(t1, exp_step_factor, grid_size, scale)
|
||||
t1 += dt * noise[r]
|
||||
|
||||
t = t1
|
||||
N_samples = 0
|
||||
|
||||
while (0 <= t) & (t < t2) & (N_samples < max_samples):
|
||||
xyz = ray_o + t * ray_d
|
||||
dt = calc_dt(t, exp_step_factor, grid_size, scale)
|
||||
mip = ti.max(mip_from_pos(xyz, cascades),
|
||||
mip_from_dt(dt, grid_size, cascades))
|
||||
|
||||
# mip_bound = 0.5
|
||||
# mip_bound = ti.min(ti.pow(2., mip - 1), scale)
|
||||
mip_bound = scale
|
||||
mip_bound_inv = 1 / mip_bound
|
||||
|
||||
nxyz = ti.math.clamp(0.5 * (xyz * mip_bound_inv + 1) * grid_size,
|
||||
0.0, grid_size - 1.0)
|
||||
# nxyz = ti.ceil(nxyz)
|
||||
|
||||
idx = mip * grid_size3 + __morton3D(ti.cast(nxyz, ti.u32))
|
||||
occ = density_bitfield[ti.u32(idx // 8)] & (1 << ti.u32(idx % 8))
|
||||
# idx = __morton3D(ti.cast(nxyz, ti.uint32))
|
||||
# occ = density_bitfield[mip, idx//8] & (1 << ti.cast(idx%8, ti.uint32))
|
||||
|
||||
if occ:
|
||||
t += dt
|
||||
N_samples += 1
|
||||
else:
|
||||
# t += dt
|
||||
txyz = (((nxyz + 0.5 + 0.5 * ti.math.sign(ray_d)) *
|
||||
grid_size_inv * 2 - 1) * mip_bound - xyz) * d_inv
|
||||
|
||||
t_target = t + ti.max(0, txyz.min())
|
||||
t += calc_dt(t, exp_step_factor, grid_size, scale)
|
||||
while t < t_target:
|
||||
t += calc_dt(t, exp_step_factor, grid_size, scale)
|
||||
|
||||
start_idx = ti.atomic_add(counter[0], N_samples)
|
||||
ray_count = ti.atomic_add(counter[1], 1)
|
||||
|
||||
rays_a[ray_count, 0] = r
|
||||
rays_a[ray_count, 1] = start_idx
|
||||
rays_a[ray_count, 2] = N_samples
|
||||
|
||||
t = t1
|
||||
samples = 0
|
||||
|
||||
while (t < t2) & (samples < N_samples):
|
||||
xyz = ray_o + t * ray_d
|
||||
dt = calc_dt(t, exp_step_factor, grid_size, scale)
|
||||
mip = ti.max(mip_from_pos(xyz, cascades),
|
||||
mip_from_dt(dt, grid_size, cascades))
|
||||
|
||||
# mip_bound = 0.5
|
||||
# mip_bound = ti.min(ti.pow(2., mip - 1), scale)
|
||||
mip_bound = scale
|
||||
mip_bound_inv = 1 / mip_bound
|
||||
|
||||
nxyz = ti.math.clamp(0.5 * (xyz * mip_bound_inv + 1) * grid_size,
|
||||
0.0, grid_size - 1.0)
|
||||
# nxyz = ti.ceil(nxyz)
|
||||
|
||||
idx = mip * grid_size3 + __morton3D(ti.cast(nxyz, ti.u32))
|
||||
occ = density_bitfield[ti.u32(idx // 8)] & (1 << ti.u32(idx % 8))
|
||||
# idx = __morton3D(ti.cast(nxyz, ti.uint32))
|
||||
# occ = density_bitfield[mip, idx//8] & (1 << ti.cast(idx%8, ti.uint32))
|
||||
|
||||
if occ:
|
||||
s = start_idx + samples
|
||||
xyzs[s, 0] = xyz[0]
|
||||
xyzs[s, 1] = xyz[1]
|
||||
xyzs[s, 2] = xyz[2]
|
||||
dirs[s, 0] = ray_d[0]
|
||||
dirs[s, 1] = ray_d[1]
|
||||
dirs[s, 2] = ray_d[2]
|
||||
ts[s] = t
|
||||
deltas[s] = dt
|
||||
t += dt
|
||||
samples += 1
|
||||
else:
|
||||
# t += dt
|
||||
txyz = (((nxyz + 0.5 + 0.5 * ti.math.sign(ray_d)) *
|
||||
grid_size_inv * 2 - 1) * mip_bound - xyz) * d_inv
|
||||
|
||||
t_target = t + ti.max(0, txyz.min())
|
||||
t += calc_dt(t, exp_step_factor, grid_size, scale)
|
||||
while t < t_target:
|
||||
t += calc_dt(t, exp_step_factor, grid_size, scale)
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def raymarching_train_backword(segments: ti.types.ndarray(ndim=2),
|
||||
ts: ti.types.ndarray(ndim=1),
|
||||
dL_drays_o: ti.types.ndarray(ndim=2),
|
||||
dL_drays_d: ti.types.ndarray(ndim=2),
|
||||
dL_dxyzs: ti.types.ndarray(ndim=2),
|
||||
dL_ddirs: ti.types.ndarray(ndim=2)):
|
||||
|
||||
for s in segments:
|
||||
index = segments[s]
|
||||
dxyz = dL_dxyzs[index]
|
||||
ddir = dL_ddirs[index]
|
||||
|
||||
dL_drays_o[s] = dxyz
|
||||
dL_drays_d[s] = dxyz * ts[index] + ddir
|
||||
|
||||
|
||||
class RayMarcherTaichi(torch.nn.Module):
|
||||
|
||||
def __init__(self, batch_size=8192):
|
||||
super(RayMarcherTaichi, self).__init__()
|
||||
|
||||
self.register_buffer('rays_a',
|
||||
torch.zeros(batch_size, 3, dtype=torch.int32))
|
||||
self.register_buffer(
|
||||
'xyzs', torch.zeros(batch_size * 1024, 3, dtype=torch.float32))
|
||||
self.register_buffer(
|
||||
'dirs', torch.zeros(batch_size * 1024, 3, dtype=torch.float32))
|
||||
self.register_buffer(
|
||||
'deltas', torch.zeros(batch_size * 1024, dtype=torch.float32))
|
||||
self.register_buffer(
|
||||
'ts', torch.zeros(batch_size * 1024, dtype=torch.float32))
|
||||
|
||||
# self.register_buffer('dL_drays_o', torch.zeros(batch_size, dtype=torch.float32))
|
||||
# self.register_buffer('dL_drays_d', torch.zeros(batch_size, dtype=torch.float32))
|
||||
|
||||
class _module_function(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, rays_o, rays_d, hits_t, density_bitfield,
|
||||
cascades, scale, exp_step_factor, grid_size,
|
||||
max_samples):
|
||||
# noise to perturb the first sample of each ray
|
||||
noise = torch.rand_like(rays_o[:, 0])
|
||||
counter = torch.zeros(2,
|
||||
device=rays_o.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
raymarching_train(\
|
||||
rays_o, rays_d,
|
||||
hits_t.contiguous(),
|
||||
density_bitfield, noise, counter,
|
||||
self.rays_a.contiguous(),
|
||||
self.xyzs.contiguous(),
|
||||
self.dirs.contiguous(),
|
||||
self.deltas.contiguous(),
|
||||
self.ts.contiguous(),
|
||||
cascades, grid_size, scale,
|
||||
exp_step_factor, max_samples)
|
||||
|
||||
# ti.sync()
|
||||
|
||||
total_samples = counter[0] # total samples for all rays
|
||||
# remove redundant output
|
||||
xyzs = self.xyzs[:total_samples]
|
||||
dirs = self.dirs[:total_samples]
|
||||
deltas = self.deltas[:total_samples]
|
||||
ts = self.ts[:total_samples]
|
||||
|
||||
return self.rays_a, xyzs, dirs, deltas, ts, total_samples
|
||||
|
||||
# @staticmethod
|
||||
# @custom_bwd
|
||||
# def backward(ctx, dL_drays_a, dL_dxyzs, dL_ddirs, dL_ddeltas, dL_dts,
|
||||
# dL_dtotal_samples):
|
||||
# rays_a, ts = ctx.saved_tensors
|
||||
# # rays_a = rays_a.contiguous()
|
||||
# ts = ts.contiguous()
|
||||
# segments = torch.cat([rays_a[:, 1], rays_a[-1:, 1] + rays_a[-1:, 2]])
|
||||
# dL_drays_o = torch.zeros_like(rays_a[:, 0])
|
||||
# dL_drays_d = torch.zeros_like(rays_a[:, 0])
|
||||
# raymarching_train_backword(segments.contiguous(), ts, dL_drays_o,
|
||||
# dL_drays_d, dL_dxyzs, dL_ddirs)
|
||||
# # ti.sync()
|
||||
# # dL_drays_o = segment_csr(dL_dxyzs, segments)
|
||||
# # dL_drays_d = \
|
||||
# # segment_csr(dL_dxyzs*rearrange(ts, 'n -> n 1')+dL_ddirs, segments)
|
||||
|
||||
# return dL_drays_o, dL_drays_d, None, None, None, None, None, None, None
|
||||
|
||||
self._module_function = _module_function
|
||||
|
||||
def forward(self, rays_o, rays_d, hits_t, density_bitfield, cascades,
|
||||
scale, exp_step_factor, grid_size, max_samples):
|
||||
return self._module_function.apply(rays_o, rays_d, hits_t,
|
||||
density_bitfield, cascades, scale,
|
||||
exp_step_factor, grid_size,
|
||||
max_samples)
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def raymarching_test_kernel(
|
||||
rays_o: ti.types.ndarray(ndim=2),
|
||||
rays_d: ti.types.ndarray(ndim=2),
|
||||
hits_t: ti.types.ndarray(ndim=2),
|
||||
alive_indices: ti.types.ndarray(ndim=1),
|
||||
density_bitfield: ti.types.ndarray(ndim=1),
|
||||
cascades: int,
|
||||
grid_size: int,
|
||||
scale: float,
|
||||
exp_step_factor: float,
|
||||
N_samples: int,
|
||||
max_samples: int,
|
||||
xyzs: ti.types.ndarray(ndim=2),
|
||||
dirs: ti.types.ndarray(ndim=2),
|
||||
deltas: ti.types.ndarray(ndim=1),
|
||||
ts: ti.types.ndarray(ndim=1),
|
||||
N_eff_samples: ti.types.ndarray(ndim=1),
|
||||
):
|
||||
|
||||
for n in alive_indices:
|
||||
r = alive_indices[n]
|
||||
grid_size3 = grid_size**3
|
||||
grid_size_inv = 1.0 / grid_size
|
||||
|
||||
ray_o = vec3(rays_o[r, 0], rays_o[r, 1], rays_o[r, 2])
|
||||
ray_d = vec3(rays_d[r, 0], rays_d[r, 1], rays_d[r, 2])
|
||||
d_inv = 1.0 / ray_d
|
||||
|
||||
t = hits_t[r, 0]
|
||||
t2 = hits_t[r, 1]
|
||||
|
||||
s = 0
|
||||
|
||||
while (0 <= t) & (t < t2) & (s < N_samples):
|
||||
xyz = ray_o + t * ray_d
|
||||
dt = calc_dt(t, exp_step_factor, grid_size, scale)
|
||||
mip = ti.max(mip_from_pos(xyz, cascades),
|
||||
mip_from_dt(dt, grid_size, cascades))
|
||||
|
||||
# mip_bound = 0.5
|
||||
# mip_bound = ti.min(ti.pow(2., mip - 1), scale)
|
||||
mip_bound = scale
|
||||
mip_bound_inv = 1 / mip_bound
|
||||
|
||||
nxyz = ti.math.clamp(0.5 * (xyz * mip_bound_inv + 1) * grid_size,
|
||||
0.0, grid_size - 1.0)
|
||||
# nxyz = ti.ceil(nxyz)
|
||||
|
||||
idx = mip * grid_size3 + __morton3D(ti.cast(nxyz, ti.u32))
|
||||
occ = density_bitfield[ti.u32(idx // 8)] & (1 << ti.u32(idx % 8))
|
||||
|
||||
if occ:
|
||||
xyzs[n, s, 0] = xyz[0]
|
||||
xyzs[n, s, 1] = xyz[1]
|
||||
xyzs[n, s, 2] = xyz[2]
|
||||
dirs[n, s, 0] = ray_d[0]
|
||||
dirs[n, s, 1] = ray_d[1]
|
||||
dirs[n, s, 2] = ray_d[2]
|
||||
ts[n, s] = t
|
||||
deltas[n, s] = dt
|
||||
t += dt
|
||||
hits_t[r, 0] = t
|
||||
s += 1
|
||||
|
||||
else:
|
||||
txyz = (((nxyz + 0.5 + 0.5 * ti.math.sign(ray_d)) *
|
||||
grid_size_inv * 2 - 1) * mip_bound - xyz) * d_inv
|
||||
|
||||
t_target = t + ti.max(0, txyz.min())
|
||||
t += calc_dt(t, exp_step_factor, grid_size, scale)
|
||||
while t < t_target:
|
||||
t += calc_dt(t, exp_step_factor, grid_size, scale)
|
||||
|
||||
N_eff_samples[n] = s
|
||||
|
||||
|
||||
def raymarching_test(rays_o, rays_d, hits_t, alive_indices, density_bitfield,
|
||||
cascades, scale, exp_step_factor, grid_size, max_samples,
|
||||
N_samples):
|
||||
|
||||
N_rays = alive_indices.size(0)
|
||||
xyzs = torch.zeros(N_rays,
|
||||
N_samples,
|
||||
3,
|
||||
device=rays_o.device,
|
||||
dtype=rays_o.dtype)
|
||||
dirs = torch.zeros(N_rays,
|
||||
N_samples,
|
||||
3,
|
||||
device=rays_o.device,
|
||||
dtype=rays_o.dtype)
|
||||
deltas = torch.zeros(N_rays,
|
||||
N_samples,
|
||||
device=rays_o.device,
|
||||
dtype=rays_o.dtype)
|
||||
ts = torch.zeros(N_rays,
|
||||
N_samples,
|
||||
device=rays_o.device,
|
||||
dtype=rays_o.dtype)
|
||||
N_eff_samples = torch.zeros(N_rays,
|
||||
device=rays_o.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
raymarching_test_kernel(rays_o, rays_d, hits_t, alive_indices,
|
||||
density_bitfield, cascades, grid_size, scale,
|
||||
exp_step_factor, N_samples, max_samples, xyzs,
|
||||
dirs, deltas, ts, N_eff_samples)
|
||||
|
||||
# ti.sync()
|
||||
|
||||
return xyzs, dirs, deltas, ts, N_eff_samples
|
||||
224
taichi_modules/utils.py
Normal file
224
taichi_modules/utils.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import taichi as ti
|
||||
import torch
|
||||
from taichi.math import uvec3
|
||||
|
||||
taichi_block_size = 128
|
||||
|
||||
data_type = ti.f32
|
||||
torch_type = torch.float32
|
||||
|
||||
MAX_SAMPLES = 1024
|
||||
NEAR_DISTANCE = 0.01
|
||||
SQRT3 = 1.7320508075688772
|
||||
SQRT3_MAX_SAMPLES = SQRT3 / 1024
|
||||
SQRT3_2 = 1.7320508075688772 * 2
|
||||
|
||||
|
||||
@ti.func
|
||||
def scalbn(x, exponent):
|
||||
return x * ti.math.pow(2, exponent)
|
||||
|
||||
|
||||
@ti.func
|
||||
def calc_dt(t, exp_step_factor, grid_size, scale):
|
||||
return ti.math.clamp(t * exp_step_factor, SQRT3_MAX_SAMPLES,
|
||||
SQRT3_2 * scale / grid_size)
|
||||
|
||||
|
||||
@ti.func
|
||||
def frexp_bit(x):
|
||||
exponent = 0
|
||||
if x != 0.0:
|
||||
# frac = ti.abs(x)
|
||||
bits = ti.bit_cast(x, ti.u32)
|
||||
exponent = ti.i32((bits & ti.u32(0x7f800000)) >> 23) - 127
|
||||
# exponent = (ti.i32(bits & ti.u32(0x7f800000)) >> 23) - 127
|
||||
bits &= ti.u32(0x7fffff)
|
||||
bits |= ti.u32(0x3f800000)
|
||||
frac = ti.bit_cast(bits, ti.f32)
|
||||
if frac < 0.5:
|
||||
exponent -= 1
|
||||
elif frac > 1.0:
|
||||
exponent += 1
|
||||
return exponent
|
||||
|
||||
|
||||
@ti.func
|
||||
def mip_from_pos(xyz, cascades):
|
||||
mx = ti.abs(xyz).max()
|
||||
# _, exponent = _frexp(mx)
|
||||
exponent = frexp_bit(ti.f32(mx)) + 1
|
||||
# frac, exponent = ti.frexp(ti.f32(mx))
|
||||
return ti.min(cascades - 1, ti.max(0, exponent))
|
||||
|
||||
|
||||
@ti.func
|
||||
def mip_from_dt(dt, grid_size, cascades):
|
||||
# _, exponent = _frexp(dt*grid_size)
|
||||
exponent = frexp_bit(ti.f32(dt * grid_size))
|
||||
# frac, exponent = ti.frexp(ti.f32(dt*grid_size))
|
||||
return ti.min(cascades - 1, ti.max(0, exponent))
|
||||
|
||||
|
||||
@ti.func
|
||||
def __expand_bits(v):
|
||||
v = (v * ti.uint32(0x00010001)) & ti.uint32(0xFF0000FF)
|
||||
v = (v * ti.uint32(0x00000101)) & ti.uint32(0x0F00F00F)
|
||||
v = (v * ti.uint32(0x00000011)) & ti.uint32(0xC30C30C3)
|
||||
v = (v * ti.uint32(0x00000005)) & ti.uint32(0x49249249)
|
||||
return v
|
||||
|
||||
|
||||
@ti.func
|
||||
def __morton3D(xyz):
|
||||
xyz = __expand_bits(xyz)
|
||||
return xyz[0] | (xyz[1] << 1) | (xyz[2] << 2)
|
||||
|
||||
|
||||
@ti.func
|
||||
def __morton3D_invert(x):
|
||||
x = x & (0x49249249)
|
||||
x = (x | (x >> 2)) & ti.uint32(0xc30c30c3)
|
||||
x = (x | (x >> 4)) & ti.uint32(0x0f00f00f)
|
||||
x = (x | (x >> 8)) & ti.uint32(0xff0000ff)
|
||||
x = (x | (x >> 16)) & ti.uint32(0x0000ffff)
|
||||
return ti.int32(x)
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def morton3D_invert_kernel(indices: ti.types.ndarray(ndim=1),
|
||||
coords: ti.types.ndarray(ndim=2)):
|
||||
for i in indices:
|
||||
ind = ti.uint32(indices[i])
|
||||
coords[i, 0] = __morton3D_invert(ind >> 0)
|
||||
coords[i, 1] = __morton3D_invert(ind >> 1)
|
||||
coords[i, 2] = __morton3D_invert(ind >> 2)
|
||||
|
||||
|
||||
def morton3D_invert(indices):
|
||||
coords = torch.zeros(indices.size(0),
|
||||
3,
|
||||
device=indices.device,
|
||||
dtype=torch.int32)
|
||||
morton3D_invert_kernel(indices.contiguous(), coords)
|
||||
ti.sync()
|
||||
return coords
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def morton3D_kernel(xyzs: ti.types.ndarray(ndim=2),
|
||||
indices: ti.types.ndarray(ndim=1)):
|
||||
for s in indices:
|
||||
xyz = uvec3([xyzs[s, 0], xyzs[s, 1], xyzs[s, 2]])
|
||||
indices[s] = ti.cast(__morton3D(xyz), ti.int32)
|
||||
|
||||
|
||||
def morton3D(coords1):
|
||||
indices = torch.zeros(coords1.size(0),
|
||||
device=coords1.device,
|
||||
dtype=torch.int32)
|
||||
morton3D_kernel(coords1.contiguous(), indices)
|
||||
ti.sync()
|
||||
return indices
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def packbits(density_grid: ti.types.ndarray(ndim=1),
|
||||
density_threshold: float,
|
||||
density_bitfield: ti.types.ndarray(ndim=1)):
|
||||
|
||||
for n in density_bitfield:
|
||||
bits = ti.uint8(0)
|
||||
|
||||
for i in ti.static(range(8)):
|
||||
bits |= (ti.uint8(1) << i) if (
|
||||
density_grid[8 * n + i] > density_threshold) else ti.uint8(0)
|
||||
|
||||
density_bitfield[n] = bits
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def torch2ti(field: ti.template(), data: ti.types.ndarray()):
|
||||
for I in ti.grouped(data):
|
||||
field[I] = data[I]
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def ti2torch(field: ti.template(), data: ti.types.ndarray()):
|
||||
for I in ti.grouped(data):
|
||||
data[I] = field[I]
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def ti2torch_grad(field: ti.template(), grad: ti.types.ndarray()):
|
||||
for I in ti.grouped(grad):
|
||||
grad[I] = field.grad[I]
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def torch2ti_grad(field: ti.template(), grad: ti.types.ndarray()):
|
||||
for I in ti.grouped(grad):
|
||||
field.grad[I] = grad[I]
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def torch2ti_vec(field: ti.template(), data: ti.types.ndarray()):
|
||||
for I in range(data.shape[0] // 2):
|
||||
field[I] = ti.Vector([data[I * 2], data[I * 2 + 1]])
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def ti2torch_vec(field: ti.template(), data: ti.types.ndarray()):
|
||||
for i, j in ti.ndrange(data.shape[0], data.shape[1] // 2):
|
||||
data[i, j * 2] = field[i, j][0]
|
||||
data[i, j * 2 + 1] = field[i, j][1]
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def ti2torch_grad_vec(field: ti.template(), grad: ti.types.ndarray()):
|
||||
for I in range(grad.shape[0] // 2):
|
||||
grad[I * 2] = field.grad[I][0]
|
||||
grad[I * 2 + 1] = field.grad[I][1]
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def torch2ti_grad_vec(field: ti.template(), grad: ti.types.ndarray()):
|
||||
for i, j in ti.ndrange(grad.shape[0], grad.shape[1] // 2):
|
||||
field.grad[i, j][0] = grad[i, j * 2]
|
||||
field.grad[i, j][1] = grad[i, j * 2 + 1]
|
||||
|
||||
|
||||
def extract_model_state_dict(ckpt_path,
|
||||
model_name='model',
|
||||
prefixes_to_ignore=[]):
|
||||
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
||||
checkpoint_ = {}
|
||||
if 'state_dict' in checkpoint: # if it's a pytorch-lightning checkpoint
|
||||
checkpoint = checkpoint['state_dict']
|
||||
for k, v in checkpoint.items():
|
||||
if not k.startswith(model_name):
|
||||
continue
|
||||
k = k[len(model_name) + 1:]
|
||||
for prefix in prefixes_to_ignore:
|
||||
if k.startswith(prefix):
|
||||
break
|
||||
else:
|
||||
checkpoint_[k] = v
|
||||
return checkpoint_
|
||||
|
||||
|
||||
def load_ckpt(model, ckpt_path, model_name='model', prefixes_to_ignore=[]):
|
||||
if not ckpt_path:
|
||||
return
|
||||
model_dict = model.state_dict()
|
||||
checkpoint_ = extract_model_state_dict(ckpt_path, model_name,
|
||||
prefixes_to_ignore)
|
||||
model_dict.update(checkpoint_)
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
def depth2img(depth):
|
||||
depth = (depth - depth.min()) / (depth.max() - depth.min())
|
||||
depth_img = cv2.applyColorMap((depth * 255).astype(np.uint8),
|
||||
cv2.COLORMAP_TURBO)
|
||||
|
||||
return depth_img
|
||||
48
taichi_modules/volume_render_test.py
Normal file
48
taichi_modules/volume_render_test.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import taichi as ti
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def composite_test(
|
||||
sigmas: ti.types.ndarray(ndim=2), rgbs: ti.types.ndarray(ndim=3),
|
||||
deltas: ti.types.ndarray(ndim=2), ts: ti.types.ndarray(ndim=2),
|
||||
hits_t: ti.types.ndarray(ndim=2),
|
||||
alive_indices: ti.types.ndarray(ndim=1), T_threshold: float,
|
||||
N_eff_samples: ti.types.ndarray(ndim=1),
|
||||
opacity: ti.types.ndarray(ndim=1),
|
||||
depth: ti.types.ndarray(ndim=1), rgb: ti.types.ndarray(ndim=2)):
|
||||
|
||||
for n in alive_indices:
|
||||
samples = N_eff_samples[n]
|
||||
if samples == 0:
|
||||
alive_indices[n] = -1
|
||||
else:
|
||||
r = alive_indices[n]
|
||||
|
||||
T = 1 - opacity[r]
|
||||
|
||||
rgb_temp_0 = 0.0
|
||||
rgb_temp_1 = 0.0
|
||||
rgb_temp_2 = 0.0
|
||||
depth_temp = 0.0
|
||||
opacity_temp = 0.0
|
||||
|
||||
for s in range(samples):
|
||||
a = 1.0 - ti.exp(-sigmas[n, s] * deltas[n, s])
|
||||
w = a * T
|
||||
|
||||
rgb_temp_0 += w * rgbs[n, s, 0]
|
||||
rgb_temp_1 += w * rgbs[n, s, 1]
|
||||
rgb_temp_2 += w * rgbs[n, s, 2]
|
||||
depth[r] += w * ts[n, s]
|
||||
opacity[r] += w
|
||||
T *= 1.0 - a
|
||||
|
||||
if T <= T_threshold:
|
||||
alive_indices[n] = -1
|
||||
break
|
||||
|
||||
rgb[r, 0] += rgb_temp_0
|
||||
rgb[r, 1] += rgb_temp_1
|
||||
rgb[r, 2] += rgb_temp_2
|
||||
depth[r] += depth_temp
|
||||
opacity[r] += opacity_temp
|
||||
239
taichi_modules/volume_train.py
Normal file
239
taichi_modules/volume_train.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import taichi as ti
|
||||
import torch
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from .utils import (data_type, ti2torch, ti2torch_grad, torch2ti,
|
||||
torch2ti_grad, torch_type)
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def composite_train_fw_array(
|
||||
sigmas: ti.types.ndarray(),
|
||||
rgbs: ti.types.ndarray(),
|
||||
deltas: ti.types.ndarray(),
|
||||
ts: ti.types.ndarray(),
|
||||
rays_a: ti.types.ndarray(),
|
||||
T_threshold: float,
|
||||
total_samples: ti.types.ndarray(),
|
||||
opacity: ti.types.ndarray(),
|
||||
depth: ti.types.ndarray(),
|
||||
rgb: ti.types.ndarray(),
|
||||
ws: ti.types.ndarray(),
|
||||
):
|
||||
|
||||
for n in opacity:
|
||||
ray_idx = rays_a[n, 0]
|
||||
start_idx = rays_a[n, 1]
|
||||
N_samples = rays_a[n, 2]
|
||||
|
||||
T = 1.0
|
||||
samples = 0
|
||||
while samples < N_samples:
|
||||
s = start_idx + samples
|
||||
a = 1.0 - ti.exp(-sigmas[s] * deltas[s])
|
||||
w = a * T
|
||||
|
||||
rgb[ray_idx, 0] += w * rgbs[s, 0]
|
||||
rgb[ray_idx, 1] += w * rgbs[s, 1]
|
||||
rgb[ray_idx, 2] += w * rgbs[s, 2]
|
||||
depth[ray_idx] += w * ts[s]
|
||||
opacity[ray_idx] += w
|
||||
ws[s] = w
|
||||
T *= 1.0 - a
|
||||
|
||||
# if T<T_threshold:
|
||||
# break
|
||||
samples += 1
|
||||
|
||||
total_samples[ray_idx] = samples
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def composite_train_fw(sigmas: ti.template(), rgbs: ti.template(),
|
||||
deltas: ti.template(), ts: ti.template(),
|
||||
rays_a: ti.template(), T_threshold: float,
|
||||
T: ti.template(), total_samples: ti.template(),
|
||||
opacity: ti.template(), depth: ti.template(),
|
||||
rgb: ti.template(), ws: ti.template()):
|
||||
|
||||
ti.loop_config(block_dim=256)
|
||||
for n in opacity:
|
||||
ray_idx = ti.i32(rays_a[n, 0])
|
||||
start_idx = ti.i32(rays_a[n, 1])
|
||||
N_samples = ti.i32(rays_a[n, 2])
|
||||
|
||||
rgb[ray_idx, 0] = 0.0
|
||||
rgb[ray_idx, 1] = 0.0
|
||||
rgb[ray_idx, 2] = 0.0
|
||||
depth[ray_idx] = 0.0
|
||||
opacity[ray_idx] = 0.0
|
||||
total_samples[ray_idx] = 0
|
||||
|
||||
T[start_idx] = 1.0
|
||||
# T_ = 1.0
|
||||
# samples = 0
|
||||
# while samples<N_samples:
|
||||
for sample_ in range(N_samples):
|
||||
# T_ = T[ray_idx, samples]
|
||||
s = start_idx + sample_
|
||||
T_ = T[s]
|
||||
if T_ > T_threshold:
|
||||
# s = start_idx + sample_
|
||||
a = 1.0 - ti.exp(-sigmas[s] * deltas[s])
|
||||
w = a * T_
|
||||
rgb[ray_idx, 0] += w * rgbs[s, 0]
|
||||
rgb[ray_idx, 1] += w * rgbs[s, 1]
|
||||
rgb[ray_idx, 2] += w * rgbs[s, 2]
|
||||
depth[ray_idx] += w * ts[s]
|
||||
opacity[ray_idx] += w
|
||||
ws[s] = w
|
||||
# T_ *= (1.0-a)
|
||||
T[s + 1] = T_ * (1.0 - a)
|
||||
# if T[s+1]>=T_threshold:
|
||||
# samples += 1
|
||||
total_samples[ray_idx] += 1
|
||||
else:
|
||||
T[s + 1] = 0.0
|
||||
|
||||
# total_samples[ray_idx] = N_samples
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def check_value(
|
||||
fields: ti.template(),
|
||||
array: ti.types.ndarray(),
|
||||
checker: ti.types.ndarray(),
|
||||
):
|
||||
for I in ti.grouped(array):
|
||||
if fields[I] == array[I]:
|
||||
checker[I] = 1
|
||||
|
||||
|
||||
class VolumeRendererTaichi(torch.nn.Module):
|
||||
|
||||
def __init__(self, batch_size=8192, data_type=data_type):
|
||||
super(VolumeRendererTaichi, self).__init__()
|
||||
# samples level
|
||||
self.sigmas_fields = ti.field(dtype=data_type,
|
||||
shape=(batch_size * 1024, ),
|
||||
needs_grad=True)
|
||||
self.rgbs_fields = ti.field(dtype=data_type,
|
||||
shape=(batch_size * 1024, 3),
|
||||
needs_grad=True)
|
||||
self.deltas_fields = ti.field(dtype=data_type,
|
||||
shape=(batch_size * 1024, ),
|
||||
needs_grad=True)
|
||||
self.ts_fields = ti.field(dtype=data_type,
|
||||
shape=(batch_size * 1024, ),
|
||||
needs_grad=True)
|
||||
self.ws_fields = ti.field(dtype=data_type,
|
||||
shape=(batch_size * 1024, ),
|
||||
needs_grad=True)
|
||||
self.T = ti.field(dtype=data_type,
|
||||
shape=(batch_size * 1024),
|
||||
needs_grad=True)
|
||||
|
||||
# rays level
|
||||
self.rays_a_fields = ti.field(dtype=ti.i64, shape=(batch_size, 3))
|
||||
self.total_samples_fields = ti.field(dtype=ti.i64,
|
||||
shape=(batch_size, ))
|
||||
self.opacity_fields = ti.field(dtype=data_type,
|
||||
shape=(batch_size, ),
|
||||
needs_grad=True)
|
||||
self.depth_fields = ti.field(dtype=data_type,
|
||||
shape=(batch_size, ),
|
||||
needs_grad=True)
|
||||
self.rgb_fields = ti.field(dtype=data_type,
|
||||
shape=(batch_size, 3),
|
||||
needs_grad=True)
|
||||
|
||||
# preallocate tensor
|
||||
self.register_buffer('total_samples',
|
||||
torch.zeros(batch_size, dtype=torch.int64))
|
||||
self.register_buffer('rgb', torch.zeros(batch_size,
|
||||
3,
|
||||
dtype=torch_type))
|
||||
self.register_buffer('opacity',
|
||||
torch.zeros(batch_size, dtype=torch_type))
|
||||
self.register_buffer('depth', torch.zeros(batch_size,
|
||||
dtype=torch_type))
|
||||
self.register_buffer('ws',
|
||||
torch.zeros(batch_size * 1024, dtype=torch_type))
|
||||
|
||||
self.register_buffer('sigma_grad',
|
||||
torch.zeros(batch_size * 1024, dtype=torch_type))
|
||||
self.register_buffer(
|
||||
'rgb_grad', torch.zeros(batch_size * 1024, 3, dtype=torch_type))
|
||||
|
||||
class _module_function(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch_type)
|
||||
def forward(ctx, sigmas, rgbs, deltas, ts, rays_a, T_threshold):
|
||||
# If no output gradient is provided, no need to
|
||||
# automatically materialize it as torch.zeros.
|
||||
|
||||
ctx.T_threshold = T_threshold
|
||||
ctx.samples_size = sigmas.shape[0]
|
||||
|
||||
ws = self.ws[:sigmas.shape[0]]
|
||||
|
||||
torch2ti(self.sigmas_fields, sigmas.contiguous())
|
||||
torch2ti(self.rgbs_fields, rgbs.contiguous())
|
||||
torch2ti(self.deltas_fields, deltas.contiguous())
|
||||
torch2ti(self.ts_fields, ts.contiguous())
|
||||
torch2ti(self.rays_a_fields, rays_a.contiguous())
|
||||
composite_train_fw(self.sigmas_fields, self.rgbs_fields,
|
||||
self.deltas_fields, self.ts_fields,
|
||||
self.rays_a_fields, T_threshold, self.T,
|
||||
self.total_samples_fields,
|
||||
self.opacity_fields, self.depth_fields,
|
||||
self.rgb_fields, self.ws_fields)
|
||||
ti2torch(self.total_samples_fields, self.total_samples)
|
||||
ti2torch(self.opacity_fields, self.opacity)
|
||||
ti2torch(self.depth_fields, self.depth)
|
||||
ti2torch(self.rgb_fields, self.rgb)
|
||||
|
||||
|
||||
return self.total_samples.sum(
|
||||
), self.opacity, self.depth, self.rgb, ws
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, dL_dtotal_samples, dL_dopacity, dL_ddepth,
|
||||
dL_drgb, dL_dws):
|
||||
|
||||
T_threshold = ctx.T_threshold
|
||||
samples_size = ctx.samples_size
|
||||
|
||||
sigma_grad = self.sigma_grad[:samples_size].contiguous()
|
||||
rgb_grad = self.rgb_grad[:samples_size].contiguous()
|
||||
|
||||
self.zero_grad()
|
||||
|
||||
torch2ti_grad(self.opacity_fields, dL_dopacity.contiguous())
|
||||
torch2ti_grad(self.depth_fields, dL_ddepth.contiguous())
|
||||
torch2ti_grad(self.rgb_fields, dL_drgb.contiguous())
|
||||
torch2ti_grad(self.ws_fields, dL_dws.contiguous())
|
||||
composite_train_fw.grad(self.sigmas_fields, self.rgbs_fields,
|
||||
self.deltas_fields, self.ts_fields,
|
||||
self.rays_a_fields, T_threshold,
|
||||
self.T, self.total_samples_fields,
|
||||
self.opacity_fields, self.depth_fields,
|
||||
self.rgb_fields, self.ws_fields)
|
||||
ti2torch_grad(self.sigmas_fields, sigma_grad)
|
||||
ti2torch_grad(self.rgbs_fields, rgb_grad)
|
||||
|
||||
return sigma_grad, rgb_grad, None, None, None, None
|
||||
|
||||
self._module_function = _module_function
|
||||
|
||||
def zero_grad(self):
|
||||
self.sigmas_fields.grad.fill(0.)
|
||||
self.rgbs_fields.grad.fill(0.)
|
||||
self.T.grad.fill(0.)
|
||||
|
||||
|
||||
def forward(self, sigmas, rgbs, deltas, ts, rays_a, T_threshold):
|
||||
return self._module_function.apply(sigmas, rgbs, deltas, ts, rays_a,
|
||||
T_threshold)
|
||||
Reference in New Issue
Block a user