306 lines
11 KiB
Python
306 lines
11 KiB
Python
import numpy as np
|
|
import taichi as ti
|
|
import torch
|
|
from taichi.math import uvec3
|
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
|
|
|
from .utils import (data_type, ti2torch, ti2torch_grad, ti2torch_grad_vec,
|
|
ti2torch_vec, torch2ti, torch2ti_grad, torch2ti_grad_vec,
|
|
torch2ti_vec, torch_type)
|
|
|
|
half2 = ti.types.vector(n=2, dtype=ti.f16)
|
|
|
|
|
|
@ti.kernel
|
|
def random_initialize(data: ti.types.ndarray()):
|
|
for I in ti.grouped(data):
|
|
data[I] = (ti.random() * 2.0 - 1.0) * 1e-4
|
|
|
|
|
|
@ti.kernel
|
|
def ti_copy(data1: ti.template(), data2: ti.template()):
|
|
for I in ti.grouped(data1):
|
|
data1[I] = data2[I]
|
|
|
|
|
|
@ti.kernel
|
|
def ti_copy_array(data1: ti.types.ndarray(), data2: ti.types.ndarray()):
|
|
for I in ti.grouped(data1):
|
|
data1[I] = data2[I]
|
|
|
|
|
|
@ti.kernel
|
|
def ti_copy_field_array(data1: ti.template(), data2: ti.types.ndarray()):
|
|
for I in ti.grouped(data1):
|
|
data1[I] = data2[I]
|
|
|
|
|
|
@ti.func
|
|
def fast_hash(pos_grid_local):
|
|
result = ti.uint32(0)
|
|
# primes = uvec3(ti.uint32(1), ti.uint32(1958374283), ti.uint32(2654435761))
|
|
primes = uvec3(ti.uint32(1), ti.uint32(2654435761), ti.uint32(805459861))
|
|
for i in ti.static(range(3)):
|
|
result ^= ti.uint32(pos_grid_local[i]) * primes[i]
|
|
return result
|
|
|
|
|
|
@ti.func
|
|
def under_hash(pos_grid_local, resolution):
|
|
result = ti.uint32(0)
|
|
stride = ti.uint32(1)
|
|
for i in ti.static(range(3)):
|
|
result += ti.uint32(pos_grid_local[i] * stride)
|
|
stride *= resolution
|
|
return result
|
|
|
|
|
|
@ti.func
|
|
def grid_pos2hash_index(indicator, pos_grid_local, resolution, map_size):
|
|
hash_result = ti.uint32(0)
|
|
if indicator == 1:
|
|
hash_result = under_hash(pos_grid_local, resolution)
|
|
else:
|
|
hash_result = fast_hash(pos_grid_local)
|
|
|
|
return hash_result % map_size
|
|
|
|
|
|
@ti.kernel
|
|
def hash_encode_kernel(
|
|
xyzs: ti.template(), table: ti.template(),
|
|
xyzs_embedding: ti.template(), hash_map_indicator: ti.template(),
|
|
hash_map_sizes_field: ti.template(), offsets: ti.template(), B: ti.i32,
|
|
per_level_scale: ti.f32):
|
|
|
|
# get hash table embedding
|
|
ti.loop_config(block_dim=16)
|
|
for i, level in ti.ndrange(B, 16):
|
|
xyz = ti.Vector([xyzs[i, 0], xyzs[i, 1], xyzs[i, 2]])
|
|
|
|
scale = 16 * ti.exp(level * ti.log(per_level_scale)) - 1.0
|
|
resolution = ti.cast(ti.ceil(scale), ti.uint32) + 1
|
|
|
|
offset = offsets[level] * 2
|
|
|
|
pos = xyz * scale + 0.5
|
|
pos_grid_uint = ti.cast(ti.floor(pos), ti.uint32)
|
|
pos -= pos_grid_uint
|
|
|
|
indicator = hash_map_indicator[level]
|
|
map_size = hash_map_sizes_field[level]
|
|
|
|
local_feature_0 = 0.0
|
|
local_feature_1 = 0.0
|
|
|
|
for idx in ti.static(range(8)):
|
|
w = 1.
|
|
pos_grid_local = uvec3(0)
|
|
|
|
for d in ti.static(range(3)):
|
|
if (idx & (1 << d)) == 0:
|
|
pos_grid_local[d] = pos_grid_uint[d]
|
|
w *= 1 - pos[d]
|
|
else:
|
|
pos_grid_local[d] = pos_grid_uint[d] + 1
|
|
w *= pos[d]
|
|
|
|
index = grid_pos2hash_index(indicator, pos_grid_local, resolution,
|
|
map_size)
|
|
index_table = offset + index * 2
|
|
index_table_int = ti.cast(index_table, ti.int32)
|
|
local_feature_0 += w * table[index_table_int]
|
|
local_feature_1 += w * table[index_table_int + 1]
|
|
|
|
xyzs_embedding[i, level * 2] = local_feature_0
|
|
xyzs_embedding[i, level * 2 + 1] = local_feature_1
|
|
|
|
|
|
@ti.kernel
|
|
def hash_encode_kernel_half2(
|
|
xyzs: ti.template(), table: ti.template(),
|
|
xyzs_embedding: ti.template(), hash_map_indicator: ti.template(),
|
|
hash_map_sizes_field: ti.template(), offsets: ti.template(), B: ti.i32,
|
|
per_level_scale: ti.f16):
|
|
|
|
# get hash table embedding
|
|
ti.loop_config(block_dim=32)
|
|
for i, level in ti.ndrange(B, 16):
|
|
xyz = ti.Vector([xyzs[i, 0], xyzs[i, 1], xyzs[i, 2]])
|
|
|
|
scale = 16 * ti.exp(level * ti.log(per_level_scale)) - 1.0
|
|
resolution = ti.cast(ti.ceil(scale), ti.uint32) + 1
|
|
|
|
offset = offsets[level]
|
|
|
|
pos = xyz * scale + 0.5
|
|
pos_grid_uint = ti.cast(ti.floor(pos), ti.uint32)
|
|
pos -= pos_grid_uint
|
|
|
|
indicator = hash_map_indicator[level]
|
|
map_size = hash_map_sizes_field[level]
|
|
|
|
local_feature = half2(0.0)
|
|
for idx in ti.static(range(8)):
|
|
w = ti.f32(1.0)
|
|
pos_grid_local = uvec3(0)
|
|
|
|
for d in ti.static(range(3)):
|
|
if (idx & (1 << d)) == 0:
|
|
pos_grid_local[d] = pos_grid_uint[d]
|
|
w *= 1 - pos[d]
|
|
else:
|
|
pos_grid_local[d] = pos_grid_uint[d] + 1
|
|
w *= pos[d]
|
|
|
|
index = grid_pos2hash_index(indicator, pos_grid_local, resolution,
|
|
map_size)
|
|
|
|
index_table = offset + index
|
|
index_table_int = ti.cast(index_table, ti.int32)
|
|
|
|
local_feature += w * table[index_table_int]
|
|
xyzs_embedding[i, level] = local_feature
|
|
|
|
|
|
class HashEncoderTaichi(torch.nn.Module):
|
|
|
|
def __init__(self,
|
|
b=1.3195079565048218,
|
|
batch_size=8192,
|
|
data_type=data_type,
|
|
half2_opt=False):
|
|
super(HashEncoderTaichi, self).__init__()
|
|
|
|
self.per_level_scale = b
|
|
if batch_size < 2048:
|
|
batch_size = 2048
|
|
|
|
# per_level_scale = 1.3195079565048218
|
|
print("per_level_scale: ", b)
|
|
self.offsets = ti.field(ti.i32, shape=(16, ))
|
|
self.hash_map_sizes_field = ti.field(ti.uint32, shape=(16, ))
|
|
self.hash_map_indicator = ti.field(ti.i32, shape=(16, ))
|
|
base_res = 16
|
|
max_params = 2**19
|
|
offset_ = 0
|
|
hash_map_sizes = []
|
|
for i in range(16):
|
|
resolution = int(
|
|
np.ceil(base_res * np.exp(i * np.log(self.per_level_scale)) -
|
|
1.0)) + 1
|
|
params_in_level = resolution**3
|
|
params_in_level = int(resolution**
|
|
3) if params_in_level % 8 == 0 else int(
|
|
(params_in_level + 8 - 1) / 8) * 8
|
|
params_in_level = min(max_params, params_in_level)
|
|
self.offsets[i] = offset_
|
|
hash_map_sizes.append(params_in_level)
|
|
self.hash_map_indicator[
|
|
i] = 1 if resolution**3 <= params_in_level else 0
|
|
offset_ += params_in_level
|
|
print("offset_: ", offset_)
|
|
size = np.uint32(np.array(hash_map_sizes))
|
|
self.hash_map_sizes_field.from_numpy(size)
|
|
|
|
self.total_hash_size = offset_ * 2
|
|
print("total_hash_size: ", self.total_hash_size)
|
|
|
|
self.hash_table = torch.nn.Parameter(torch.zeros(self.total_hash_size,
|
|
dtype=torch_type),
|
|
requires_grad=True)
|
|
random_initialize(self.hash_table)
|
|
|
|
if half2_opt:
|
|
assert self.total_hash_size % 2 == 0
|
|
self.parameter_fields = half2.field(shape=(self.total_hash_size //
|
|
2, ),
|
|
needs_grad=True)
|
|
self.output_fields = half2.field(shape=(batch_size * 1024, 16),
|
|
needs_grad=True)
|
|
|
|
self.torch2ti = torch2ti_vec
|
|
self.ti2torch = ti2torch_vec
|
|
self.ti2torch_grad = ti2torch_grad_vec
|
|
self.torch2ti_grad = torch2ti_grad_vec
|
|
|
|
self._hash_encode_kernel = hash_encode_kernel_half2
|
|
else:
|
|
self.parameter_fields = ti.field(data_type,
|
|
shape=(self.total_hash_size, ),
|
|
needs_grad=True)
|
|
self.output_fields = ti.field(dtype=data_type,
|
|
shape=(batch_size * 1024, 32),
|
|
needs_grad=True)
|
|
self.torch2ti = torch2ti
|
|
self.ti2torch = ti2torch
|
|
self.ti2torch_grad = ti2torch_grad
|
|
self.torch2ti_grad = torch2ti_grad
|
|
|
|
self._hash_encode_kernel = hash_encode_kernel
|
|
|
|
self.input_fields = ti.field(dtype=data_type,
|
|
shape=(batch_size * 1024, 3),
|
|
needs_grad=True)
|
|
self.output_dim = 32 # the output dim: num levels (16) x level num (2)
|
|
self.register_buffer(
|
|
'hash_grad', torch.zeros(self.total_hash_size, dtype=torch_type))
|
|
self.register_buffer(
|
|
'output_embedding',
|
|
torch.zeros(batch_size * 1024, 32, dtype=torch_type))
|
|
|
|
class _module_function(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
@custom_fwd(cast_inputs=torch_type)
|
|
def forward(ctx, input_pos, params):
|
|
output_embedding = self.output_embedding[:input_pos.
|
|
shape[0]].contiguous(
|
|
)
|
|
torch2ti(self.input_fields, input_pos.contiguous())
|
|
self.torch2ti(self.parameter_fields, params.contiguous())
|
|
|
|
self._hash_encode_kernel(
|
|
self.input_fields,
|
|
self.parameter_fields,
|
|
self.output_fields,
|
|
self.hash_map_indicator,
|
|
self.hash_map_sizes_field,
|
|
self.offsets,
|
|
input_pos.shape[0],
|
|
self.per_level_scale,
|
|
)
|
|
self.ti2torch(self.output_fields, output_embedding)
|
|
|
|
return output_embedding
|
|
|
|
@staticmethod
|
|
@custom_bwd
|
|
def backward(ctx, doutput):
|
|
|
|
self.zero_grad()
|
|
|
|
self.torch2ti_grad(self.output_fields, doutput.contiguous())
|
|
self._hash_encode_kernel.grad(
|
|
self.input_fields,
|
|
self.parameter_fields,
|
|
self.output_fields,
|
|
self.hash_map_indicator,
|
|
self.hash_map_sizes_field,
|
|
self.offsets,
|
|
doutput.shape[0],
|
|
self.per_level_scale,
|
|
)
|
|
self.ti2torch_grad(self.parameter_fields,
|
|
self.hash_grad.contiguous())
|
|
return None, self.hash_grad
|
|
|
|
self._module_function = _module_function
|
|
|
|
def zero_grad(self):
|
|
self.parameter_fields.grad.fill(0.)
|
|
|
|
def forward(self, positions, bound=1):
|
|
positions = (positions + bound) / (2 * bound)
|
|
return self._module_function.apply(positions, self.hash_table)
|