first commit
This commit is contained in:
1
gridencoder/__init__.py
Normal file
1
gridencoder/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .grid import GridEncoder
|
||||
40
gridencoder/backend.py
Normal file
40
gridencoder/backend.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import os
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
nvcc_flags = [
|
||||
'-O3', '-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
||||
]
|
||||
|
||||
if os.name == "posix":
|
||||
c_flags = ['-O3', '-std=c++14']
|
||||
elif os.name == "nt":
|
||||
c_flags = ['/O2', '/std:c++17']
|
||||
|
||||
# find cl.exe
|
||||
def find_cl_path():
|
||||
import glob
|
||||
for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
|
||||
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
||||
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
|
||||
if paths:
|
||||
return paths[0]
|
||||
# If cl.exe is not on path, try to find it.
|
||||
if os.system("where cl.exe >nul 2>nul") != 0:
|
||||
cl_path = find_cl_path()
|
||||
if cl_path is None:
|
||||
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
||||
os.environ["PATH"] += ";" + cl_path
|
||||
|
||||
_backend = load(name='_grid_encoder',
|
||||
extra_cflags=c_flags,
|
||||
extra_cuda_cflags=nvcc_flags,
|
||||
sources=[os.path.join(_src_path, 'src', f) for f in [
|
||||
'gridencoder.cu',
|
||||
'bindings.cpp',
|
||||
]],
|
||||
)
|
||||
|
||||
__all__ = ['_backend']
|
||||
206
gridencoder/grid.py
Normal file
206
gridencoder/grid.py
Normal file
@@ -0,0 +1,206 @@
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
try:
|
||||
import _gridencoder as _backend
|
||||
except ImportError:
|
||||
from .backend import _backend
|
||||
|
||||
_gridtype_to_id = {
|
||||
'hash': 0,
|
||||
'tiled': 1,
|
||||
}
|
||||
|
||||
_interp_to_id = {
|
||||
'linear': 0,
|
||||
'smoothstep': 1,
|
||||
}
|
||||
|
||||
class _grid_encode(Function):
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False, interpolation=0, max_level=None):
|
||||
# inputs: [B, D], float in [0, 1]
|
||||
# embeddings: [sO, C], float
|
||||
# offsets: [L + 1], int
|
||||
# RETURN: [B, F], float
|
||||
|
||||
inputs = inputs.contiguous()
|
||||
|
||||
B, D = inputs.shape # batch size, coord dim
|
||||
L = offsets.shape[0] - 1 # level
|
||||
C = embeddings.shape[1] # embedding dim for each level
|
||||
S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
|
||||
H = base_resolution # base resolution
|
||||
|
||||
max_level = L if max_level is None else max(min(int(math.ceil(max_level * L)), L), 1)
|
||||
|
||||
# manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
|
||||
# if C % 2 != 0, force float, since half for atomicAdd is very slow.
|
||||
if torch.is_autocast_enabled() and C % 2 == 0:
|
||||
embeddings = embeddings.to(torch.half)
|
||||
|
||||
# L first, optimize cache for cuda kernel, but needs an extra permute later
|
||||
outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
|
||||
|
||||
# zero init if we only calculate partial levels
|
||||
if max_level < L: outputs.zero_()
|
||||
|
||||
if calc_grad_inputs:
|
||||
dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype)
|
||||
if max_level < L: dy_dx.zero_()
|
||||
else:
|
||||
dy_dx = None
|
||||
|
||||
_backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interpolation)
|
||||
|
||||
# permute back to [B, L * C]
|
||||
outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
|
||||
|
||||
ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
|
||||
ctx.dims = [B, D, C, L, S, H, gridtype, interpolation, max_level]
|
||||
ctx.align_corners = align_corners
|
||||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
#@once_differentiable
|
||||
@custom_bwd
|
||||
def backward(ctx, grad):
|
||||
|
||||
inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
|
||||
B, D, C, L, S, H, gridtype, interpolation, max_level = ctx.dims
|
||||
align_corners = ctx.align_corners
|
||||
|
||||
# grad: [B, L * C] --> [L, B, C]
|
||||
grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
|
||||
|
||||
grad_embeddings = torch.zeros_like(embeddings)
|
||||
|
||||
if dy_dx is not None:
|
||||
grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
|
||||
else:
|
||||
grad_inputs = None
|
||||
|
||||
_backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interpolation)
|
||||
|
||||
if dy_dx is not None:
|
||||
grad_inputs = grad_inputs.to(inputs.dtype)
|
||||
|
||||
return grad_inputs, grad_embeddings, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
|
||||
grid_encode = _grid_encode.apply
|
||||
|
||||
|
||||
class GridEncoder(nn.Module):
|
||||
def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False, interpolation='linear'):
|
||||
super().__init__()
|
||||
|
||||
# the finest resolution desired at the last level, if provided, overridee per_level_scale
|
||||
if desired_resolution is not None:
|
||||
per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))
|
||||
|
||||
self.input_dim = input_dim # coord dims, 2 or 3
|
||||
self.num_levels = num_levels # num levels, each level multiply resolution by 2
|
||||
self.level_dim = level_dim # encode channels per level
|
||||
self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
|
||||
self.log2_hashmap_size = log2_hashmap_size
|
||||
self.base_resolution = base_resolution
|
||||
self.output_dim = num_levels * level_dim
|
||||
self.gridtype = gridtype
|
||||
self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"
|
||||
self.interpolation = interpolation
|
||||
self.interp_id = _interp_to_id[interpolation] # "linear" or "smoothstep"
|
||||
self.align_corners = align_corners
|
||||
|
||||
# allocate parameters
|
||||
offsets = []
|
||||
offset = 0
|
||||
self.max_params = 2 ** log2_hashmap_size
|
||||
for i in range(num_levels):
|
||||
resolution = int(np.ceil(base_resolution * per_level_scale ** i))
|
||||
params_in_level = min(self.max_params, (resolution) ** input_dim) # limit max number
|
||||
params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
|
||||
offsets.append(offset)
|
||||
offset += params_in_level
|
||||
offsets.append(offset)
|
||||
offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
|
||||
self.register_buffer('offsets', offsets)
|
||||
|
||||
self.n_params = offsets[-1] * level_dim
|
||||
|
||||
# parameters
|
||||
self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
std = 1e-4
|
||||
self.embeddings.data.uniform_(-std, std)
|
||||
|
||||
def __repr__(self):
|
||||
return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}"
|
||||
|
||||
def forward(self, inputs, bound=1, max_level=None):
|
||||
# inputs: [..., input_dim], normalized real world positions in [-bound, bound]
|
||||
# max_level: only calculate first max_level levels (None will use all levels)
|
||||
# return: [..., num_levels * level_dim]
|
||||
|
||||
inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
|
||||
|
||||
#print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
|
||||
|
||||
prefix_shape = list(inputs.shape[:-1])
|
||||
inputs = inputs.view(-1, self.input_dim)
|
||||
|
||||
outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners, self.interp_id, max_level)
|
||||
outputs = outputs.view(prefix_shape + [self.output_dim])
|
||||
|
||||
#print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
|
||||
|
||||
return outputs
|
||||
|
||||
# always run in float precision!
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000):
|
||||
# inputs: [..., input_dim], float in [-b, b], location to calculate TV loss.
|
||||
|
||||
D = self.input_dim
|
||||
C = self.embeddings.shape[1] # embedding dim for each level
|
||||
L = self.offsets.shape[0] - 1 # level
|
||||
S = np.log2(self.per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
|
||||
H = self.base_resolution # base resolution
|
||||
|
||||
if inputs is None:
|
||||
# randomized in [0, 1]
|
||||
inputs = torch.rand(B, self.input_dim, device=self.embeddings.device)
|
||||
else:
|
||||
inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
|
||||
inputs = inputs.view(-1, self.input_dim)
|
||||
B = inputs.shape[0]
|
||||
|
||||
if self.embeddings.grad is None:
|
||||
raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!')
|
||||
|
||||
_backend.grad_total_variation(inputs, self.embeddings, self.embeddings.grad, self.offsets, weight, B, D, C, L, S, H, self.gridtype_id, self.align_corners)
|
||||
|
||||
@torch.cuda.amp.autocast(enabled=False)
|
||||
def grad_weight_decay(self, weight=0.1):
|
||||
# level-wise meaned weight decay (ref: zip-nerf)
|
||||
|
||||
B = self.embeddings.shape[0] # size of embedding
|
||||
C = self.embeddings.shape[1] # embedding dim for each level
|
||||
L = self.offsets.shape[0] - 1 # level
|
||||
|
||||
if self.embeddings.grad is None:
|
||||
raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!')
|
||||
|
||||
_backend.grad_weight_decay(self.embeddings, self.embeddings.grad, self.offsets, weight, B, C, L)
|
||||
51
gridencoder/setup.py
Normal file
51
gridencoder/setup.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import os
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
nvcc_flags = [
|
||||
'-O3', '-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
||||
]
|
||||
|
||||
if os.name == "posix":
|
||||
c_flags = ['-O3', '-std=c++14']
|
||||
elif os.name == "nt":
|
||||
c_flags = ['/O2', '/std:c++17']
|
||||
|
||||
# find cl.exe
|
||||
def find_cl_path():
|
||||
import glob
|
||||
for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
|
||||
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
||||
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
|
||||
if paths:
|
||||
return paths[0]
|
||||
|
||||
# If cl.exe is not on path, try to find it.
|
||||
if os.system("where cl.exe >nul 2>nul") != 0:
|
||||
cl_path = find_cl_path()
|
||||
if cl_path is None:
|
||||
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
||||
os.environ["PATH"] += ";" + cl_path
|
||||
|
||||
setup(
|
||||
name='gridencoder', # package name, import this to use python API
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name='_gridencoder', # extension name, import this to use CUDA API
|
||||
sources=[os.path.join(_src_path, 'src', f) for f in [
|
||||
'gridencoder.cu',
|
||||
'bindings.cpp',
|
||||
]],
|
||||
extra_compile_args={
|
||||
'cxx': c_flags,
|
||||
'nvcc': nvcc_flags,
|
||||
}
|
||||
),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension,
|
||||
}
|
||||
)
|
||||
10
gridencoder/src/bindings.cpp
Normal file
10
gridencoder/src/bindings.cpp
Normal file
@@ -0,0 +1,10 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "gridencoder.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
|
||||
m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
|
||||
m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)");
|
||||
m.def("grad_weight_decay", &grad_weight_decay, "grad_weight_decay (CUDA)");
|
||||
}
|
||||
713
gridencoder/src/gridencoder.cu
Normal file
713
gridencoder/src/gridencoder.cu
Normal file
@@ -0,0 +1,713 @@
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
|
||||
#include <stdint.h>
|
||||
#include <cstdio>
|
||||
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
|
||||
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
|
||||
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
|
||||
|
||||
|
||||
// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF... program will never reach here!
|
||||
__device__ inline at::Half atomicAdd(at::Half *address, at::Half val) {
|
||||
// requires CUDA >= 10 and ARCH >= 70
|
||||
// this is very slow compared to float or __half2, never use it.
|
||||
//return atomicAdd(reinterpret_cast<__half*>(address), val);
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ inline T div_round_up(T val, T divisor) {
|
||||
return (val + divisor - 1) / divisor;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline T smoothstep(T val) {
|
||||
return val*val*(3.0f - 2.0f * val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline T smoothstep_derivative(T val) {
|
||||
return 6*val*(1.0f - val);
|
||||
}
|
||||
|
||||
|
||||
template <uint32_t D>
|
||||
__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {
|
||||
|
||||
// coherent type of hashing
|
||||
constexpr uint32_t primes[7] = { 1u, 2654435761u, 805459861u, 3674653429u, 2097192037u, 1434869437u, 2165219737u };
|
||||
|
||||
uint32_t result = 0;
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < D; ++i) {
|
||||
result ^= pos_grid[i] * primes[i];
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
template <uint32_t D, uint32_t C>
|
||||
__device__ uint32_t get_grid_index(const uint32_t gridtype, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) {
|
||||
uint32_t stride = 1;
|
||||
uint32_t index = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
|
||||
index += pos_grid[d] * stride;
|
||||
stride *= resolution;
|
||||
}
|
||||
|
||||
// NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97.
|
||||
// gridtype: 0 == hash, 1 == tiled
|
||||
if (gridtype == 0 && stride > hashmap_size) {
|
||||
index = fast_hash<D>(pos_grid);
|
||||
}
|
||||
|
||||
return (index % hashmap_size) * C + ch;
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t, uint32_t D, uint32_t C>
|
||||
__global__ void kernel_grid(
|
||||
const float * __restrict__ inputs,
|
||||
const scalar_t * __restrict__ grid,
|
||||
const int * __restrict__ offsets,
|
||||
scalar_t * __restrict__ outputs,
|
||||
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
|
||||
scalar_t * __restrict__ dy_dx,
|
||||
const uint32_t gridtype,
|
||||
const bool align_corners,
|
||||
const uint32_t interp
|
||||
) {
|
||||
const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (b >= B) return;
|
||||
|
||||
const uint32_t level = blockIdx.y;
|
||||
|
||||
// locate
|
||||
grid += (uint32_t)offsets[level] * C;
|
||||
inputs += b * D;
|
||||
outputs += level * B * C + b * C;
|
||||
|
||||
// check input range (should be in [0, 1])
|
||||
bool flag_oob = false;
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
if (inputs[d] < 0 || inputs[d] > 1) {
|
||||
flag_oob = true;
|
||||
}
|
||||
}
|
||||
// if input out of bound, just set output to 0
|
||||
if (flag_oob) {
|
||||
#pragma unroll
|
||||
for (uint32_t ch = 0; ch < C; ch++) {
|
||||
outputs[ch] = 0;
|
||||
}
|
||||
if (dy_dx) {
|
||||
dy_dx += b * D * L * C + level * D * C; // B L D C
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
#pragma unroll
|
||||
for (uint32_t ch = 0; ch < C; ch++) {
|
||||
dy_dx[d * C + ch] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
|
||||
const uint32_t resolution = (uint32_t)ceil(exp2f(level * S) * H);
|
||||
|
||||
// calculate coordinate (always use float for precision!)
|
||||
float pos[D];
|
||||
float pos_deriv[D];
|
||||
uint32_t pos_grid[D];
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
|
||||
// align_corners
|
||||
if (align_corners) {
|
||||
pos[d] = inputs[d] * (float)(resolution - 1); // [0, resolution - 1]
|
||||
pos_grid[d] = min((uint32_t)floorf(pos[d]), resolution - 2); // left-top corner, [0, resolution - 2]
|
||||
} else {
|
||||
pos[d] = fminf(fmaxf(inputs[d] * (float)resolution - 0.5f, 0.0f), (float)(resolution - 1)); // [-0.5, resolution-0.5] --> [0, resolution - 1]
|
||||
pos_grid[d] = (uint32_t)floorf(pos[d]); // left-top corner, [0, resolution - 1]
|
||||
}
|
||||
pos[d] -= (float)pos_grid[d];
|
||||
|
||||
// smoothstep instead of linear
|
||||
if (interp == 1) {
|
||||
pos_deriv[d] = smoothstep_derivative(pos[d]);
|
||||
pos[d] = smoothstep(pos[d]);
|
||||
} else {
|
||||
pos_deriv[d] = 1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// verification of alignment
|
||||
// if (level == L - 1 && b < 4) {
|
||||
// printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
|
||||
// }
|
||||
|
||||
// interpolate
|
||||
scalar_t results[C] = {0}; // temp results in register
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t idx = 0; idx < (1 << D); idx++) {
|
||||
float w = 1;
|
||||
uint32_t pos_grid_local[D];
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
if ((idx & (1 << d)) == 0) {
|
||||
w *= 1 - pos[d];
|
||||
pos_grid_local[d] = pos_grid[d];
|
||||
} else {
|
||||
w *= pos[d];
|
||||
pos_grid_local[d] = min(pos_grid[d] + 1, resolution - 1);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t index = get_grid_index<D, C>(gridtype, 0, hashmap_size, resolution, pos_grid_local);
|
||||
|
||||
// writing to register (fast)
|
||||
#pragma unroll
|
||||
for (uint32_t ch = 0; ch < C; ch++) {
|
||||
results[ch] += w * grid[index + ch];
|
||||
}
|
||||
|
||||
//printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]);
|
||||
}
|
||||
|
||||
// writing to global memory (slow)
|
||||
#pragma unroll
|
||||
for (uint32_t ch = 0; ch < C; ch++) {
|
||||
outputs[ch] = results[ch];
|
||||
}
|
||||
|
||||
// prepare dy_dx
|
||||
// differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9
|
||||
if (dy_dx) {
|
||||
|
||||
dy_dx += b * D * L * C + level * D * C; // B L D C
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t gd = 0; gd < D; gd++) {
|
||||
|
||||
scalar_t results_grad[C] = {0};
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {
|
||||
float w = (float)(align_corners ? resolution - 1 : resolution);
|
||||
uint32_t pos_grid_local[D];
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t nd = 0; nd < D - 1; nd++) {
|
||||
const uint32_t d = (nd >= gd) ? (nd + 1) : nd;
|
||||
|
||||
if ((idx & (1 << nd)) == 0) {
|
||||
w *= 1 - pos[d];
|
||||
pos_grid_local[d] = pos_grid[d];
|
||||
} else {
|
||||
w *= pos[d];
|
||||
pos_grid_local[d] = min(pos_grid[d] + 1, resolution - 1);
|
||||
}
|
||||
}
|
||||
|
||||
pos_grid_local[gd] = pos_grid[gd];
|
||||
uint32_t index_left = get_grid_index<D, C>(gridtype, 0, hashmap_size, resolution, pos_grid_local);
|
||||
pos_grid_local[gd] = min(pos_grid[gd] + 1, resolution - 1);
|
||||
uint32_t index_right = get_grid_index<D, C>(gridtype, 0, hashmap_size, resolution, pos_grid_local);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t ch = 0; ch < C; ch++) {
|
||||
results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]) * pos_deriv[gd];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t ch = 0; ch < C; ch++) {
|
||||
dy_dx[gd * C + ch] = results_grad[ch];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t, uint32_t D, uint32_t C, uint32_t N_C>
|
||||
__global__ void kernel_grid_backward(
|
||||
const scalar_t * __restrict__ grad,
|
||||
const float * __restrict__ inputs,
|
||||
const scalar_t * __restrict__ grid,
|
||||
const int * __restrict__ offsets,
|
||||
scalar_t * __restrict__ grad_grid,
|
||||
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
|
||||
const uint32_t gridtype,
|
||||
const bool align_corners,
|
||||
const uint32_t interp
|
||||
) {
|
||||
const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;
|
||||
if (b >= B) return;
|
||||
|
||||
const uint32_t level = blockIdx.y;
|
||||
const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;
|
||||
|
||||
// locate
|
||||
grad_grid += offsets[level] * C;
|
||||
inputs += b * D;
|
||||
grad += level * B * C + b * C + ch; // L, B, C
|
||||
|
||||
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
|
||||
const uint32_t resolution = (uint32_t)ceil(exp2f(level * S) * H);
|
||||
|
||||
// check input range (should be in [0, 1])
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
if (inputs[d] < 0 || inputs[d] > 1) {
|
||||
return; // grad is init as 0, so we simply return.
|
||||
}
|
||||
}
|
||||
|
||||
// calculate coordinate
|
||||
float pos[D];
|
||||
uint32_t pos_grid[D];
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
// align_corners
|
||||
if (align_corners) {
|
||||
pos[d] = inputs[d] * (float)(resolution - 1); // [0, resolution - 1]
|
||||
pos_grid[d] = min((uint32_t)floorf(pos[d]), resolution - 2); // left-top corner, [0, resolution - 2]
|
||||
} else {
|
||||
pos[d] = fminf(fmaxf(inputs[d] * (float)resolution - 0.5f, 0.0f), (float)(resolution - 1)); // [-0.5, resolution-0.5] --> [0, resolution - 1]
|
||||
pos_grid[d] = (uint32_t)floorf(pos[d]); // left-top corner, [0, resolution - 1]
|
||||
}
|
||||
pos[d] -= (float)pos_grid[d];
|
||||
// smoothstep instead of linear
|
||||
if (interp == 1) {
|
||||
pos[d] = smoothstep(pos[d]);
|
||||
}
|
||||
}
|
||||
|
||||
scalar_t grad_cur[N_C] = {0}; // fetch to register
|
||||
#pragma unroll
|
||||
for (uint32_t c = 0; c < N_C; c++) {
|
||||
grad_cur[c] = grad[c];
|
||||
}
|
||||
|
||||
// interpolate
|
||||
#pragma unroll
|
||||
for (uint32_t idx = 0; idx < (1 << D); idx++) {
|
||||
float w = 1;
|
||||
uint32_t pos_grid_local[D];
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
if ((idx & (1 << d)) == 0) {
|
||||
w *= 1 - pos[d];
|
||||
pos_grid_local[d] = pos_grid[d];
|
||||
} else {
|
||||
w *= pos[d];
|
||||
pos_grid_local[d] = min(pos_grid[d] + 1, resolution - 1);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t index = get_grid_index<D, C>(gridtype, ch, hashmap_size, resolution, pos_grid_local);
|
||||
|
||||
// atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0
|
||||
// TODO: use float which is better than __half, if N_C % 2 != 0
|
||||
if (std::is_same<scalar_t, at::Half>::value && N_C % 2 == 0) {
|
||||
#pragma unroll
|
||||
for (uint32_t c = 0; c < N_C; c += 2) {
|
||||
// process two __half at once (by interpreting as a __half2)
|
||||
__half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};
|
||||
atomicAdd((__half2*)&grad_grid[index + c], v);
|
||||
}
|
||||
// float, or __half when N_C % 2 != 0 (which means C == 1)
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t c = 0; c < N_C; c++) {
|
||||
atomicAdd(&grad_grid[index + c], w * grad_cur[c]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t, uint32_t D, uint32_t C>
|
||||
__global__ void kernel_input_backward(
|
||||
const scalar_t * __restrict__ grad,
|
||||
const scalar_t * __restrict__ dy_dx,
|
||||
scalar_t * __restrict__ grad_inputs,
|
||||
uint32_t B, uint32_t L
|
||||
) {
|
||||
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (t >= B * D) return;
|
||||
|
||||
const uint32_t b = t / D;
|
||||
const uint32_t d = t - b * D;
|
||||
|
||||
dy_dx += b * L * D * C;
|
||||
|
||||
scalar_t result = 0;
|
||||
|
||||
# pragma unroll
|
||||
for (int l = 0; l < L; l++) {
|
||||
# pragma unroll
|
||||
for (int ch = 0; ch < C; ch++) {
|
||||
result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];
|
||||
}
|
||||
}
|
||||
|
||||
grad_inputs[t] = result;
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t, uint32_t D>
|
||||
void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
|
||||
static constexpr uint32_t N_THREAD = 512;
|
||||
const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), max_level, 1 };
|
||||
switch (C) {
|
||||
case 1: kernel_grid<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
|
||||
case 2: kernel_grid<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
|
||||
case 4: kernel_grid<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
|
||||
case 8: kernel_grid<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
|
||||
case 16: kernel_grid<scalar_t, D, 16><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
|
||||
case 32: kernel_grid<scalar_t, D, 32><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break;
|
||||
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, 8, 16 or 32."};
|
||||
}
|
||||
}
|
||||
|
||||
// inputs: [B, D], float, in [0, 1]
|
||||
// embeddings: [sO, C], float
|
||||
// offsets: [L + 1], uint32_t
|
||||
// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.)
|
||||
// H: base resolution
|
||||
// dy_dx: [B, L * D * C]
|
||||
template <typename scalar_t>
|
||||
void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
|
||||
switch (D) {
|
||||
case 2: kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break;
|
||||
case 3: kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break;
|
||||
case 4: kernel_grid_wrapper<scalar_t, 4>(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break;
|
||||
case 5: kernel_grid_wrapper<scalar_t, 5>(inputs, embeddings, offsets, outputs, B, C, L, max_level, S, H, dy_dx, gridtype, align_corners, interp); break;
|
||||
default: throw std::runtime_error{"GridEncoding: D must be 2, 3, 4 or 5."};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, uint32_t D>
|
||||
void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
|
||||
static constexpr uint32_t N_THREAD = 256;
|
||||
const uint32_t N_C = std::min(2u, C); // n_features_per_thread
|
||||
const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), max_level, 1 };
|
||||
switch (C) {
|
||||
case 1:
|
||||
kernel_grid_backward<scalar_t, D, 1, 1><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
|
||||
if (dy_dx) kernel_input_backward<scalar_t, D, 1><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
||||
break;
|
||||
case 2:
|
||||
kernel_grid_backward<scalar_t, D, 2, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
|
||||
if (dy_dx) kernel_input_backward<scalar_t, D, 2><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
||||
break;
|
||||
case 4:
|
||||
kernel_grid_backward<scalar_t, D, 4, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
|
||||
if (dy_dx) kernel_input_backward<scalar_t, D, 4><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
||||
break;
|
||||
case 8:
|
||||
kernel_grid_backward<scalar_t, D, 8, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
|
||||
if (dy_dx) kernel_input_backward<scalar_t, D, 8><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
||||
break;
|
||||
case 16:
|
||||
kernel_grid_backward<scalar_t, D, 16, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
|
||||
if (dy_dx) kernel_input_backward<scalar_t, D, 16><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
||||
break;
|
||||
case 32:
|
||||
kernel_grid_backward<scalar_t, D, 32, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp);
|
||||
if (dy_dx) kernel_input_backward<scalar_t, D, 32><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
|
||||
break;
|
||||
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, 8, 16 or 32."};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// grad: [L, B, C], float
|
||||
// inputs: [B, D], float, in [0, 1]
|
||||
// embeddings: [sO, C], float
|
||||
// offsets: [L + 1], uint32_t
|
||||
// grad_embeddings: [sO, C]
|
||||
// H: base resolution
|
||||
template <typename scalar_t>
|
||||
void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
|
||||
switch (D) {
|
||||
case 2: kernel_grid_backward_wrapper<scalar_t, 2>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
|
||||
case 3: kernel_grid_backward_wrapper<scalar_t, 3>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
|
||||
case 4: kernel_grid_backward_wrapper<scalar_t, 4>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
|
||||
case 5: kernel_grid_backward_wrapper<scalar_t, 5>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, max_level, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break;
|
||||
default: throw std::runtime_error{"GridEncoding: D must be 2, 3, 4 or 5."};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
|
||||
CHECK_CUDA(inputs);
|
||||
CHECK_CUDA(embeddings);
|
||||
CHECK_CUDA(offsets);
|
||||
CHECK_CUDA(outputs);
|
||||
// CHECK_CUDA(dy_dx);
|
||||
|
||||
CHECK_CONTIGUOUS(inputs);
|
||||
CHECK_CONTIGUOUS(embeddings);
|
||||
CHECK_CONTIGUOUS(offsets);
|
||||
CHECK_CONTIGUOUS(outputs);
|
||||
// CHECK_CONTIGUOUS(dy_dx);
|
||||
|
||||
CHECK_IS_FLOATING(inputs);
|
||||
CHECK_IS_FLOATING(embeddings);
|
||||
CHECK_IS_INT(offsets);
|
||||
CHECK_IS_FLOATING(outputs);
|
||||
// CHECK_IS_FLOATING(dy_dx);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
embeddings.scalar_type(), "grid_encode_forward", ([&] {
|
||||
grid_encode_forward_cuda<scalar_t>(inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), outputs.data_ptr<scalar_t>(), B, D, C, L, max_level, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners, interp);
|
||||
}));
|
||||
}
|
||||
|
||||
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) {
|
||||
CHECK_CUDA(grad);
|
||||
CHECK_CUDA(inputs);
|
||||
CHECK_CUDA(embeddings);
|
||||
CHECK_CUDA(offsets);
|
||||
CHECK_CUDA(grad_embeddings);
|
||||
// CHECK_CUDA(dy_dx);
|
||||
// CHECK_CUDA(grad_inputs);
|
||||
|
||||
CHECK_CONTIGUOUS(grad);
|
||||
CHECK_CONTIGUOUS(inputs);
|
||||
CHECK_CONTIGUOUS(embeddings);
|
||||
CHECK_CONTIGUOUS(offsets);
|
||||
CHECK_CONTIGUOUS(grad_embeddings);
|
||||
// CHECK_CONTIGUOUS(dy_dx);
|
||||
// CHECK_CONTIGUOUS(grad_inputs);
|
||||
|
||||
CHECK_IS_FLOATING(grad);
|
||||
CHECK_IS_FLOATING(inputs);
|
||||
CHECK_IS_FLOATING(embeddings);
|
||||
CHECK_IS_INT(offsets);
|
||||
CHECK_IS_FLOATING(grad_embeddings);
|
||||
// CHECK_IS_FLOATING(dy_dx);
|
||||
// CHECK_IS_FLOATING(grad_inputs);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
grad.scalar_type(), "grid_encode_backward", ([&] {
|
||||
grid_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), grad_embeddings.data_ptr<scalar_t>(), B, D, C, L, max_level, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners, interp);
|
||||
}));
|
||||
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t, uint32_t D, uint32_t C>
|
||||
__global__ void kernel_grad_tv(
|
||||
const scalar_t * __restrict__ inputs,
|
||||
const scalar_t * __restrict__ grid,
|
||||
scalar_t * __restrict__ grad,
|
||||
const int * __restrict__ offsets,
|
||||
const float weight,
|
||||
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
|
||||
const uint32_t gridtype,
|
||||
const bool align_corners
|
||||
) {
|
||||
const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (b >= B) return;
|
||||
|
||||
const uint32_t level = blockIdx.y;
|
||||
|
||||
// locate
|
||||
inputs += b * D;
|
||||
grid += (uint32_t)offsets[level] * C;
|
||||
grad += (uint32_t)offsets[level] * C;
|
||||
|
||||
// check input range (should be in [0, 1])
|
||||
bool flag_oob = false;
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
if (inputs[d] < 0 || inputs[d] > 1) {
|
||||
flag_oob = true;
|
||||
}
|
||||
}
|
||||
|
||||
// if input out of bound, do nothing
|
||||
if (flag_oob) return;
|
||||
|
||||
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
|
||||
const uint32_t resolution = (uint32_t)ceil(exp2f(level * S) * H);
|
||||
|
||||
// calculate coordinate
|
||||
float pos[D];
|
||||
uint32_t pos_grid[D]; // [0, resolution]
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
// align_corners
|
||||
if (align_corners) {
|
||||
pos[d] = inputs[d] * (float)(resolution - 1); // [0, resolution - 1]
|
||||
pos_grid[d] = min((uint32_t)floorf(pos[d]), resolution - 2); // left-top corner, [0, resolution - 2]
|
||||
} else {
|
||||
pos[d] = fminf(fmaxf(inputs[d] * (float)resolution - 0.5f, 0.0f), (float)(resolution - 1)); // [-0.5, resolution-0.5] --> [0, resolution - 1]
|
||||
pos_grid[d] = (uint32_t)floorf(pos[d]); // left-top corner, [0, resolution - 1]
|
||||
}
|
||||
}
|
||||
|
||||
//printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
|
||||
|
||||
// total variation on pos_grid
|
||||
scalar_t results[C] = {0}; // temp results in register
|
||||
scalar_t idelta[C] = {0};
|
||||
|
||||
uint32_t index = get_grid_index<D, C>(gridtype, 0, hashmap_size, resolution, pos_grid);
|
||||
|
||||
scalar_t w = weight / (2 * D);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t d = 0; d < D; d++) {
|
||||
|
||||
uint32_t cur_d = pos_grid[d];
|
||||
scalar_t grad_val;
|
||||
|
||||
// right side
|
||||
if (cur_d < resolution) {
|
||||
pos_grid[d] = cur_d + 1;
|
||||
uint32_t index_right = get_grid_index<D, C>(gridtype, 0, hashmap_size, resolution, pos_grid);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t ch = 0; ch < C; ch++) {
|
||||
grad_val = (grid[index + ch] - grid[index_right + ch]);
|
||||
results[ch] += grad_val;
|
||||
idelta[ch] += grad_val * grad_val;
|
||||
}
|
||||
}
|
||||
|
||||
// left side
|
||||
if (cur_d > 0) {
|
||||
pos_grid[d] = cur_d - 1;
|
||||
uint32_t index_left = get_grid_index<D, C>(gridtype, 0, hashmap_size, resolution, pos_grid);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t ch = 0; ch < C; ch++) {
|
||||
grad_val = (grid[index + ch] - grid[index_left + ch]);
|
||||
results[ch] += grad_val;
|
||||
idelta[ch] += grad_val * grad_val;
|
||||
}
|
||||
}
|
||||
|
||||
// reset
|
||||
pos_grid[d] = cur_d;
|
||||
}
|
||||
|
||||
// writing to global memory (slow)
|
||||
#pragma unroll
|
||||
for (uint32_t ch = 0; ch < C; ch++) {
|
||||
// index may collide, so use atomic!
|
||||
atomicAdd(&grad[index + ch], w * results[ch] * rsqrtf(idelta[ch] + 1e-9f));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t, uint32_t D>
|
||||
void kernel_grad_tv_wrapper(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {
|
||||
static constexpr uint32_t N_THREAD = 512;
|
||||
const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };
|
||||
switch (C) {
|
||||
case 1: kernel_grad_tv<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
|
||||
case 2: kernel_grad_tv<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
|
||||
case 4: kernel_grad_tv<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
|
||||
case 8: kernel_grad_tv<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
|
||||
case 16: kernel_grad_tv<scalar_t, D, 16><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
|
||||
case 32: kernel_grad_tv<scalar_t, D, 32><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break;
|
||||
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, 8, 16 or 32."};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t>
|
||||
void grad_total_variation_cuda(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {
|
||||
switch (D) {
|
||||
case 2: kernel_grad_tv_wrapper<scalar_t, 2>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
|
||||
case 3: kernel_grad_tv_wrapper<scalar_t, 3>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
|
||||
case 4: kernel_grad_tv_wrapper<scalar_t, 4>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
|
||||
case 5: kernel_grad_tv_wrapper<scalar_t, 5>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break;
|
||||
default: throw std::runtime_error{"GridEncoding: D must be 2, 3, 4, or 5."};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) {
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
embeddings.scalar_type(), "grad_total_variation", ([&] {
|
||||
grad_total_variation_cuda<scalar_t>(inputs.data_ptr<scalar_t>(), embeddings.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(), offsets.data_ptr<int>(), weight, B, D, C, L, S, H, gridtype, align_corners);
|
||||
}));
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void kernel_grad_wd(
|
||||
const scalar_t * __restrict__ grid,
|
||||
scalar_t * __restrict__ grad,
|
||||
const int * __restrict__ offsets,
|
||||
const float weight,
|
||||
const uint32_t B, const uint32_t L, const uint32_t C
|
||||
) {
|
||||
const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (b >= B * C) return;
|
||||
|
||||
// locate
|
||||
grid += b;
|
||||
grad += b;
|
||||
|
||||
// decide in which level is this thread...
|
||||
uint32_t level = 0;
|
||||
const uint32_t n = b / C;
|
||||
// binary search b in offsets
|
||||
uint32_t l = 0, r = L;
|
||||
while (l < r) {
|
||||
uint32_t m = (l + r) / 2;
|
||||
if (offsets[m] <= n) {
|
||||
level = m;
|
||||
l = m + 1;
|
||||
} else {
|
||||
r = m;
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
|
||||
grad[0] += 2 * weight * grid[0] / hashmap_size;
|
||||
}
|
||||
|
||||
void grad_weight_decay(const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L) {
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
embeddings.scalar_type(), "grad_weight_decay", ([&] {
|
||||
static constexpr uint32_t N_THREAD = 1024;
|
||||
const dim3 blocks_hashgrid = { div_round_up(B * C, N_THREAD), 1, 1 };
|
||||
kernel_grad_wd<scalar_t><<<blocks_hashgrid, N_THREAD>>>(embeddings.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(), offsets.data_ptr<int>(), weight, B, L, C);
|
||||
}));
|
||||
}
|
||||
18
gridencoder/src/gridencoder.h
Normal file
18
gridencoder/src/gridencoder.h
Normal file
@@ -0,0 +1,18 @@
|
||||
#ifndef _HASH_ENCODE_H
|
||||
#define _HASH_ENCODE_H
|
||||
|
||||
#include <stdint.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
// inputs: [B, D], float, in [0, 1]
|
||||
// embeddings: [sO, C], float
|
||||
// offsets: [L + 1], uint32_t
|
||||
// outputs: [B, L * C], float
|
||||
// H: base resolution
|
||||
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp);
|
||||
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const uint32_t max_level, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp);
|
||||
|
||||
void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners);
|
||||
void grad_weight_decay(const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L);
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user