first commit
This commit is contained in:
305
taichi_modules/hash_encoder.py
Normal file
305
taichi_modules/hash_encoder.py
Normal file
@@ -0,0 +1,305 @@
|
||||
import numpy as np
|
||||
import taichi as ti
|
||||
import torch
|
||||
from taichi.math import uvec3
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from .utils import (data_type, ti2torch, ti2torch_grad, ti2torch_grad_vec,
|
||||
ti2torch_vec, torch2ti, torch2ti_grad, torch2ti_grad_vec,
|
||||
torch2ti_vec, torch_type)
|
||||
|
||||
half2 = ti.types.vector(n=2, dtype=ti.f16)
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def random_initialize(data: ti.types.ndarray()):
|
||||
for I in ti.grouped(data):
|
||||
data[I] = (ti.random() * 2.0 - 1.0) * 1e-4
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def ti_copy(data1: ti.template(), data2: ti.template()):
|
||||
for I in ti.grouped(data1):
|
||||
data1[I] = data2[I]
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def ti_copy_array(data1: ti.types.ndarray(), data2: ti.types.ndarray()):
|
||||
for I in ti.grouped(data1):
|
||||
data1[I] = data2[I]
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def ti_copy_field_array(data1: ti.template(), data2: ti.types.ndarray()):
|
||||
for I in ti.grouped(data1):
|
||||
data1[I] = data2[I]
|
||||
|
||||
|
||||
@ti.func
|
||||
def fast_hash(pos_grid_local):
|
||||
result = ti.uint32(0)
|
||||
# primes = uvec3(ti.uint32(1), ti.uint32(1958374283), ti.uint32(2654435761))
|
||||
primes = uvec3(ti.uint32(1), ti.uint32(2654435761), ti.uint32(805459861))
|
||||
for i in ti.static(range(3)):
|
||||
result ^= ti.uint32(pos_grid_local[i]) * primes[i]
|
||||
return result
|
||||
|
||||
|
||||
@ti.func
|
||||
def under_hash(pos_grid_local, resolution):
|
||||
result = ti.uint32(0)
|
||||
stride = ti.uint32(1)
|
||||
for i in ti.static(range(3)):
|
||||
result += ti.uint32(pos_grid_local[i] * stride)
|
||||
stride *= resolution
|
||||
return result
|
||||
|
||||
|
||||
@ti.func
|
||||
def grid_pos2hash_index(indicator, pos_grid_local, resolution, map_size):
|
||||
hash_result = ti.uint32(0)
|
||||
if indicator == 1:
|
||||
hash_result = under_hash(pos_grid_local, resolution)
|
||||
else:
|
||||
hash_result = fast_hash(pos_grid_local)
|
||||
|
||||
return hash_result % map_size
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def hash_encode_kernel(
|
||||
xyzs: ti.template(), table: ti.template(),
|
||||
xyzs_embedding: ti.template(), hash_map_indicator: ti.template(),
|
||||
hash_map_sizes_field: ti.template(), offsets: ti.template(), B: ti.i32,
|
||||
per_level_scale: ti.f32):
|
||||
|
||||
# get hash table embedding
|
||||
ti.loop_config(block_dim=16)
|
||||
for i, level in ti.ndrange(B, 16):
|
||||
xyz = ti.Vector([xyzs[i, 0], xyzs[i, 1], xyzs[i, 2]])
|
||||
|
||||
scale = 16 * ti.exp(level * ti.log(per_level_scale)) - 1.0
|
||||
resolution = ti.cast(ti.ceil(scale), ti.uint32) + 1
|
||||
|
||||
offset = offsets[level] * 2
|
||||
|
||||
pos = xyz * scale + 0.5
|
||||
pos_grid_uint = ti.cast(ti.floor(pos), ti.uint32)
|
||||
pos -= pos_grid_uint
|
||||
|
||||
indicator = hash_map_indicator[level]
|
||||
map_size = hash_map_sizes_field[level]
|
||||
|
||||
local_feature_0 = 0.0
|
||||
local_feature_1 = 0.0
|
||||
|
||||
for idx in ti.static(range(8)):
|
||||
w = 1.
|
||||
pos_grid_local = uvec3(0)
|
||||
|
||||
for d in ti.static(range(3)):
|
||||
if (idx & (1 << d)) == 0:
|
||||
pos_grid_local[d] = pos_grid_uint[d]
|
||||
w *= 1 - pos[d]
|
||||
else:
|
||||
pos_grid_local[d] = pos_grid_uint[d] + 1
|
||||
w *= pos[d]
|
||||
|
||||
index = grid_pos2hash_index(indicator, pos_grid_local, resolution,
|
||||
map_size)
|
||||
index_table = offset + index * 2
|
||||
index_table_int = ti.cast(index_table, ti.int32)
|
||||
local_feature_0 += w * table[index_table_int]
|
||||
local_feature_1 += w * table[index_table_int + 1]
|
||||
|
||||
xyzs_embedding[i, level * 2] = local_feature_0
|
||||
xyzs_embedding[i, level * 2 + 1] = local_feature_1
|
||||
|
||||
|
||||
@ti.kernel
|
||||
def hash_encode_kernel_half2(
|
||||
xyzs: ti.template(), table: ti.template(),
|
||||
xyzs_embedding: ti.template(), hash_map_indicator: ti.template(),
|
||||
hash_map_sizes_field: ti.template(), offsets: ti.template(), B: ti.i32,
|
||||
per_level_scale: ti.f16):
|
||||
|
||||
# get hash table embedding
|
||||
ti.loop_config(block_dim=32)
|
||||
for i, level in ti.ndrange(B, 16):
|
||||
xyz = ti.Vector([xyzs[i, 0], xyzs[i, 1], xyzs[i, 2]])
|
||||
|
||||
scale = 16 * ti.exp(level * ti.log(per_level_scale)) - 1.0
|
||||
resolution = ti.cast(ti.ceil(scale), ti.uint32) + 1
|
||||
|
||||
offset = offsets[level]
|
||||
|
||||
pos = xyz * scale + 0.5
|
||||
pos_grid_uint = ti.cast(ti.floor(pos), ti.uint32)
|
||||
pos -= pos_grid_uint
|
||||
|
||||
indicator = hash_map_indicator[level]
|
||||
map_size = hash_map_sizes_field[level]
|
||||
|
||||
local_feature = half2(0.0)
|
||||
for idx in ti.static(range(8)):
|
||||
w = ti.f32(1.0)
|
||||
pos_grid_local = uvec3(0)
|
||||
|
||||
for d in ti.static(range(3)):
|
||||
if (idx & (1 << d)) == 0:
|
||||
pos_grid_local[d] = pos_grid_uint[d]
|
||||
w *= 1 - pos[d]
|
||||
else:
|
||||
pos_grid_local[d] = pos_grid_uint[d] + 1
|
||||
w *= pos[d]
|
||||
|
||||
index = grid_pos2hash_index(indicator, pos_grid_local, resolution,
|
||||
map_size)
|
||||
|
||||
index_table = offset + index
|
||||
index_table_int = ti.cast(index_table, ti.int32)
|
||||
|
||||
local_feature += w * table[index_table_int]
|
||||
xyzs_embedding[i, level] = local_feature
|
||||
|
||||
|
||||
class HashEncoderTaichi(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
b=1.3195079565048218,
|
||||
batch_size=8192,
|
||||
data_type=data_type,
|
||||
half2_opt=False):
|
||||
super(HashEncoderTaichi, self).__init__()
|
||||
|
||||
self.per_level_scale = b
|
||||
if batch_size < 2048:
|
||||
batch_size = 2048
|
||||
|
||||
# per_level_scale = 1.3195079565048218
|
||||
print("per_level_scale: ", b)
|
||||
self.offsets = ti.field(ti.i32, shape=(16, ))
|
||||
self.hash_map_sizes_field = ti.field(ti.uint32, shape=(16, ))
|
||||
self.hash_map_indicator = ti.field(ti.i32, shape=(16, ))
|
||||
base_res = 16
|
||||
max_params = 2**19
|
||||
offset_ = 0
|
||||
hash_map_sizes = []
|
||||
for i in range(16):
|
||||
resolution = int(
|
||||
np.ceil(base_res * np.exp(i * np.log(self.per_level_scale)) -
|
||||
1.0)) + 1
|
||||
params_in_level = resolution**3
|
||||
params_in_level = int(resolution**
|
||||
3) if params_in_level % 8 == 0 else int(
|
||||
(params_in_level + 8 - 1) / 8) * 8
|
||||
params_in_level = min(max_params, params_in_level)
|
||||
self.offsets[i] = offset_
|
||||
hash_map_sizes.append(params_in_level)
|
||||
self.hash_map_indicator[
|
||||
i] = 1 if resolution**3 <= params_in_level else 0
|
||||
offset_ += params_in_level
|
||||
print("offset_: ", offset_)
|
||||
size = np.uint32(np.array(hash_map_sizes))
|
||||
self.hash_map_sizes_field.from_numpy(size)
|
||||
|
||||
self.total_hash_size = offset_ * 2
|
||||
print("total_hash_size: ", self.total_hash_size)
|
||||
|
||||
self.hash_table = torch.nn.Parameter(torch.zeros(self.total_hash_size,
|
||||
dtype=torch_type),
|
||||
requires_grad=True)
|
||||
random_initialize(self.hash_table)
|
||||
|
||||
if half2_opt:
|
||||
assert self.total_hash_size % 2 == 0
|
||||
self.parameter_fields = half2.field(shape=(self.total_hash_size //
|
||||
2, ),
|
||||
needs_grad=True)
|
||||
self.output_fields = half2.field(shape=(batch_size * 1024, 16),
|
||||
needs_grad=True)
|
||||
|
||||
self.torch2ti = torch2ti_vec
|
||||
self.ti2torch = ti2torch_vec
|
||||
self.ti2torch_grad = ti2torch_grad_vec
|
||||
self.torch2ti_grad = torch2ti_grad_vec
|
||||
|
||||
self._hash_encode_kernel = hash_encode_kernel_half2
|
||||
else:
|
||||
self.parameter_fields = ti.field(data_type,
|
||||
shape=(self.total_hash_size, ),
|
||||
needs_grad=True)
|
||||
self.output_fields = ti.field(dtype=data_type,
|
||||
shape=(batch_size * 1024, 32),
|
||||
needs_grad=True)
|
||||
self.torch2ti = torch2ti
|
||||
self.ti2torch = ti2torch
|
||||
self.ti2torch_grad = ti2torch_grad
|
||||
self.torch2ti_grad = torch2ti_grad
|
||||
|
||||
self._hash_encode_kernel = hash_encode_kernel
|
||||
|
||||
self.input_fields = ti.field(dtype=data_type,
|
||||
shape=(batch_size * 1024, 3),
|
||||
needs_grad=True)
|
||||
self.output_dim = 32 # the output dim: num levels (16) x level num (2)
|
||||
self.register_buffer(
|
||||
'hash_grad', torch.zeros(self.total_hash_size, dtype=torch_type))
|
||||
self.register_buffer(
|
||||
'output_embedding',
|
||||
torch.zeros(batch_size * 1024, 32, dtype=torch_type))
|
||||
|
||||
class _module_function(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch_type)
|
||||
def forward(ctx, input_pos, params):
|
||||
output_embedding = self.output_embedding[:input_pos.
|
||||
shape[0]].contiguous(
|
||||
)
|
||||
torch2ti(self.input_fields, input_pos.contiguous())
|
||||
self.torch2ti(self.parameter_fields, params.contiguous())
|
||||
|
||||
self._hash_encode_kernel(
|
||||
self.input_fields,
|
||||
self.parameter_fields,
|
||||
self.output_fields,
|
||||
self.hash_map_indicator,
|
||||
self.hash_map_sizes_field,
|
||||
self.offsets,
|
||||
input_pos.shape[0],
|
||||
self.per_level_scale,
|
||||
)
|
||||
self.ti2torch(self.output_fields, output_embedding)
|
||||
|
||||
return output_embedding
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, doutput):
|
||||
|
||||
self.zero_grad()
|
||||
|
||||
self.torch2ti_grad(self.output_fields, doutput.contiguous())
|
||||
self._hash_encode_kernel.grad(
|
||||
self.input_fields,
|
||||
self.parameter_fields,
|
||||
self.output_fields,
|
||||
self.hash_map_indicator,
|
||||
self.hash_map_sizes_field,
|
||||
self.offsets,
|
||||
doutput.shape[0],
|
||||
self.per_level_scale,
|
||||
)
|
||||
self.ti2torch_grad(self.parameter_fields,
|
||||
self.hash_grad.contiguous())
|
||||
return None, self.hash_grad
|
||||
|
||||
self._module_function = _module_function
|
||||
|
||||
def zero_grad(self):
|
||||
self.parameter_fields.grad.fill(0.)
|
||||
|
||||
def forward(self, positions, bound=1):
|
||||
positions = (positions + bound) / (2 * bound)
|
||||
return self._module_function.apply(positions, self.hash_table)
|
||||
Reference in New Issue
Block a user