Files
Magic123/taichi_modules/hash_encoder.py
Guocheng Qian 13e18567fa first commit
2023-08-02 19:51:43 -07:00

306 lines
11 KiB
Python

import numpy as np
import taichi as ti
import torch
from taichi.math import uvec3
from torch.cuda.amp import custom_bwd, custom_fwd
from .utils import (data_type, ti2torch, ti2torch_grad, ti2torch_grad_vec,
ti2torch_vec, torch2ti, torch2ti_grad, torch2ti_grad_vec,
torch2ti_vec, torch_type)
half2 = ti.types.vector(n=2, dtype=ti.f16)
@ti.kernel
def random_initialize(data: ti.types.ndarray()):
for I in ti.grouped(data):
data[I] = (ti.random() * 2.0 - 1.0) * 1e-4
@ti.kernel
def ti_copy(data1: ti.template(), data2: ti.template()):
for I in ti.grouped(data1):
data1[I] = data2[I]
@ti.kernel
def ti_copy_array(data1: ti.types.ndarray(), data2: ti.types.ndarray()):
for I in ti.grouped(data1):
data1[I] = data2[I]
@ti.kernel
def ti_copy_field_array(data1: ti.template(), data2: ti.types.ndarray()):
for I in ti.grouped(data1):
data1[I] = data2[I]
@ti.func
def fast_hash(pos_grid_local):
result = ti.uint32(0)
# primes = uvec3(ti.uint32(1), ti.uint32(1958374283), ti.uint32(2654435761))
primes = uvec3(ti.uint32(1), ti.uint32(2654435761), ti.uint32(805459861))
for i in ti.static(range(3)):
result ^= ti.uint32(pos_grid_local[i]) * primes[i]
return result
@ti.func
def under_hash(pos_grid_local, resolution):
result = ti.uint32(0)
stride = ti.uint32(1)
for i in ti.static(range(3)):
result += ti.uint32(pos_grid_local[i] * stride)
stride *= resolution
return result
@ti.func
def grid_pos2hash_index(indicator, pos_grid_local, resolution, map_size):
hash_result = ti.uint32(0)
if indicator == 1:
hash_result = under_hash(pos_grid_local, resolution)
else:
hash_result = fast_hash(pos_grid_local)
return hash_result % map_size
@ti.kernel
def hash_encode_kernel(
xyzs: ti.template(), table: ti.template(),
xyzs_embedding: ti.template(), hash_map_indicator: ti.template(),
hash_map_sizes_field: ti.template(), offsets: ti.template(), B: ti.i32,
per_level_scale: ti.f32):
# get hash table embedding
ti.loop_config(block_dim=16)
for i, level in ti.ndrange(B, 16):
xyz = ti.Vector([xyzs[i, 0], xyzs[i, 1], xyzs[i, 2]])
scale = 16 * ti.exp(level * ti.log(per_level_scale)) - 1.0
resolution = ti.cast(ti.ceil(scale), ti.uint32) + 1
offset = offsets[level] * 2
pos = xyz * scale + 0.5
pos_grid_uint = ti.cast(ti.floor(pos), ti.uint32)
pos -= pos_grid_uint
indicator = hash_map_indicator[level]
map_size = hash_map_sizes_field[level]
local_feature_0 = 0.0
local_feature_1 = 0.0
for idx in ti.static(range(8)):
w = 1.
pos_grid_local = uvec3(0)
for d in ti.static(range(3)):
if (idx & (1 << d)) == 0:
pos_grid_local[d] = pos_grid_uint[d]
w *= 1 - pos[d]
else:
pos_grid_local[d] = pos_grid_uint[d] + 1
w *= pos[d]
index = grid_pos2hash_index(indicator, pos_grid_local, resolution,
map_size)
index_table = offset + index * 2
index_table_int = ti.cast(index_table, ti.int32)
local_feature_0 += w * table[index_table_int]
local_feature_1 += w * table[index_table_int + 1]
xyzs_embedding[i, level * 2] = local_feature_0
xyzs_embedding[i, level * 2 + 1] = local_feature_1
@ti.kernel
def hash_encode_kernel_half2(
xyzs: ti.template(), table: ti.template(),
xyzs_embedding: ti.template(), hash_map_indicator: ti.template(),
hash_map_sizes_field: ti.template(), offsets: ti.template(), B: ti.i32,
per_level_scale: ti.f16):
# get hash table embedding
ti.loop_config(block_dim=32)
for i, level in ti.ndrange(B, 16):
xyz = ti.Vector([xyzs[i, 0], xyzs[i, 1], xyzs[i, 2]])
scale = 16 * ti.exp(level * ti.log(per_level_scale)) - 1.0
resolution = ti.cast(ti.ceil(scale), ti.uint32) + 1
offset = offsets[level]
pos = xyz * scale + 0.5
pos_grid_uint = ti.cast(ti.floor(pos), ti.uint32)
pos -= pos_grid_uint
indicator = hash_map_indicator[level]
map_size = hash_map_sizes_field[level]
local_feature = half2(0.0)
for idx in ti.static(range(8)):
w = ti.f32(1.0)
pos_grid_local = uvec3(0)
for d in ti.static(range(3)):
if (idx & (1 << d)) == 0:
pos_grid_local[d] = pos_grid_uint[d]
w *= 1 - pos[d]
else:
pos_grid_local[d] = pos_grid_uint[d] + 1
w *= pos[d]
index = grid_pos2hash_index(indicator, pos_grid_local, resolution,
map_size)
index_table = offset + index
index_table_int = ti.cast(index_table, ti.int32)
local_feature += w * table[index_table_int]
xyzs_embedding[i, level] = local_feature
class HashEncoderTaichi(torch.nn.Module):
def __init__(self,
b=1.3195079565048218,
batch_size=8192,
data_type=data_type,
half2_opt=False):
super(HashEncoderTaichi, self).__init__()
self.per_level_scale = b
if batch_size < 2048:
batch_size = 2048
# per_level_scale = 1.3195079565048218
print("per_level_scale: ", b)
self.offsets = ti.field(ti.i32, shape=(16, ))
self.hash_map_sizes_field = ti.field(ti.uint32, shape=(16, ))
self.hash_map_indicator = ti.field(ti.i32, shape=(16, ))
base_res = 16
max_params = 2**19
offset_ = 0
hash_map_sizes = []
for i in range(16):
resolution = int(
np.ceil(base_res * np.exp(i * np.log(self.per_level_scale)) -
1.0)) + 1
params_in_level = resolution**3
params_in_level = int(resolution**
3) if params_in_level % 8 == 0 else int(
(params_in_level + 8 - 1) / 8) * 8
params_in_level = min(max_params, params_in_level)
self.offsets[i] = offset_
hash_map_sizes.append(params_in_level)
self.hash_map_indicator[
i] = 1 if resolution**3 <= params_in_level else 0
offset_ += params_in_level
print("offset_: ", offset_)
size = np.uint32(np.array(hash_map_sizes))
self.hash_map_sizes_field.from_numpy(size)
self.total_hash_size = offset_ * 2
print("total_hash_size: ", self.total_hash_size)
self.hash_table = torch.nn.Parameter(torch.zeros(self.total_hash_size,
dtype=torch_type),
requires_grad=True)
random_initialize(self.hash_table)
if half2_opt:
assert self.total_hash_size % 2 == 0
self.parameter_fields = half2.field(shape=(self.total_hash_size //
2, ),
needs_grad=True)
self.output_fields = half2.field(shape=(batch_size * 1024, 16),
needs_grad=True)
self.torch2ti = torch2ti_vec
self.ti2torch = ti2torch_vec
self.ti2torch_grad = ti2torch_grad_vec
self.torch2ti_grad = torch2ti_grad_vec
self._hash_encode_kernel = hash_encode_kernel_half2
else:
self.parameter_fields = ti.field(data_type,
shape=(self.total_hash_size, ),
needs_grad=True)
self.output_fields = ti.field(dtype=data_type,
shape=(batch_size * 1024, 32),
needs_grad=True)
self.torch2ti = torch2ti
self.ti2torch = ti2torch
self.ti2torch_grad = ti2torch_grad
self.torch2ti_grad = torch2ti_grad
self._hash_encode_kernel = hash_encode_kernel
self.input_fields = ti.field(dtype=data_type,
shape=(batch_size * 1024, 3),
needs_grad=True)
self.output_dim = 32 # the output dim: num levels (16) x level num (2)
self.register_buffer(
'hash_grad', torch.zeros(self.total_hash_size, dtype=torch_type))
self.register_buffer(
'output_embedding',
torch.zeros(batch_size * 1024, 32, dtype=torch_type))
class _module_function(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch_type)
def forward(ctx, input_pos, params):
output_embedding = self.output_embedding[:input_pos.
shape[0]].contiguous(
)
torch2ti(self.input_fields, input_pos.contiguous())
self.torch2ti(self.parameter_fields, params.contiguous())
self._hash_encode_kernel(
self.input_fields,
self.parameter_fields,
self.output_fields,
self.hash_map_indicator,
self.hash_map_sizes_field,
self.offsets,
input_pos.shape[0],
self.per_level_scale,
)
self.ti2torch(self.output_fields, output_embedding)
return output_embedding
@staticmethod
@custom_bwd
def backward(ctx, doutput):
self.zero_grad()
self.torch2ti_grad(self.output_fields, doutput.contiguous())
self._hash_encode_kernel.grad(
self.input_fields,
self.parameter_fields,
self.output_fields,
self.hash_map_indicator,
self.hash_map_sizes_field,
self.offsets,
doutput.shape[0],
self.per_level_scale,
)
self.ti2torch_grad(self.parameter_fields,
self.hash_grad.contiguous())
return None, self.hash_grad
self._module_function = _module_function
def zero_grad(self):
self.parameter_fields.grad.fill(0.)
def forward(self, positions, bound=1):
positions = (positions + bound) / (2 * bound)
return self._module_function.apply(positions, self.hash_table)